289 lines
8.2 KiB
Python
289 lines
8.2 KiB
Python
"""Watchlist service.
|
|
|
|
Auto-populates top-X tickers by composite score (default 10), supports
|
|
manual add/remove (tagged, not subject to auto-population), enforces
|
|
cap (auto + 10 manual, default max 20), and updates auto entries on
|
|
score recomputation.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
|
|
from sqlalchemy import delete, func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.exceptions import DuplicateError, NotFoundError, ValidationError
|
|
from app.models.score import CompositeScore, DimensionScore
|
|
from app.models.sr_level import SRLevel
|
|
from app.models.ticker import Ticker
|
|
from app.models.trade_setup import TradeSetup
|
|
from app.models.watchlist import WatchlistEntry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_AUTO_SIZE = 10
|
|
MAX_MANUAL = 10
|
|
|
|
|
|
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
|
normalised = symbol.strip().upper()
|
|
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
|
ticker = result.scalar_one_or_none()
|
|
if ticker is None:
|
|
raise NotFoundError(f"Ticker not found: {normalised}")
|
|
return ticker
|
|
|
|
|
|
async def auto_populate(
|
|
db: AsyncSession,
|
|
user_id: int,
|
|
top_x: int = DEFAULT_AUTO_SIZE,
|
|
) -> None:
|
|
"""Auto-populate watchlist with top-X tickers by composite score.
|
|
|
|
Replaces existing auto entries. Manual entries are untouched.
|
|
"""
|
|
# Get top-X tickers by composite score (non-stale, descending)
|
|
stmt = (
|
|
select(CompositeScore)
|
|
.where(CompositeScore.is_stale == False) # noqa: E712
|
|
.order_by(CompositeScore.score.desc())
|
|
.limit(top_x)
|
|
)
|
|
result = await db.execute(stmt)
|
|
top_scores = list(result.scalars().all())
|
|
top_ticker_ids = {cs.ticker_id for cs in top_scores}
|
|
|
|
# Delete existing auto entries for this user
|
|
await db.execute(
|
|
delete(WatchlistEntry).where(
|
|
WatchlistEntry.user_id == user_id,
|
|
WatchlistEntry.entry_type == "auto",
|
|
)
|
|
)
|
|
|
|
# Get manual ticker_ids so we don't duplicate
|
|
manual_result = await db.execute(
|
|
select(WatchlistEntry.ticker_id).where(
|
|
WatchlistEntry.user_id == user_id,
|
|
WatchlistEntry.entry_type == "manual",
|
|
)
|
|
)
|
|
manual_ticker_ids = {row[0] for row in manual_result.all()}
|
|
|
|
now = datetime.now(timezone.utc)
|
|
for ticker_id in top_ticker_ids:
|
|
if ticker_id in manual_ticker_ids:
|
|
continue # Already on watchlist as manual
|
|
entry = WatchlistEntry(
|
|
user_id=user_id,
|
|
ticker_id=ticker_id,
|
|
entry_type="auto",
|
|
added_at=now,
|
|
)
|
|
db.add(entry)
|
|
|
|
await db.flush()
|
|
|
|
|
|
async def add_manual_entry(
|
|
db: AsyncSession,
|
|
user_id: int,
|
|
symbol: str,
|
|
) -> WatchlistEntry:
|
|
"""Add a manual watchlist entry.
|
|
|
|
Raises DuplicateError if already on watchlist.
|
|
Raises ValidationError if manual cap exceeded.
|
|
"""
|
|
ticker = await _get_ticker(db, symbol)
|
|
|
|
# Check if already on watchlist
|
|
existing = await db.execute(
|
|
select(WatchlistEntry).where(
|
|
WatchlistEntry.user_id == user_id,
|
|
WatchlistEntry.ticker_id == ticker.id,
|
|
)
|
|
)
|
|
if existing.scalar_one_or_none() is not None:
|
|
raise DuplicateError(f"Ticker already on watchlist: {ticker.symbol}")
|
|
|
|
# Count current manual entries
|
|
count_result = await db.execute(
|
|
select(func.count()).select_from(WatchlistEntry).where(
|
|
WatchlistEntry.user_id == user_id,
|
|
WatchlistEntry.entry_type == "manual",
|
|
)
|
|
)
|
|
manual_count = count_result.scalar() or 0
|
|
|
|
if manual_count >= MAX_MANUAL:
|
|
raise ValidationError(
|
|
f"Manual watchlist cap reached ({MAX_MANUAL}). "
|
|
"Remove an entry before adding a new one."
|
|
)
|
|
|
|
# Check total cap
|
|
total_result = await db.execute(
|
|
select(func.count()).select_from(WatchlistEntry).where(
|
|
WatchlistEntry.user_id == user_id,
|
|
)
|
|
)
|
|
total_count = total_result.scalar() or 0
|
|
max_total = DEFAULT_AUTO_SIZE + MAX_MANUAL
|
|
|
|
if total_count >= max_total:
|
|
raise ValidationError(
|
|
f"Watchlist cap reached ({max_total}). "
|
|
"Remove an entry before adding a new one."
|
|
)
|
|
|
|
entry = WatchlistEntry(
|
|
user_id=user_id,
|
|
ticker_id=ticker.id,
|
|
entry_type="manual",
|
|
added_at=datetime.now(timezone.utc),
|
|
)
|
|
db.add(entry)
|
|
await db.commit()
|
|
await db.refresh(entry)
|
|
return entry
|
|
|
|
|
|
async def remove_entry(
|
|
db: AsyncSession,
|
|
user_id: int,
|
|
symbol: str,
|
|
) -> None:
|
|
"""Remove a watchlist entry (manual or auto)."""
|
|
ticker = await _get_ticker(db, symbol)
|
|
|
|
result = await db.execute(
|
|
select(WatchlistEntry).where(
|
|
WatchlistEntry.user_id == user_id,
|
|
WatchlistEntry.ticker_id == ticker.id,
|
|
)
|
|
)
|
|
entry = result.scalar_one_or_none()
|
|
if entry is None:
|
|
raise NotFoundError(f"Ticker not on watchlist: {ticker.symbol}")
|
|
|
|
await db.delete(entry)
|
|
await db.commit()
|
|
|
|
|
|
async def _enrich_entry(
|
|
db: AsyncSession,
|
|
entry: WatchlistEntry,
|
|
symbol: str,
|
|
) -> dict:
|
|
"""Build enriched watchlist entry dict with scores, R:R, and SR levels."""
|
|
ticker_id = entry.ticker_id
|
|
|
|
# Composite score
|
|
comp_result = await db.execute(
|
|
select(CompositeScore).where(CompositeScore.ticker_id == ticker_id)
|
|
)
|
|
comp = comp_result.scalar_one_or_none()
|
|
|
|
# Dimension scores
|
|
dim_result = await db.execute(
|
|
select(DimensionScore).where(DimensionScore.ticker_id == ticker_id)
|
|
)
|
|
dims = [
|
|
{"dimension": ds.dimension, "score": ds.score}
|
|
for ds in dim_result.scalars().all()
|
|
]
|
|
|
|
# Best trade setup (highest R:R) for this ticker
|
|
setup_result = await db.execute(
|
|
select(TradeSetup)
|
|
.where(TradeSetup.ticker_id == ticker_id)
|
|
.order_by(TradeSetup.rr_ratio.desc())
|
|
.limit(1)
|
|
)
|
|
setup = setup_result.scalar_one_or_none()
|
|
|
|
# Active SR levels
|
|
sr_result = await db.execute(
|
|
select(SRLevel)
|
|
.where(SRLevel.ticker_id == ticker_id)
|
|
.order_by(SRLevel.strength.desc())
|
|
)
|
|
sr_levels = [
|
|
{
|
|
"price_level": lv.price_level,
|
|
"type": lv.type,
|
|
"strength": lv.strength,
|
|
}
|
|
for lv in sr_result.scalars().all()
|
|
]
|
|
|
|
return {
|
|
"symbol": symbol,
|
|
"entry_type": entry.entry_type,
|
|
"composite_score": comp.score if comp else None,
|
|
"dimensions": dims,
|
|
"rr_ratio": setup.rr_ratio if setup else None,
|
|
"rr_direction": setup.direction if setup else None,
|
|
"sr_levels": sr_levels,
|
|
"added_at": entry.added_at,
|
|
}
|
|
|
|
|
|
async def get_watchlist(
|
|
db: AsyncSession,
|
|
user_id: int,
|
|
sort_by: str = "composite",
|
|
) -> list[dict]:
|
|
"""Get user's watchlist with enriched data.
|
|
|
|
Runs auto_populate first to ensure auto entries are current,
|
|
then enriches each entry with scores, R:R, and SR levels.
|
|
|
|
sort_by: "composite", "rr", or a dimension name
|
|
(e.g. "technical", "sr_quality", "sentiment", "fundamental", "momentum").
|
|
"""
|
|
# Auto-populate to refresh auto entries
|
|
await auto_populate(db, user_id)
|
|
await db.commit()
|
|
|
|
# Fetch all entries with ticker symbol
|
|
stmt = (
|
|
select(WatchlistEntry, Ticker.symbol)
|
|
.join(Ticker, WatchlistEntry.ticker_id == Ticker.id)
|
|
.where(WatchlistEntry.user_id == user_id)
|
|
)
|
|
result = await db.execute(stmt)
|
|
rows = result.all()
|
|
|
|
entries: list[dict] = []
|
|
for entry, symbol in rows:
|
|
enriched = await _enrich_entry(db, entry, symbol)
|
|
entries.append(enriched)
|
|
|
|
# Sort
|
|
if sort_by == "composite":
|
|
entries.sort(
|
|
key=lambda e: e["composite_score"] if e["composite_score"] is not None else -1,
|
|
reverse=True,
|
|
)
|
|
elif sort_by == "rr":
|
|
entries.sort(
|
|
key=lambda e: e["rr_ratio"] if e["rr_ratio"] is not None else -1,
|
|
reverse=True,
|
|
)
|
|
else:
|
|
# Sort by a specific dimension score
|
|
def _dim_sort_key(e: dict) -> float:
|
|
for d in e["dimensions"]:
|
|
if d["dimension"] == sort_by:
|
|
return d["score"]
|
|
return -1.0
|
|
|
|
entries.sort(key=_dim_sort_key, reverse=True)
|
|
|
|
return entries
|