major update
This commit is contained in:
433
tests/unit/test_rr_scanner_preservation.py
Normal file
433
tests/unit/test_rr_scanner_preservation.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Preservation tests for R:R scanner target quality bugfix.
|
||||
|
||||
Verify that the fix does NOT change behavior for zero-candidate and
|
||||
single-candidate scenarios, and that get_trade_setups sorting is unchanged.
|
||||
|
||||
The fix only changes selection logic when MULTIPLE candidates exist.
|
||||
Zero-candidate and single-candidate scenarios must produce identical results.
|
||||
|
||||
**Validates: Requirements 3.1, 3.2, 3.3, 3.5**
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings, HealthCheck, strategies as st
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.ohlcv import OHLCVRecord
|
||||
from app.models.sr_level import SRLevel
|
||||
from app.models.ticker import Ticker
|
||||
from app.models.trade_setup import TradeSetup
|
||||
from app.models.score import CompositeScore
|
||||
from app.services.rr_scanner_service import scan_ticker, get_trade_setups
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
async def scan_session() -> AsyncSession:
|
||||
"""Provide a DB session compatible with scan_ticker (which commits)."""
|
||||
from tests.conftest import _test_session_factory
|
||||
|
||||
async with _test_session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session() -> AsyncSession:
|
||||
"""Provide a transactional DB session for get_trade_setups tests."""
|
||||
from tests.conftest import _test_session_factory
|
||||
|
||||
async with _test_session_factory() as session:
|
||||
async with session.begin():
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_ohlcv_bars(
|
||||
ticker_id: int,
|
||||
num_bars: int = 20,
|
||||
base_close: float = 100.0,
|
||||
) -> list[OHLCVRecord]:
|
||||
"""Generate OHLCV bars closing around base_close with ATR ≈ 2.0."""
|
||||
bars: list[OHLCVRecord] = []
|
||||
start = date(2024, 1, 1)
|
||||
for i in range(num_bars):
|
||||
close = base_close + (i % 3 - 1) * 0.5 # oscillate ±0.5
|
||||
bars.append(OHLCVRecord(
|
||||
ticker_id=ticker_id,
|
||||
date=start + timedelta(days=i),
|
||||
open=close - 0.3,
|
||||
high=close + 1.0,
|
||||
low=close - 1.0,
|
||||
close=close,
|
||||
volume=100_000,
|
||||
))
|
||||
return bars
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 7.1 [PBT-preservation] Property test: zero-candidate and single-candidate
|
||||
# scenarios produce the same output as the original code.
|
||||
# ===========================================================================
|
||||
|
||||
@st.composite
|
||||
def zero_candidate_scenario(draw: st.DrawFn) -> dict:
|
||||
"""Generate a scenario where no S/R levels qualify as candidates.
|
||||
|
||||
Variants:
|
||||
- No SR levels at all
|
||||
- All levels below entry (no long targets) and all above entry (no short targets)
|
||||
but all below the R:R threshold for their respective directions
|
||||
- Levels in the right direction but below R:R threshold
|
||||
|
||||
Note: scan_ticker does NOT filter by SR level type — it only checks whether
|
||||
the price_level is above or below entry. So "wrong side" means all levels
|
||||
are clustered near entry and below threshold in both directions.
|
||||
"""
|
||||
variant = draw(st.sampled_from(["no_levels", "below_threshold"]))
|
||||
|
||||
if variant == "no_levels":
|
||||
return {"variant": variant, "levels": []}
|
||||
|
||||
else: # below_threshold
|
||||
# All levels close to entry so R:R < 1.5 with risk ≈ 3
|
||||
# For longs: reward < 4.5 → price < 104.5
|
||||
# For shorts: reward < 4.5 → price > 95.5
|
||||
# Place all levels in the 96–104 band (below threshold both ways)
|
||||
num = draw(st.integers(min_value=1, max_value=3))
|
||||
levels = []
|
||||
for _ in range(num):
|
||||
price = draw(st.floats(min_value=100.5, max_value=103.5))
|
||||
levels.append({
|
||||
"price": price,
|
||||
"type": draw(st.sampled_from(["resistance", "support"])),
|
||||
"strength": draw(st.integers(min_value=10, max_value=100)),
|
||||
})
|
||||
# Also add some below entry but still below threshold
|
||||
for _ in range(draw(st.integers(min_value=0, max_value=2))):
|
||||
price = draw(st.floats(min_value=96.5, max_value=99.5))
|
||||
levels.append({
|
||||
"price": price,
|
||||
"type": draw(st.sampled_from(["resistance", "support"])),
|
||||
"strength": draw(st.integers(min_value=10, max_value=100)),
|
||||
})
|
||||
return {"variant": variant, "levels": levels}
|
||||
|
||||
|
||||
@st.composite
|
||||
def single_candidate_scenario(draw: st.DrawFn) -> dict:
|
||||
"""Generate a scenario with exactly one S/R level that meets the R:R threshold.
|
||||
|
||||
For longs: one resistance above entry with R:R >= 1.5 (price >= 104.5 with risk ≈ 3).
|
||||
"""
|
||||
direction = draw(st.sampled_from(["long", "short"]))
|
||||
|
||||
if direction == "long":
|
||||
# Single resistance above entry meeting threshold
|
||||
price = draw(st.floats(min_value=105.0, max_value=150.0))
|
||||
strength = draw(st.integers(min_value=1, max_value=100))
|
||||
return {
|
||||
"direction": direction,
|
||||
"level": {"price": price, "type": "resistance", "strength": strength},
|
||||
}
|
||||
else:
|
||||
# Single support below entry meeting threshold
|
||||
price = draw(st.floats(min_value=50.0, max_value=95.0))
|
||||
strength = draw(st.integers(min_value=1, max_value=100))
|
||||
return {
|
||||
"direction": direction,
|
||||
"level": {"price": price, "type": "support", "strength": strength},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@given(scenario=zero_candidate_scenario())
|
||||
@settings(
|
||||
max_examples=15,
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
)
|
||||
async def test_property_zero_candidates_produce_no_setup(
|
||||
scenario: dict,
|
||||
scan_session: AsyncSession,
|
||||
):
|
||||
"""**Validates: Requirements 3.1, 3.2**
|
||||
|
||||
Property: when zero candidate S/R levels exist (no levels, wrong side,
|
||||
or below threshold), scan_ticker produces no setup — unchanged from
|
||||
original behavior.
|
||||
"""
|
||||
from tests.conftest import _test_engine, _test_session_factory
|
||||
from app.database import Base
|
||||
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with _test_session_factory() as session:
|
||||
ticker = Ticker(symbol="PRSV0")
|
||||
session.add(ticker)
|
||||
await session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
session.add_all(bars)
|
||||
|
||||
for lv_data in scenario.get("levels", []):
|
||||
session.add(SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=lv_data["price"],
|
||||
type=lv_data["type"],
|
||||
strength=lv_data["strength"],
|
||||
detection_method="volume_profile",
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
setups = await scan_ticker(session, "PRSV0", rr_threshold=1.5)
|
||||
|
||||
assert setups == [], (
|
||||
f"Expected no setups for zero-candidate scenario "
|
||||
f"(variant={scenario.get('variant', 'unknown')}), got {len(setups)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@given(scenario=single_candidate_scenario())
|
||||
@settings(
|
||||
max_examples=15,
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
)
|
||||
async def test_property_single_candidate_selected_unchanged(
|
||||
scenario: dict,
|
||||
scan_session: AsyncSession,
|
||||
):
|
||||
"""**Validates: Requirements 3.3**
|
||||
|
||||
Property: when exactly one candidate S/R level meets the R:R threshold,
|
||||
scan_ticker selects it — same as the original code would.
|
||||
"""
|
||||
from tests.conftest import _test_engine, _test_session_factory
|
||||
from app.database import Base
|
||||
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with _test_session_factory() as session:
|
||||
ticker = Ticker(symbol="PRSV1")
|
||||
session.add(ticker)
|
||||
await session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
session.add_all(bars)
|
||||
|
||||
lv = scenario["level"]
|
||||
session.add(SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=lv["price"],
|
||||
type=lv["type"],
|
||||
strength=lv["strength"],
|
||||
detection_method="volume_profile",
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
setups = await scan_ticker(session, "PRSV1", rr_threshold=1.5)
|
||||
|
||||
direction = scenario["direction"]
|
||||
dir_setups = [s for s in setups if s.direction == direction]
|
||||
assert len(dir_setups) == 1, (
|
||||
f"Expected exactly one {direction} setup for single candidate, "
|
||||
f"got {len(dir_setups)}"
|
||||
)
|
||||
|
||||
selected_target = dir_setups[0].target
|
||||
expected_target = round(lv["price"], 4)
|
||||
assert selected_target == pytest.approx(expected_target, abs=0.01), (
|
||||
f"Single candidate: expected target={expected_target}, "
|
||||
f"got {selected_target}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 7.2 Unit test: no S/R levels → no setup produced
|
||||
# ===========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_sr_levels_produces_no_setup(scan_session: AsyncSession):
|
||||
"""**Validates: Requirements 3.1**
|
||||
|
||||
When a ticker has OHLCV data but no S/R levels at all,
|
||||
scan_ticker should return an empty list.
|
||||
"""
|
||||
ticker = Ticker(symbol="NOSRL")
|
||||
scan_session.add(ticker)
|
||||
await scan_session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
scan_session.add_all(bars)
|
||||
await scan_session.flush()
|
||||
|
||||
setups = await scan_ticker(scan_session, "NOSRL", rr_threshold=1.5)
|
||||
|
||||
assert setups == [], (
|
||||
f"Expected no setups when no SR levels exist, got {len(setups)}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 7.3 Unit test: single candidate meets threshold → selected
|
||||
# ===========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_resistance_above_threshold_selected(scan_session: AsyncSession):
|
||||
"""**Validates: Requirements 3.3**
|
||||
|
||||
When exactly one resistance level above entry meets the R:R threshold,
|
||||
it should be selected as the long setup target.
|
||||
|
||||
Entry ≈ 100, ATR ≈ 2, risk ≈ 3. Resistance at 110 → R:R ≈ 3.33 (>= 1.5).
|
||||
"""
|
||||
ticker = Ticker(symbol="SINGL")
|
||||
scan_session.add(ticker)
|
||||
await scan_session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
scan_session.add_all(bars)
|
||||
|
||||
level = SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=110.0,
|
||||
type="resistance",
|
||||
strength=60,
|
||||
detection_method="volume_profile",
|
||||
)
|
||||
scan_session.add(level)
|
||||
await scan_session.flush()
|
||||
|
||||
setups = await scan_ticker(scan_session, "SINGL", rr_threshold=1.5)
|
||||
|
||||
long_setups = [s for s in setups if s.direction == "long"]
|
||||
assert len(long_setups) == 1, (
|
||||
f"Expected exactly one long setup, got {len(long_setups)}"
|
||||
)
|
||||
assert long_setups[0].target == pytest.approx(110.0, abs=0.01), (
|
||||
f"Expected target=110.0, got {long_setups[0].target}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_support_below_threshold_selected(scan_session: AsyncSession):
|
||||
"""**Validates: Requirements 3.3**
|
||||
|
||||
When exactly one support level below entry meets the R:R threshold,
|
||||
it should be selected as the short setup target.
|
||||
|
||||
Entry ≈ 100, ATR ≈ 2, risk ≈ 3. Support at 90 → R:R ≈ 3.33 (>= 1.5).
|
||||
"""
|
||||
ticker = Ticker(symbol="SINGS")
|
||||
scan_session.add(ticker)
|
||||
await scan_session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
scan_session.add_all(bars)
|
||||
|
||||
level = SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=90.0,
|
||||
type="support",
|
||||
strength=55,
|
||||
detection_method="pivot_point",
|
||||
)
|
||||
scan_session.add(level)
|
||||
await scan_session.flush()
|
||||
|
||||
setups = await scan_ticker(scan_session, "SINGS", rr_threshold=1.5)
|
||||
|
||||
short_setups = [s for s in setups if s.direction == "short"]
|
||||
assert len(short_setups) == 1, (
|
||||
f"Expected exactly one short setup, got {len(short_setups)}"
|
||||
)
|
||||
assert short_setups[0].target == pytest.approx(90.0, abs=0.01), (
|
||||
f"Expected target=90.0, got {short_setups[0].target}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 7.4 Unit test: get_trade_setups sorting is unchanged (R:R desc, composite desc)
|
||||
# ===========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_trade_setups_sorting_rr_desc_composite_desc(db_session: AsyncSession):
|
||||
"""**Validates: Requirements 3.5**
|
||||
|
||||
get_trade_setups must return results sorted by rr_ratio descending,
|
||||
with composite_score descending as the secondary sort key.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Create tickers for each setup
|
||||
ticker_a = Ticker(symbol="SORTA")
|
||||
ticker_b = Ticker(symbol="SORTB")
|
||||
ticker_c = Ticker(symbol="SORTC")
|
||||
ticker_d = Ticker(symbol="SORTD")
|
||||
db_session.add_all([ticker_a, ticker_b, ticker_c, ticker_d])
|
||||
await db_session.flush()
|
||||
|
||||
# Create setups with different rr_ratio and composite_score values
|
||||
# Expected order: D (rr=5.0), C (rr=3.0, comp=80), B (rr=3.0, comp=50), A (rr=1.5)
|
||||
setup_a = TradeSetup(
|
||||
ticker_id=ticker_a.id, direction="long",
|
||||
entry_price=100.0, stop_loss=97.0, target=104.5,
|
||||
rr_ratio=1.5, composite_score=90.0, detected_at=now,
|
||||
)
|
||||
setup_b = TradeSetup(
|
||||
ticker_id=ticker_b.id, direction="long",
|
||||
entry_price=100.0, stop_loss=97.0, target=109.0,
|
||||
rr_ratio=3.0, composite_score=50.0, detected_at=now,
|
||||
)
|
||||
setup_c = TradeSetup(
|
||||
ticker_id=ticker_c.id, direction="short",
|
||||
entry_price=100.0, stop_loss=103.0, target=91.0,
|
||||
rr_ratio=3.0, composite_score=80.0, detected_at=now,
|
||||
)
|
||||
setup_d = TradeSetup(
|
||||
ticker_id=ticker_d.id, direction="long",
|
||||
entry_price=100.0, stop_loss=97.0, target=115.0,
|
||||
rr_ratio=5.0, composite_score=30.0, detected_at=now,
|
||||
)
|
||||
db_session.add_all([setup_a, setup_b, setup_c, setup_d])
|
||||
await db_session.flush()
|
||||
|
||||
results = await get_trade_setups(db_session)
|
||||
|
||||
assert len(results) == 4, f"Expected 4 setups, got {len(results)}"
|
||||
|
||||
# Verify ordering: rr_ratio desc, then composite_score desc
|
||||
rr_values = [r["rr_ratio"] for r in results]
|
||||
assert rr_values == [5.0, 3.0, 3.0, 1.5], (
|
||||
f"Expected rr_ratio order [5.0, 3.0, 3.0, 1.5], got {rr_values}"
|
||||
)
|
||||
|
||||
# For the two setups with rr_ratio=3.0, composite_score should be desc
|
||||
tied_composites = [r["composite_score"] for r in results if r["rr_ratio"] == 3.0]
|
||||
assert tied_composites == [80.0, 50.0], (
|
||||
f"Expected composite_score order [80.0, 50.0] for tied R:R, "
|
||||
f"got {tied_composites}"
|
||||
)
|
||||
|
||||
# Verify symbols match expected order
|
||||
symbols = [r["symbol"] for r in results]
|
||||
assert symbols == ["SORTD", "SORTC", "SORTB", "SORTA"], (
|
||||
f"Expected symbol order ['SORTD', 'SORTC', 'SORTB', 'SORTA'], "
|
||||
f"got {symbols}"
|
||||
)
|
||||
Reference in New Issue
Block a user