"""Tests for the paper-trading service.""" from __future__ import annotations from datetime import date, datetime, timedelta, timezone import pytest from app.exceptions import ValidationError from app.models.benchmark_price import BenchmarkPrice from app.models.ohlcv import OHLCVRecord from app.models.paper_trade import PaperTrade from app.models.ticker import Ticker from app.models.user import User from app.services import paper_trade_service as svc from tests.conftest import _test_session_factory # type: ignore @pytest.fixture async def session(): async with _test_session_factory() as s: yield s async def _seed(session, symbol: str, close: float) -> int: user = await session.get(User, 1) if user is None: session.add(User(id=1, username="u", password_hash="x", role="user", has_access=True)) await session.flush() t = Ticker(symbol=symbol) session.add(t) await session.flush() session.add(OHLCVRecord(ticker_id=t.id, date=date.today(), open=close, high=close, low=close, close=close, volume=1)) await session.commit() return t.id async def test_create_and_list_open(session): await _seed(session, "AAA", close=110.0) await svc.create_trade(session, 1, symbol="AAA", direction="long", entry_price=100.0, shares=10, stop_loss=95.0, target=120.0) rows = await svc.list_trades(session, 1) assert len(rows) == 1 row = rows[0] assert row["symbol"] == "AAA" assert row["status"] == "open" assert row["current_price"] == 110.0 # marked to the latest close async def test_close_uses_current_price(session): await _seed(session, "AAA", close=112.0) trade = await svc.create_trade(session, 1, symbol="AAA", direction="long", entry_price=100.0, shares=5, stop_loss=95.0, target=120.0) closed = await svc.close_trade(session, 1, trade.id) assert closed.status == "closed" assert closed.close_price == 112.0 assert closed.closed_at is not None rows = await svc.list_trades(session, 1, status="closed") assert rows[0]["current_price"] == 112.0 # closed → realized exit async def test_close_with_explicit_price(session): await _seed(session, "AAA", close=112.0) trade = await svc.create_trade(session, 1, symbol="AAA", direction="short", entry_price=100.0, shares=5, stop_loss=105.0, target=90.0) closed = await svc.close_trade(session, 1, trade.id, close_price=93.0) assert closed.close_price == 93.0 async def test_invalid_direction_rejected(session): await _seed(session, "AAA", close=100.0) with pytest.raises(ValidationError): await svc.create_trade(session, 1, symbol="AAA", direction="sideways", entry_price=100.0, shares=1, stop_loss=95.0, target=110.0) async def test_double_close_rejected(session): await _seed(session, "AAA", close=100.0) trade = await svc.create_trade(session, 1, symbol="AAA", direction="long", entry_price=100.0, shares=1, stop_loss=95.0, target=110.0) await svc.close_trade(session, 1, trade.id) with pytest.raises(ValidationError): await svc.close_trade(session, 1, trade.id) async def _add_bars(session, ticker_id: int, highs_lows: list[tuple[float, float]], start: date) -> None: for i, (hi, lo) in enumerate(highs_lows): mid = (hi + lo) / 2 session.add(OHLCVRecord(ticker_id=ticker_id, date=start + timedelta(days=i + 1), open=mid, high=hi, low=lo, close=mid, volume=1)) await session.commit() async def test_resolve_closes_on_target(session): tid = await _seed(session, "AAA", close=100.0) trade = await svc.create_trade(session, 1, symbol="AAA", direction="long", entry_price=100.0, shares=10, stop_loss=95.0, target=110.0) # later bars: a day that trades up through 110 await _add_bars(session, tid, [(103, 101), (111, 108)], start=date.today()) closed = await svc.resolve_open_trades(session) assert closed == 1 await session.refresh(trade) assert trade.status == "closed" assert trade.close_price == 110.0 # closed at target async def test_resolve_closes_on_stop(session): tid = await _seed(session, "AAA", close=100.0) trade = await svc.create_trade(session, 1, symbol="AAA", direction="long", entry_price=100.0, shares=10, stop_loss=95.0, target=110.0) await _add_bars(session, tid, [(101, 94)], start=date.today()) # low pierces stop closed = await svc.resolve_open_trades(session) assert closed == 1 await session.refresh(trade) assert trade.close_price == 95.0 # closed at stop async def test_resolve_leaves_open_when_neither_hit(session): tid = await _seed(session, "AAA", close=100.0) await svc.create_trade(session, 1, symbol="AAA", direction="long", entry_price=100.0, shares=10, stop_loss=95.0, target=110.0) await _add_bars(session, tid, [(103, 98), (104, 99)], start=date.today()) # range-bound closed = await svc.resolve_open_trades(session) assert closed == 0 rows = await svc.list_trades(session, 1, status="open") assert len(rows) == 1 async def _seed_benchmark(session, points: dict) -> None: for d, close in points.items(): session.add(BenchmarkPrice(symbol="SPY", date=d, close=close)) await session.commit() async def _add_open_trade(session, ticker_id: int, direction: str, *, entry: float, shares: float, days_ago: int) -> None: session.add(PaperTrade( user_id=1, ticker_id=ticker_id, direction=direction, entry_price=entry, shares=shares, stop_loss=entry * 0.95, target=entry * 1.2, status="open", opened_at=datetime.now(timezone.utc) - timedelta(days=days_ago), )) await session.commit() async def test_alpha_long_open(session): tid = await _seed(session, "AAA", close=110.0) # current price 110 → +10% on a 100 entry today = date.today() await _seed_benchmark(session, {today - timedelta(days=10): 400.0, today: 420.0}) # SPY +5% await _add_open_trade(session, tid, "long", entry=100.0, shares=10, days_ago=10) row = (await svc.list_trades(session, 1, status="open"))[0] assert row["benchmark_return_pct"] == pytest.approx(5.0) assert row["alpha_pct"] == pytest.approx(5.0) # +10% trade − 5% bench assert row["alpha_usd"] == pytest.approx(50.0) # 5% of 100*10 async def test_alpha_short_and_missing_benchmark(session): tid = await _seed(session, "BBB", close=90.0) # price fell to 90 → short +10% today = date.today() await _add_open_trade(session, tid, "short", entry=100.0, shares=4, days_ago=10) # No benchmark data yet → alpha unset, not an error. row = (await svc.list_trades(session, 1, status="open"))[0] assert row["alpha_pct"] is None assert row["benchmark_return_pct"] is None # Flat benchmark → alpha equals the (direction-signed) trade return. await _seed_benchmark(session, {today - timedelta(days=10): 400.0, today: 400.0}) row = (await svc.list_trades(session, 1, status="open"))[0] assert row["benchmark_return_pct"] == pytest.approx(0.0) assert row["alpha_pct"] == pytest.approx(10.0)