"""Paper-trading service: take, mark-to-market, and close simulated trades.""" from __future__ import annotations from datetime import datetime, timezone from sqlalchemy import and_, func, select from sqlalchemy.ext.asyncio import AsyncSession from app.exceptions import NotFoundError, ValidationError from app.models.ohlcv import OHLCVRecord from app.models.paper_trade import PaperTrade from app.models.ticker import Ticker async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker: normalised = symbol.strip().upper() result = await db.execute(select(Ticker).where(Ticker.symbol == normalised)) ticker = result.scalar_one_or_none() if ticker is None: raise NotFoundError(f"Ticker not found: {normalised}") return ticker async def _latest_closes(db: AsyncSession, ticker_ids: set[int]) -> dict[int, float]: """Latest stored close per ticker.""" if not ticker_ids: return {} latest = ( select(OHLCVRecord.ticker_id, func.max(OHLCVRecord.date).label("md")) .where(OHLCVRecord.ticker_id.in_(ticker_ids)) .group_by(OHLCVRecord.ticker_id) .subquery() ) stmt = select(OHLCVRecord.ticker_id, OHLCVRecord.close).join( latest, and_( OHLCVRecord.ticker_id == latest.c.ticker_id, OHLCVRecord.date == latest.c.md, ), ) result = await db.execute(stmt) return {tid: float(close) for tid, close in result.all()} async def create_trade( db: AsyncSession, user_id: int, *, symbol: str, direction: str, entry_price: float, shares: float, stop_loss: float, target: float, ) -> PaperTrade: direction = direction.strip().lower() if direction not in ("long", "short"): raise ValidationError("direction must be 'long' or 'short'") if shares <= 0 or entry_price <= 0: raise ValidationError("shares and entry_price must be positive") ticker = await _get_ticker(db, symbol) trade = PaperTrade( user_id=user_id, ticker_id=ticker.id, direction=direction, entry_price=entry_price, shares=shares, stop_loss=stop_loss, target=target, status="open", opened_at=datetime.now(timezone.utc), ) db.add(trade) await db.commit() await db.refresh(trade) return trade def _to_dict(trade: PaperTrade, symbol: str, current_price: float | None) -> dict: return { "id": trade.id, "symbol": symbol, "direction": trade.direction, "entry_price": trade.entry_price, "shares": trade.shares, "stop_loss": trade.stop_loss, "target": trade.target, "status": trade.status, "opened_at": trade.opened_at, "close_price": trade.close_price, "closed_at": trade.closed_at, # For open trades, mark to market; for closed, the realized exit price. "current_price": current_price if trade.status == "open" else trade.close_price, } async def list_trades( db: AsyncSession, user_id: int, status: str | None = None, ) -> list[dict]: stmt = ( select(PaperTrade, Ticker.symbol) .join(Ticker, PaperTrade.ticker_id == Ticker.id) .where(PaperTrade.user_id == user_id) ) if status is not None: stmt = stmt.where(PaperTrade.status == status) stmt = stmt.order_by(PaperTrade.opened_at.desc()) rows = (await db.execute(stmt)).all() open_ids = {t.ticker_id for t, _ in rows if t.status == "open"} prices = await _latest_closes(db, open_ids) return [_to_dict(t, sym, prices.get(t.ticker_id)) for t, sym in rows] async def close_trade( db: AsyncSession, user_id: int, trade_id: int, close_price: float | None = None, ) -> PaperTrade: result = await db.execute( select(PaperTrade).where( PaperTrade.id == trade_id, PaperTrade.user_id == user_id, ) ) trade = result.scalar_one_or_none() if trade is None: raise NotFoundError(f"Paper trade not found: {trade_id}") if trade.status == "closed": raise ValidationError("Trade is already closed") if close_price is None: prices = await _latest_closes(db, {trade.ticker_id}) close_price = prices.get(trade.ticker_id) if close_price is None: raise ValidationError("No current price available to close at; supply close_price") trade.status = "closed" trade.close_price = float(close_price) trade.closed_at = datetime.now(timezone.utc) await db.commit() await db.refresh(trade) return trade