"""Watchlist service. Auto-populates top-X tickers by composite score (default 10), supports manual add/remove (tagged, not subject to auto-population), enforces cap (auto + 10 manual, default max 20), and updates auto entries on score recomputation. """ from __future__ import annotations import logging from datetime import datetime, timezone from sqlalchemy import delete, func, select from sqlalchemy.ext.asyncio import AsyncSession from app.exceptions import DuplicateError, NotFoundError, ValidationError from app.models.score import CompositeScore, DimensionScore from app.models.sr_level import SRLevel from app.models.ticker import Ticker from app.models.trade_setup import TradeSetup from app.models.watchlist import WatchlistEntry logger = logging.getLogger(__name__) DEFAULT_AUTO_SIZE = 10 MAX_MANUAL = 10 async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker: normalised = symbol.strip().upper() result = await db.execute(select(Ticker).where(Ticker.symbol == normalised)) ticker = result.scalar_one_or_none() if ticker is None: raise NotFoundError(f"Ticker not found: {normalised}") return ticker async def auto_populate( db: AsyncSession, user_id: int, top_x: int = DEFAULT_AUTO_SIZE, ) -> None: """Auto-populate watchlist with top-X tickers by composite score. Replaces existing auto entries. Manual entries are untouched. """ # Get top-X tickers by composite score (non-stale, descending) stmt = ( select(CompositeScore) .where(CompositeScore.is_stale == False) # noqa: E712 .order_by(CompositeScore.score.desc()) .limit(top_x) ) result = await db.execute(stmt) top_scores = list(result.scalars().all()) top_ticker_ids = {cs.ticker_id for cs in top_scores} # Delete existing auto entries for this user await db.execute( delete(WatchlistEntry).where( WatchlistEntry.user_id == user_id, WatchlistEntry.entry_type == "auto", ) ) # Get manual ticker_ids so we don't duplicate manual_result = await db.execute( select(WatchlistEntry.ticker_id).where( WatchlistEntry.user_id == user_id, WatchlistEntry.entry_type == "manual", ) ) manual_ticker_ids = {row[0] for row in manual_result.all()} now = datetime.now(timezone.utc) for ticker_id in top_ticker_ids: if ticker_id in manual_ticker_ids: continue # Already on watchlist as manual entry = WatchlistEntry( user_id=user_id, ticker_id=ticker_id, entry_type="auto", added_at=now, ) db.add(entry) await db.flush() async def add_manual_entry( db: AsyncSession, user_id: int, symbol: str, ) -> WatchlistEntry: """Add a manual watchlist entry. Raises DuplicateError if already on watchlist. Raises ValidationError if manual cap exceeded. """ ticker = await _get_ticker(db, symbol) # Check if already on watchlist existing = await db.execute( select(WatchlistEntry).where( WatchlistEntry.user_id == user_id, WatchlistEntry.ticker_id == ticker.id, ) ) if existing.scalar_one_or_none() is not None: raise DuplicateError(f"Ticker already on watchlist: {ticker.symbol}") # Count current manual entries count_result = await db.execute( select(func.count()).select_from(WatchlistEntry).where( WatchlistEntry.user_id == user_id, WatchlistEntry.entry_type == "manual", ) ) manual_count = count_result.scalar() or 0 if manual_count >= MAX_MANUAL: raise ValidationError( f"Manual watchlist cap reached ({MAX_MANUAL}). " "Remove an entry before adding a new one." ) # Check total cap total_result = await db.execute( select(func.count()).select_from(WatchlistEntry).where( WatchlistEntry.user_id == user_id, ) ) total_count = total_result.scalar() or 0 max_total = DEFAULT_AUTO_SIZE + MAX_MANUAL if total_count >= max_total: raise ValidationError( f"Watchlist cap reached ({max_total}). " "Remove an entry before adding a new one." ) entry = WatchlistEntry( user_id=user_id, ticker_id=ticker.id, entry_type="manual", added_at=datetime.now(timezone.utc), ) db.add(entry) await db.commit() await db.refresh(entry) return entry async def remove_entry( db: AsyncSession, user_id: int, symbol: str, ) -> None: """Remove a watchlist entry (manual or auto).""" ticker = await _get_ticker(db, symbol) result = await db.execute( select(WatchlistEntry).where( WatchlistEntry.user_id == user_id, WatchlistEntry.ticker_id == ticker.id, ) ) entry = result.scalar_one_or_none() if entry is None: raise NotFoundError(f"Ticker not on watchlist: {ticker.symbol}") await db.delete(entry) await db.commit() async def _enrich_entry( db: AsyncSession, entry: WatchlistEntry, symbol: str, ) -> dict: """Build enriched watchlist entry dict with scores, R:R, and SR levels.""" ticker_id = entry.ticker_id # Composite score comp_result = await db.execute( select(CompositeScore).where(CompositeScore.ticker_id == ticker_id) ) comp = comp_result.scalar_one_or_none() # Dimension scores dim_result = await db.execute( select(DimensionScore).where(DimensionScore.ticker_id == ticker_id) ) dims = [ {"dimension": ds.dimension, "score": ds.score} for ds in dim_result.scalars().all() ] # Best trade setup (highest R:R) for this ticker setup_result = await db.execute( select(TradeSetup) .where(TradeSetup.ticker_id == ticker_id) .order_by(TradeSetup.rr_ratio.desc()) .limit(1) ) setup = setup_result.scalar_one_or_none() # Active SR levels sr_result = await db.execute( select(SRLevel) .where(SRLevel.ticker_id == ticker_id) .order_by(SRLevel.strength.desc()) ) sr_levels = [ { "price_level": lv.price_level, "type": lv.type, "strength": lv.strength, } for lv in sr_result.scalars().all() ] return { "symbol": symbol, "entry_type": entry.entry_type, "composite_score": comp.score if comp else None, "dimensions": dims, "rr_ratio": setup.rr_ratio if setup else None, "rr_direction": setup.direction if setup else None, "sr_levels": sr_levels, "added_at": entry.added_at, } async def get_watchlist( db: AsyncSession, user_id: int, sort_by: str = "composite", ) -> list[dict]: """Get user's watchlist with enriched data. Runs auto_populate first to ensure auto entries are current, then enriches each entry with scores, R:R, and SR levels. sort_by: "composite", "rr", or a dimension name (e.g. "technical", "sr_quality", "sentiment", "fundamental", "momentum"). """ # Auto-populate to refresh auto entries await auto_populate(db, user_id) await db.commit() # Fetch all entries with ticker symbol stmt = ( select(WatchlistEntry, Ticker.symbol) .join(Ticker, WatchlistEntry.ticker_id == Ticker.id) .where(WatchlistEntry.user_id == user_id) ) result = await db.execute(stmt) rows = result.all() entries: list[dict] = [] for entry, symbol in rows: enriched = await _enrich_entry(db, entry, symbol) entries.append(enriched) # Sort if sort_by == "composite": entries.sort( key=lambda e: e["composite_score"] if e["composite_score"] is not None else -1, reverse=True, ) elif sort_by == "rr": entries.sort( key=lambda e: e["rr_ratio"] if e["rr_ratio"] is not None else -1, reverse=True, ) else: # Sort by a specific dimension score def _dim_sort_key(e: dict) -> float: for d in e["dimensions"]: if d["dimension"] == sort_by: return d["score"] return -1.0 entries.sort(key=_dim_sort_key, reverse=True) return entries