"""Paper-trading service: take, mark-to-market, and close simulated trades.""" from __future__ import annotations from datetime import date, datetime, timezone from sqlalchemy import and_, func, select from sqlalchemy.ext.asyncio import AsyncSession from app.exceptions import NotFoundError, ValidationError from app.models.ohlcv import OHLCVRecord from app.models.paper_trade import PaperTrade from app.models.ticker import Ticker from app.services import benchmark_service, settings_store from app.services.outcome_service import ( OUTCOME_AMBIGUOUS, OUTCOME_STOP_HIT, OUTCOME_TARGET_HIT, Bar, evaluate_setup_against_bars, ) # Exit policy for OPEN paper trades (auto-close). "trailing" rides a trailing stop # (validated as the best exit in the backtest); "target" closes at the setup's # stop/target. Stored in SystemSetting so it's tunable + transparent in the UI. KEY_EXIT_MODE = "paper_exit_mode" KEY_TRAILING_PCT = "paper_trailing_pct" DEFAULT_EXIT_MODE = "trailing" DEFAULT_TRAILING_PCT = 12.0 async def get_exit_policy(db: AsyncSession) -> dict: """Active auto-exit policy: {'mode': 'trailing'|'target', 'trailing_pct': float}.""" mode = (await settings_store.get_value(db, KEY_EXIT_MODE, DEFAULT_EXIT_MODE)).strip().lower() if mode not in ("trailing", "target"): mode = DEFAULT_EXIT_MODE raw = await settings_store.get_value(db, KEY_TRAILING_PCT, str(DEFAULT_TRAILING_PCT)) try: pct = float(raw) except (TypeError, ValueError): pct = DEFAULT_TRAILING_PCT pct = max(0.5, min(90.0, pct)) return {"mode": mode, "trailing_pct": pct} async def set_exit_policy( db: AsyncSession, *, mode: str | None = None, trailing_pct: float | None = None ) -> dict: """Persist the auto-exit policy (admin). Validates inputs.""" if mode is not None: mode = mode.strip().lower() if mode not in ("trailing", "target"): raise ValidationError("mode must be 'trailing' or 'target'") await settings_store.upsert_setting(db, KEY_EXIT_MODE, mode) if trailing_pct is not None: if not 0.5 <= float(trailing_pct) <= 90.0: raise ValidationError("trailing_pct must be between 0.5 and 90") await settings_store.upsert_setting(db, KEY_TRAILING_PCT, str(float(trailing_pct))) await db.commit() return await get_exit_policy(db) async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker: normalised = symbol.strip().upper() result = await db.execute(select(Ticker).where(Ticker.symbol == normalised)) ticker = result.scalar_one_or_none() if ticker is None: raise NotFoundError(f"Ticker not found: {normalised}") return ticker async def _latest_closes(db: AsyncSession, ticker_ids: set[int]) -> dict[int, float]: """Latest stored close per ticker.""" if not ticker_ids: return {} latest = ( select(OHLCVRecord.ticker_id, func.max(OHLCVRecord.date).label("md")) .where(OHLCVRecord.ticker_id.in_(ticker_ids)) .group_by(OHLCVRecord.ticker_id) .subquery() ) stmt = select(OHLCVRecord.ticker_id, OHLCVRecord.close).join( latest, and_( OHLCVRecord.ticker_id == latest.c.ticker_id, OHLCVRecord.date == latest.c.md, ), ) result = await db.execute(stmt) return {tid: float(close) for tid, close in result.all()} async def _max_high_after(db: AsyncSession, ticker_id: int, since: date) -> float | None: """Highest high strictly after ``since`` — the running peak for a trailing stop.""" result = await db.execute( select(func.max(OHLCVRecord.high)).where( OHLCVRecord.ticker_id == ticker_id, OHLCVRecord.date > since ) ) v = result.scalar() return float(v) if v is not None else None def _trailing_close( direction: str, entry: float, init_stop: float, trail_frac: float, bars: list[Bar] ) -> tuple[float, date, str] | None: """Walk post-entry bars; return (price, date, reason) when the trailing or initial stop is hit, else None. The stop only ratchets up: max(init_stop, peak*(1-trail)) for a long. reason = 'trailing' once it's above the initial stop, else 'stop'.""" long = direction == "long" peak = entry for b in bars: if long: level = max(init_stop, peak * (1 - trail_frac)) if b.low <= level: return level, b.date, ("trailing" if level > init_stop else "stop") if b.high > peak: peak = b.high else: level = min(init_stop, peak * (1 + trail_frac)) if b.high >= level: return level, b.date, ("trailing" if level < init_stop else "stop") if b.low < peak: peak = b.low return None async def create_trade( db: AsyncSession, user_id: int, *, symbol: str, direction: str, entry_price: float, shares: float, stop_loss: float, target: float, ) -> PaperTrade: direction = direction.strip().lower() if direction not in ("long", "short"): raise ValidationError("direction must be 'long' or 'short'") if shares <= 0 or entry_price <= 0: raise ValidationError("shares and entry_price must be positive") ticker = await _get_ticker(db, symbol) trade = PaperTrade( user_id=user_id, ticker_id=ticker.id, direction=direction, entry_price=entry_price, shares=shares, stop_loss=stop_loss, target=target, status="open", opened_at=datetime.now(timezone.utc), ) db.add(trade) await db.commit() await db.refresh(trade) return trade def _to_dict( trade: PaperTrade, symbol: str, current_price: float | None, benchmark_closes: dict[date, float] | None = None, trailing: tuple[float, float | None] | None = None, ) -> dict: # For open trades, mark to market; for closed, the realized exit price. ref = current_price if trade.status == "open" else trade.close_price # Alpha = trade return − benchmark (SPY) return over the same holding period. benchmark_return = None alpha_pct = None alpha_usd = None if ref is not None and trade.entry_price and benchmark_closes: sign = 1.0 if trade.direction == "long" else -1.0 trade_return = (ref - trade.entry_price) / trade.entry_price * 100.0 * sign as_of = ( trade.closed_at.date() if trade.status == "closed" and trade.closed_at is not None else date.today() ) benchmark_return = benchmark_service.benchmark_return_pct( benchmark_closes, trade.opened_at.date(), as_of ) if benchmark_return is not None: alpha_pct = trade_return - benchmark_return alpha_usd = alpha_pct / 100.0 * trade.entry_price * trade.shares return { "id": trade.id, "symbol": symbol, "direction": trade.direction, "entry_price": trade.entry_price, "shares": trade.shares, "stop_loss": trade.stop_loss, "target": trade.target, "status": trade.status, "opened_at": trade.opened_at, "close_price": trade.close_price, "closed_at": trade.closed_at, "current_price": ref, "benchmark_return_pct": benchmark_return, "alpha_pct": alpha_pct, "alpha_usd": alpha_usd, "close_reason": trade.close_reason, "trailing_stop": trailing[0] if trailing else None, "trailing_distance_pct": trailing[1] if trailing else None, } async def list_trades( db: AsyncSession, user_id: int | None = None, status: str | None = None, ) -> list[dict]: stmt = ( select(PaperTrade, Ticker.symbol) .join(Ticker, PaperTrade.ticker_id == Ticker.id) ) if user_id is not None: # None → all users (single-user app; used by the digest) stmt = stmt.where(PaperTrade.user_id == user_id) if status is not None: stmt = stmt.where(PaperTrade.status == status) stmt = stmt.order_by(PaperTrade.opened_at.desc()) rows = (await db.execute(stmt)).all() open_ids = {t.ticker_id for t, _ in rows if t.status == "open"} prices = await _latest_closes(db, open_ids) # Benchmark closes for alpha — populated by the daily/benchmark job. Empty until # that runs once, in which case alpha is simply left unset (a read path never # makes a provider call). benchmark_closes = await benchmark_service.load_benchmark_closes(db) # Current trailing-stop level + distance for open trades (when trailing is active). policy = await get_exit_policy(db) trailing_info: dict[int, tuple[float, float | None]] = {} if policy["mode"] == "trailing": trail_frac = policy["trailing_pct"] / 100.0 for t, _ in rows: if t.status != "open": continue max_high = await _max_high_after(db, t.ticker_id, t.opened_at.date()) peak = max(t.entry_price, max_high) if max_high is not None else t.entry_price long = t.direction == "long" level = ( max(t.stop_loss, peak * (1 - trail_frac)) if long else min(t.stop_loss, peak * (1 + trail_frac)) ) cur = prices.get(t.ticker_id) dist = None if cur: dist = ((cur - level) / cur * 100.0) if long else ((level - cur) / cur * 100.0) trailing_info[t.id] = (level, dist) return [ _to_dict(t, sym, prices.get(t.ticker_id), benchmark_closes, trailing_info.get(t.id)) for t, sym in rows ] async def close_trade( db: AsyncSession, user_id: int, trade_id: int, close_price: float | None = None, ) -> PaperTrade: result = await db.execute( select(PaperTrade).where( PaperTrade.id == trade_id, PaperTrade.user_id == user_id, ) ) trade = result.scalar_one_or_none() if trade is None: raise NotFoundError(f"Paper trade not found: {trade_id}") if trade.status == "closed": raise ValidationError("Trade is already closed") if close_price is None: prices = await _latest_closes(db, {trade.ticker_id}) close_price = prices.get(trade.ticker_id) if close_price is None: raise ValidationError("No current price available to close at; supply close_price") trade.status = "closed" trade.close_price = float(close_price) trade.close_reason = "manual" trade.closed_at = datetime.now(timezone.utc) await db.commit() await db.refresh(trade) return trade async def resolve_open_trades(db: AsyncSession) -> int: """Auto-close open trades whose stop or target was hit in the daily bars. Walks the bars after each trade's open (same logic as the outcome evaluator). Target hit → close at the target; stop (or an ambiguous same-bar touch) → close at the stop. Trades that have hit neither stay open. Returns the count closed. """ result = await db.execute(select(PaperTrade).where(PaperTrade.status == "open")) open_trades = list(result.scalars().all()) if not open_trades: return 0 policy = await get_exit_policy(db) mode = policy["mode"] trail_frac = policy["trailing_pct"] / 100.0 closed = 0 for trade in open_trades: bars_result = await db.execute( select(OHLCVRecord.date, OHLCVRecord.high, OHLCVRecord.low) .where( OHLCVRecord.ticker_id == trade.ticker_id, OHLCVRecord.date > trade.opened_at.date(), ) .order_by(OHLCVRecord.date.asc()) ) bars = [Bar(date=d, high=h, low=lo) for d, h, lo in bars_result.all()] if not bars: continue if mode == "trailing": hit = _trailing_close(trade.direction, trade.entry_price, trade.stop_loss, trail_frac, bars) if hit is None: continue # neither the trailing nor the initial stop reached yet close_price, close_date, reason = hit else: # max_bars beyond the data so a still-open trade returns undecided (not "expired"). outcome, outcome_date = evaluate_setup_against_bars( trade.direction, trade.stop_loss, trade.target, bars, max_bars=len(bars) + 1 ) if outcome == OUTCOME_TARGET_HIT: close_price, close_date, reason = trade.target, outcome_date, "target" elif outcome in (OUTCOME_STOP_HIT, OUTCOME_AMBIGUOUS): close_price, close_date, reason = trade.stop_loss, outcome_date, "stop" else: continue trade.status = "closed" trade.close_price = float(close_price) trade.close_reason = reason trade.closed_at = datetime.combine(close_date, datetime.min.time(), tzinfo=timezone.utc) closed += 1 if closed: await db.commit() return closed