434 lines
15 KiB
Python
434 lines
15 KiB
Python
"""Preservation tests for R:R scanner target quality bugfix.
|
||
|
||
Verify that the fix does NOT change behavior for zero-candidate and
|
||
single-candidate scenarios, and that get_trade_setups sorting is unchanged.
|
||
|
||
The fix only changes selection logic when MULTIPLE candidates exist.
|
||
Zero-candidate and single-candidate scenarios must produce identical results.
|
||
|
||
**Validates: Requirements 3.1, 3.2, 3.3, 3.5**
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from datetime import date, datetime, timedelta, timezone
|
||
|
||
import pytest
|
||
from hypothesis import given, settings, HealthCheck, strategies as st
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.models.ohlcv import OHLCVRecord
|
||
from app.models.sr_level import SRLevel
|
||
from app.models.ticker import Ticker
|
||
from app.models.trade_setup import TradeSetup
|
||
from app.models.score import CompositeScore
|
||
from app.services.rr_scanner_service import scan_ticker, get_trade_setups
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Session fixtures
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@pytest.fixture
|
||
async def scan_session() -> AsyncSession:
|
||
"""Provide a DB session compatible with scan_ticker (which commits)."""
|
||
from tests.conftest import _test_session_factory
|
||
|
||
async with _test_session_factory() as session:
|
||
yield session
|
||
|
||
|
||
@pytest.fixture
|
||
async def db_session() -> AsyncSession:
|
||
"""Provide a transactional DB session for get_trade_setups tests."""
|
||
from tests.conftest import _test_session_factory
|
||
|
||
async with _test_session_factory() as session:
|
||
async with session.begin():
|
||
yield session
|
||
await session.rollback()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _make_ohlcv_bars(
|
||
ticker_id: int,
|
||
num_bars: int = 20,
|
||
base_close: float = 100.0,
|
||
) -> list[OHLCVRecord]:
|
||
"""Generate OHLCV bars closing around base_close with ATR ≈ 2.0."""
|
||
bars: list[OHLCVRecord] = []
|
||
start = date(2024, 1, 1)
|
||
for i in range(num_bars):
|
||
close = base_close + (i % 3 - 1) * 0.5 # oscillate ±0.5
|
||
bars.append(OHLCVRecord(
|
||
ticker_id=ticker_id,
|
||
date=start + timedelta(days=i),
|
||
open=close - 0.3,
|
||
high=close + 1.0,
|
||
low=close - 1.0,
|
||
close=close,
|
||
volume=100_000,
|
||
))
|
||
return bars
|
||
|
||
|
||
# ===========================================================================
|
||
# 7.1 [PBT-preservation] Property test: zero-candidate and single-candidate
|
||
# scenarios produce the same output as the original code.
|
||
# ===========================================================================
|
||
|
||
@st.composite
|
||
def zero_candidate_scenario(draw: st.DrawFn) -> dict:
|
||
"""Generate a scenario where no S/R levels qualify as candidates.
|
||
|
||
Variants:
|
||
- No SR levels at all
|
||
- All levels below entry (no long targets) and all above entry (no short targets)
|
||
but all below the R:R threshold for their respective directions
|
||
- Levels in the right direction but below R:R threshold
|
||
|
||
Note: scan_ticker does NOT filter by SR level type — it only checks whether
|
||
the price_level is above or below entry. So "wrong side" means all levels
|
||
are clustered near entry and below threshold in both directions.
|
||
"""
|
||
variant = draw(st.sampled_from(["no_levels", "below_threshold"]))
|
||
|
||
if variant == "no_levels":
|
||
return {"variant": variant, "levels": []}
|
||
|
||
else: # below_threshold
|
||
# All levels close to entry so R:R < 1.5 with risk ≈ 3
|
||
# For longs: reward < 4.5 → price < 104.5
|
||
# For shorts: reward < 4.5 → price > 95.5
|
||
# Place all levels in the 96–104 band (below threshold both ways)
|
||
num = draw(st.integers(min_value=1, max_value=3))
|
||
levels = []
|
||
for _ in range(num):
|
||
price = draw(st.floats(min_value=100.5, max_value=103.5))
|
||
levels.append({
|
||
"price": price,
|
||
"type": draw(st.sampled_from(["resistance", "support"])),
|
||
"strength": draw(st.integers(min_value=10, max_value=100)),
|
||
})
|
||
# Also add some below entry but still below threshold
|
||
for _ in range(draw(st.integers(min_value=0, max_value=2))):
|
||
price = draw(st.floats(min_value=96.5, max_value=99.5))
|
||
levels.append({
|
||
"price": price,
|
||
"type": draw(st.sampled_from(["resistance", "support"])),
|
||
"strength": draw(st.integers(min_value=10, max_value=100)),
|
||
})
|
||
return {"variant": variant, "levels": levels}
|
||
|
||
|
||
@st.composite
|
||
def single_candidate_scenario(draw: st.DrawFn) -> dict:
|
||
"""Generate a scenario with exactly one S/R level that meets the R:R threshold.
|
||
|
||
For longs: one resistance above entry with R:R >= 1.5 (price >= 104.5 with risk ≈ 3).
|
||
"""
|
||
direction = draw(st.sampled_from(["long", "short"]))
|
||
|
||
if direction == "long":
|
||
# Single resistance above entry meeting threshold
|
||
price = draw(st.floats(min_value=105.0, max_value=150.0))
|
||
strength = draw(st.integers(min_value=1, max_value=100))
|
||
return {
|
||
"direction": direction,
|
||
"level": {"price": price, "type": "resistance", "strength": strength},
|
||
}
|
||
else:
|
||
# Single support below entry meeting threshold
|
||
price = draw(st.floats(min_value=50.0, max_value=95.0))
|
||
strength = draw(st.integers(min_value=1, max_value=100))
|
||
return {
|
||
"direction": direction,
|
||
"level": {"price": price, "type": "support", "strength": strength},
|
||
}
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
@given(scenario=zero_candidate_scenario())
|
||
@settings(
|
||
max_examples=15,
|
||
deadline=None,
|
||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||
)
|
||
async def test_property_zero_candidates_produce_no_setup(
|
||
scenario: dict,
|
||
scan_session: AsyncSession,
|
||
):
|
||
"""**Validates: Requirements 3.1, 3.2**
|
||
|
||
Property: when zero candidate S/R levels exist (no levels, wrong side,
|
||
or below threshold), scan_ticker produces no setup — unchanged from
|
||
original behavior.
|
||
"""
|
||
from tests.conftest import _test_engine, _test_session_factory
|
||
from app.database import Base
|
||
|
||
async with _test_engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.drop_all)
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
|
||
async with _test_session_factory() as session:
|
||
ticker = Ticker(symbol="PRSV0")
|
||
session.add(ticker)
|
||
await session.flush()
|
||
|
||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||
session.add_all(bars)
|
||
|
||
for lv_data in scenario.get("levels", []):
|
||
session.add(SRLevel(
|
||
ticker_id=ticker.id,
|
||
price_level=lv_data["price"],
|
||
type=lv_data["type"],
|
||
strength=lv_data["strength"],
|
||
detection_method="volume_profile",
|
||
))
|
||
await session.commit()
|
||
|
||
setups = await scan_ticker(session, "PRSV0", rr_threshold=1.5)
|
||
|
||
assert setups == [], (
|
||
f"Expected no setups for zero-candidate scenario "
|
||
f"(variant={scenario.get('variant', 'unknown')}), got {len(setups)}"
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
@given(scenario=single_candidate_scenario())
|
||
@settings(
|
||
max_examples=15,
|
||
deadline=None,
|
||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||
)
|
||
async def test_property_single_candidate_selected_unchanged(
|
||
scenario: dict,
|
||
scan_session: AsyncSession,
|
||
):
|
||
"""**Validates: Requirements 3.3**
|
||
|
||
Property: when exactly one candidate S/R level meets the R:R threshold,
|
||
scan_ticker selects it — same as the original code would.
|
||
"""
|
||
from tests.conftest import _test_engine, _test_session_factory
|
||
from app.database import Base
|
||
|
||
async with _test_engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.drop_all)
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
|
||
async with _test_session_factory() as session:
|
||
ticker = Ticker(symbol="PRSV1")
|
||
session.add(ticker)
|
||
await session.flush()
|
||
|
||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||
session.add_all(bars)
|
||
|
||
lv = scenario["level"]
|
||
session.add(SRLevel(
|
||
ticker_id=ticker.id,
|
||
price_level=lv["price"],
|
||
type=lv["type"],
|
||
strength=lv["strength"],
|
||
detection_method="volume_profile",
|
||
))
|
||
await session.commit()
|
||
|
||
setups = await scan_ticker(session, "PRSV1", rr_threshold=1.5)
|
||
|
||
direction = scenario["direction"]
|
||
dir_setups = [s for s in setups if s.direction == direction]
|
||
assert len(dir_setups) == 1, (
|
||
f"Expected exactly one {direction} setup for single candidate, "
|
||
f"got {len(dir_setups)}"
|
||
)
|
||
|
||
selected_target = dir_setups[0].target
|
||
expected_target = round(lv["price"], 4)
|
||
assert selected_target == pytest.approx(expected_target, abs=0.01), (
|
||
f"Single candidate: expected target={expected_target}, "
|
||
f"got {selected_target}"
|
||
)
|
||
|
||
|
||
# ===========================================================================
|
||
# 7.2 Unit test: no S/R levels → no setup produced
|
||
# ===========================================================================
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_no_sr_levels_produces_no_setup(scan_session: AsyncSession):
|
||
"""**Validates: Requirements 3.1**
|
||
|
||
When a ticker has OHLCV data but no S/R levels at all,
|
||
scan_ticker should return an empty list.
|
||
"""
|
||
ticker = Ticker(symbol="NOSRL")
|
||
scan_session.add(ticker)
|
||
await scan_session.flush()
|
||
|
||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||
scan_session.add_all(bars)
|
||
await scan_session.flush()
|
||
|
||
setups = await scan_ticker(scan_session, "NOSRL", rr_threshold=1.5)
|
||
|
||
assert setups == [], (
|
||
f"Expected no setups when no SR levels exist, got {len(setups)}"
|
||
)
|
||
|
||
|
||
# ===========================================================================
|
||
# 7.3 Unit test: single candidate meets threshold → selected
|
||
# ===========================================================================
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_single_resistance_above_threshold_selected(scan_session: AsyncSession):
|
||
"""**Validates: Requirements 3.3**
|
||
|
||
When exactly one resistance level above entry meets the R:R threshold,
|
||
it should be selected as the long setup target.
|
||
|
||
Entry ≈ 100, ATR ≈ 2, risk ≈ 3. Resistance at 110 → R:R ≈ 3.33 (>= 1.5).
|
||
"""
|
||
ticker = Ticker(symbol="SINGL")
|
||
scan_session.add(ticker)
|
||
await scan_session.flush()
|
||
|
||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||
scan_session.add_all(bars)
|
||
|
||
level = SRLevel(
|
||
ticker_id=ticker.id,
|
||
price_level=110.0,
|
||
type="resistance",
|
||
strength=60,
|
||
detection_method="volume_profile",
|
||
)
|
||
scan_session.add(level)
|
||
await scan_session.flush()
|
||
|
||
setups = await scan_ticker(scan_session, "SINGL", rr_threshold=1.5)
|
||
|
||
long_setups = [s for s in setups if s.direction == "long"]
|
||
assert len(long_setups) == 1, (
|
||
f"Expected exactly one long setup, got {len(long_setups)}"
|
||
)
|
||
assert long_setups[0].target == pytest.approx(110.0, abs=0.01), (
|
||
f"Expected target=110.0, got {long_setups[0].target}"
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_single_support_below_threshold_selected(scan_session: AsyncSession):
|
||
"""**Validates: Requirements 3.3**
|
||
|
||
When exactly one support level below entry meets the R:R threshold,
|
||
it should be selected as the short setup target.
|
||
|
||
Entry ≈ 100, ATR ≈ 2, risk ≈ 3. Support at 90 → R:R ≈ 3.33 (>= 1.5).
|
||
"""
|
||
ticker = Ticker(symbol="SINGS")
|
||
scan_session.add(ticker)
|
||
await scan_session.flush()
|
||
|
||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||
scan_session.add_all(bars)
|
||
|
||
level = SRLevel(
|
||
ticker_id=ticker.id,
|
||
price_level=90.0,
|
||
type="support",
|
||
strength=55,
|
||
detection_method="pivot_point",
|
||
)
|
||
scan_session.add(level)
|
||
await scan_session.flush()
|
||
|
||
setups = await scan_ticker(scan_session, "SINGS", rr_threshold=1.5)
|
||
|
||
short_setups = [s for s in setups if s.direction == "short"]
|
||
assert len(short_setups) == 1, (
|
||
f"Expected exactly one short setup, got {len(short_setups)}"
|
||
)
|
||
assert short_setups[0].target == pytest.approx(90.0, abs=0.01), (
|
||
f"Expected target=90.0, got {short_setups[0].target}"
|
||
)
|
||
|
||
|
||
# ===========================================================================
|
||
# 7.4 Unit test: get_trade_setups sorting is unchanged (R:R desc, composite desc)
|
||
# ===========================================================================
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_trade_setups_sorting_rr_desc_composite_desc(db_session: AsyncSession):
|
||
"""**Validates: Requirements 3.5**
|
||
|
||
get_trade_setups must return results sorted by rr_ratio descending,
|
||
with composite_score descending as the secondary sort key.
|
||
"""
|
||
now = datetime.now(timezone.utc)
|
||
|
||
# Create tickers for each setup
|
||
ticker_a = Ticker(symbol="SORTA")
|
||
ticker_b = Ticker(symbol="SORTB")
|
||
ticker_c = Ticker(symbol="SORTC")
|
||
ticker_d = Ticker(symbol="SORTD")
|
||
db_session.add_all([ticker_a, ticker_b, ticker_c, ticker_d])
|
||
await db_session.flush()
|
||
|
||
# Create setups with different rr_ratio and composite_score values
|
||
# Expected order: D (rr=5.0), C (rr=3.0, comp=80), B (rr=3.0, comp=50), A (rr=1.5)
|
||
setup_a = TradeSetup(
|
||
ticker_id=ticker_a.id, direction="long",
|
||
entry_price=100.0, stop_loss=97.0, target=104.5,
|
||
rr_ratio=1.5, composite_score=90.0, detected_at=now,
|
||
)
|
||
setup_b = TradeSetup(
|
||
ticker_id=ticker_b.id, direction="long",
|
||
entry_price=100.0, stop_loss=97.0, target=109.0,
|
||
rr_ratio=3.0, composite_score=50.0, detected_at=now,
|
||
)
|
||
setup_c = TradeSetup(
|
||
ticker_id=ticker_c.id, direction="short",
|
||
entry_price=100.0, stop_loss=103.0, target=91.0,
|
||
rr_ratio=3.0, composite_score=80.0, detected_at=now,
|
||
)
|
||
setup_d = TradeSetup(
|
||
ticker_id=ticker_d.id, direction="long",
|
||
entry_price=100.0, stop_loss=97.0, target=115.0,
|
||
rr_ratio=5.0, composite_score=30.0, detected_at=now,
|
||
)
|
||
db_session.add_all([setup_a, setup_b, setup_c, setup_d])
|
||
await db_session.flush()
|
||
|
||
results = await get_trade_setups(db_session)
|
||
|
||
assert len(results) == 4, f"Expected 4 setups, got {len(results)}"
|
||
|
||
# Verify ordering: rr_ratio desc, then composite_score desc
|
||
rr_values = [r["rr_ratio"] for r in results]
|
||
assert rr_values == [5.0, 3.0, 3.0, 1.5], (
|
||
f"Expected rr_ratio order [5.0, 3.0, 3.0, 1.5], got {rr_values}"
|
||
)
|
||
|
||
# For the two setups with rr_ratio=3.0, composite_score should be desc
|
||
tied_composites = [r["composite_score"] for r in results if r["rr_ratio"] == 3.0]
|
||
assert tied_composites == [80.0, 50.0], (
|
||
f"Expected composite_score order [80.0, 50.0] for tied R:R, "
|
||
f"got {tied_composites}"
|
||
)
|
||
|
||
# Verify symbols match expected order
|
||
symbols = [r["symbol"] for r in results]
|
||
assert symbols == ["SORTD", "SORTC", "SORTB", "SORTA"], (
|
||
f"Expected symbol order ['SORTD', 'SORTC', 'SORTB', 'SORTA'], "
|
||
f"got {symbols}"
|
||
)
|