"""Unit tests for the trade setup outcome evaluation service.""" from __future__ import annotations from datetime import date, datetime, timedelta, timezone import pytest from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models.ohlcv import OHLCVRecord from app.models.ticker import Ticker from app.models.trade_setup import TradeSetup from app.services.outcome_service import ( OUTCOME_AMBIGUOUS, OUTCOME_EXPIRED, OUTCOME_STOP_HIT, OUTCOME_TARGET_HIT, Bar, evaluate_pending_setups, evaluate_setup_against_bars, get_performance_stats, ) @pytest.fixture async def outcome_session() -> AsyncSession: """DB session compatible with evaluate_pending_setups (which commits).""" from tests.conftest import _test_session_factory async with _test_session_factory() as session: yield session def _bars(*hl: tuple[float, float], start: date = date(2026, 1, 5)) -> list[Bar]: return [ Bar(date=start + timedelta(days=i), high=high, low=low) for i, (high, low) in enumerate(hl) ] # --------------------------------------------------------------------------- # evaluate_setup_against_bars — pure logic # --------------------------------------------------------------------------- class TestEvaluateSetupAgainstBars: def test_long_target_hit(self): # entry ~100, stop 95, target 110 bars = _bars((105, 99), (111, 104)) outcome, outcome_date = evaluate_setup_against_bars("long", 95, 110, bars) assert outcome == OUTCOME_TARGET_HIT assert outcome_date == bars[1].date def test_long_stop_hit(self): bars = _bars((105, 99), (103, 94)) outcome, outcome_date = evaluate_setup_against_bars("long", 95, 110, bars) assert outcome == OUTCOME_STOP_HIT assert outcome_date == bars[1].date def test_long_stop_before_target_across_bars(self): # Stop hit on bar 0, target would hit on bar 1 — stop wins (first bar decides) bars = _bars((100, 94), (112, 100)) outcome, _ = evaluate_setup_against_bars("long", 95, 110, bars) assert outcome == OUTCOME_STOP_HIT def test_short_target_hit(self): # short: entry ~100, stop 105, target 90 bars = _bars((102, 96), (98, 89)) outcome, outcome_date = evaluate_setup_against_bars("short", 105, 90, bars) assert outcome == OUTCOME_TARGET_HIT assert outcome_date == bars[1].date def test_short_stop_hit(self): bars = _bars((102, 96), (106, 98)) outcome, _ = evaluate_setup_against_bars("short", 105, 90, bars) assert outcome == OUTCOME_STOP_HIT def test_ambiguous_when_both_levels_in_same_bar(self): # one giant bar spans both stop (95) and target (110) bars = _bars((112, 94)) outcome, _ = evaluate_setup_against_bars("long", 95, 110, bars) assert outcome == OUTCOME_AMBIGUOUS def test_pending_when_not_enough_bars(self): bars = _bars((105, 99), (104, 98)) outcome, outcome_date = evaluate_setup_against_bars("long", 95, 110, bars, max_bars=30) assert outcome is None assert outcome_date is None def test_expired_after_max_bars(self): bars = _bars(*[(105, 99)] * 10) outcome, outcome_date = evaluate_setup_against_bars("long", 95, 110, bars, max_bars=10) assert outcome == OUTCOME_EXPIRED assert outcome_date == bars[9].date def test_hit_beyond_max_bars_is_ignored(self): # target hit on bar 11 but window is 10 — expired bars = _bars(*[(105, 99)] * 10, (115, 105)) outcome, _ = evaluate_setup_against_bars("long", 95, 110, bars, max_bars=10) assert outcome == OUTCOME_EXPIRED def test_no_bars_is_pending(self): outcome, _ = evaluate_setup_against_bars("long", 95, 110, []) assert outcome is None # --------------------------------------------------------------------------- # evaluate_pending_setups — DB integration # --------------------------------------------------------------------------- async def _make_ticker(db: AsyncSession, symbol: str = "AAPL") -> Ticker: ticker = Ticker(symbol=symbol) db.add(ticker) await db.flush() return ticker def _make_setup( ticker: Ticker, direction: str = "long", entry: float = 100.0, stop: float = 95.0, target: float = 110.0, rr: float = 2.0, detected: datetime | None = None, **kwargs, ) -> TradeSetup: return TradeSetup( ticker_id=ticker.id, direction=direction, entry_price=entry, stop_loss=stop, target=target, rr_ratio=rr, composite_score=50.0, detected_at=detected or datetime(2026, 1, 2, 21, 0, tzinfo=timezone.utc), **kwargs, ) def _add_bars(db: AsyncSession, ticker: Ticker, *hl: tuple[float, float], start: date = date(2026, 1, 5)): for i, (high, low) in enumerate(hl): db.add(OHLCVRecord( ticker_id=ticker.id, date=start + timedelta(days=i), open=(high + low) / 2, high=high, low=low, close=(high + low) / 2, volume=1_000_000, )) class TestEvaluatePendingSetups: async def test_writes_outcome_and_metadata(self, outcome_session: AsyncSession): ticker = await _make_ticker(outcome_session) setup = _make_setup(ticker) outcome_session.add(setup) _add_bars(outcome_session, ticker, (105, 99), (111, 104)) await outcome_session.flush() summary = await evaluate_pending_setups(outcome_session) assert summary["evaluated"] == 1 assert summary["by_outcome"] == {OUTCOME_TARGET_HIT: 1} result = await outcome_session.execute(select(TradeSetup)) stored = result.scalar_one() assert stored.actual_outcome == OUTCOME_TARGET_HIT assert stored.outcome_date == date(2026, 1, 6) assert stored.evaluated_at is not None async def test_undecided_setup_stays_pending(self, outcome_session: AsyncSession): ticker = await _make_ticker(outcome_session) outcome_session.add(_make_setup(ticker)) _add_bars(outcome_session, ticker, (105, 99)) # no level hit, < max_bars await outcome_session.flush() summary = await evaluate_pending_setups(outcome_session) assert summary["evaluated"] == 0 assert summary["still_pending"] == 1 result = await outcome_session.execute(select(TradeSetup)) assert result.scalar_one().actual_outcome is None async def test_only_bars_after_detection_are_used(self, outcome_session: AsyncSession): ticker = await _make_ticker(outcome_session) # bar on the detection date itself would hit the stop — must be ignored _add_bars(outcome_session, ticker, (100, 90), start=date(2026, 1, 2)) _add_bars(outcome_session, ticker, (111, 104), start=date(2026, 1, 5)) outcome_session.add(_make_setup(ticker)) await outcome_session.flush() await evaluate_pending_setups(outcome_session) result = await outcome_session.execute(select(TradeSetup)) assert result.scalar_one().actual_outcome == OUTCOME_TARGET_HIT async def test_already_evaluated_setups_are_skipped(self, outcome_session: AsyncSession): ticker = await _make_ticker(outcome_session) outcome_session.add(_make_setup(ticker, actual_outcome=OUTCOME_STOP_HIT)) await outcome_session.flush() summary = await evaluate_pending_setups(outcome_session) assert summary["evaluated"] == 0 assert summary["still_pending"] == 0 # --------------------------------------------------------------------------- # get_performance_stats # --------------------------------------------------------------------------- class TestGetPerformanceStats: async def test_empty_database(self, db_session: AsyncSession): stats = await get_performance_stats(db_session) assert stats["overall"]["total"] == 0 assert stats["overall"]["hit_rate"] is None assert stats["pending"] == 0 async def test_aggregation(self, db_session: AsyncSession): ticker = await _make_ticker(db_session) db_session.add(_make_setup( ticker, direction="long", rr=3.0, actual_outcome=OUTCOME_TARGET_HIT, confidence_score=80.0, recommended_action="LONG_HIGH", )) db_session.add(_make_setup( ticker, direction="long", rr=2.0, actual_outcome=OUTCOME_STOP_HIT, confidence_score=55.0, recommended_action="LONG_MODERATE", )) db_session.add(_make_setup( ticker, direction="short", rr=2.5, actual_outcome=OUTCOME_EXPIRED, confidence_score=40.0, recommended_action="NEUTRAL", )) db_session.add(_make_setup(ticker, direction="short")) # pending await db_session.flush() stats = await get_performance_stats(db_session) overall = stats["overall"] assert overall["total"] == 3 assert overall["wins"] == 1 assert overall["losses"] == 1 assert overall["expired"] == 1 assert overall["hit_rate"] == 50.0 # realized: +3.0 (win), -1.0 (loss), 0.0 (expired) → avg 0.667 assert overall["avg_r"] == pytest.approx(0.667, abs=0.001) assert overall["total_r"] == pytest.approx(2.0) assert stats["pending"] == 1 assert stats["by_direction"]["long"]["total"] == 2 assert stats["by_direction"]["short"]["total"] == 1 assert stats["by_action"]["LONG_HIGH"]["wins"] == 1 assert stats["by_confidence"]["≥70%"]["wins"] == 1 assert stats["by_confidence"]["50-70%"]["losses"] == 1 assert stats["by_confidence"]["<50%"]["expired"] == 1 async def test_ambiguous_counts_as_loss(self, db_session: AsyncSession): ticker = await _make_ticker(db_session) db_session.add(_make_setup(ticker, actual_outcome=OUTCOME_AMBIGUOUS)) await db_session.flush() stats = await get_performance_stats(db_session) assert stats["overall"]["losses"] == 1 assert stats["overall"]["hit_rate"] == 0.0 assert stats["overall"]["avg_r"] == -1.0