first commit
This commit is contained in:
288
app/services/watchlist_service.py
Normal file
288
app/services/watchlist_service.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Watchlist service.
|
||||
|
||||
Auto-populates top-X tickers by composite score (default 10), supports
|
||||
manual add/remove (tagged, not subject to auto-population), enforces
|
||||
cap (auto + 10 manual, default max 20), and updates auto entries on
|
||||
score recomputation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import DuplicateError, NotFoundError, ValidationError
|
||||
from app.models.score import CompositeScore, DimensionScore
|
||||
from app.models.sr_level import SRLevel
|
||||
from app.models.ticker import Ticker
|
||||
from app.models.trade_setup import TradeSetup
|
||||
from app.models.watchlist import WatchlistEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_AUTO_SIZE = 10
|
||||
MAX_MANUAL = 10
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
async def auto_populate(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
top_x: int = DEFAULT_AUTO_SIZE,
|
||||
) -> None:
|
||||
"""Auto-populate watchlist with top-X tickers by composite score.
|
||||
|
||||
Replaces existing auto entries. Manual entries are untouched.
|
||||
"""
|
||||
# Get top-X tickers by composite score (non-stale, descending)
|
||||
stmt = (
|
||||
select(CompositeScore)
|
||||
.where(CompositeScore.is_stale == False) # noqa: E712
|
||||
.order_by(CompositeScore.score.desc())
|
||||
.limit(top_x)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
top_scores = list(result.scalars().all())
|
||||
top_ticker_ids = {cs.ticker_id for cs in top_scores}
|
||||
|
||||
# Delete existing auto entries for this user
|
||||
await db.execute(
|
||||
delete(WatchlistEntry).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
WatchlistEntry.entry_type == "auto",
|
||||
)
|
||||
)
|
||||
|
||||
# Get manual ticker_ids so we don't duplicate
|
||||
manual_result = await db.execute(
|
||||
select(WatchlistEntry.ticker_id).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
WatchlistEntry.entry_type == "manual",
|
||||
)
|
||||
)
|
||||
manual_ticker_ids = {row[0] for row in manual_result.all()}
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
for ticker_id in top_ticker_ids:
|
||||
if ticker_id in manual_ticker_ids:
|
||||
continue # Already on watchlist as manual
|
||||
entry = WatchlistEntry(
|
||||
user_id=user_id,
|
||||
ticker_id=ticker_id,
|
||||
entry_type="auto",
|
||||
added_at=now,
|
||||
)
|
||||
db.add(entry)
|
||||
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def add_manual_entry(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
symbol: str,
|
||||
) -> WatchlistEntry:
|
||||
"""Add a manual watchlist entry.
|
||||
|
||||
Raises DuplicateError if already on watchlist.
|
||||
Raises ValidationError if manual cap exceeded.
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
# Check if already on watchlist
|
||||
existing = await db.execute(
|
||||
select(WatchlistEntry).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
WatchlistEntry.ticker_id == ticker.id,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise DuplicateError(f"Ticker already on watchlist: {ticker.symbol}")
|
||||
|
||||
# Count current manual entries
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(WatchlistEntry).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
WatchlistEntry.entry_type == "manual",
|
||||
)
|
||||
)
|
||||
manual_count = count_result.scalar() or 0
|
||||
|
||||
if manual_count >= MAX_MANUAL:
|
||||
raise ValidationError(
|
||||
f"Manual watchlist cap reached ({MAX_MANUAL}). "
|
||||
"Remove an entry before adding a new one."
|
||||
)
|
||||
|
||||
# Check total cap
|
||||
total_result = await db.execute(
|
||||
select(func.count()).select_from(WatchlistEntry).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
)
|
||||
)
|
||||
total_count = total_result.scalar() or 0
|
||||
max_total = DEFAULT_AUTO_SIZE + MAX_MANUAL
|
||||
|
||||
if total_count >= max_total:
|
||||
raise ValidationError(
|
||||
f"Watchlist cap reached ({max_total}). "
|
||||
"Remove an entry before adding a new one."
|
||||
)
|
||||
|
||||
entry = WatchlistEntry(
|
||||
user_id=user_id,
|
||||
ticker_id=ticker.id,
|
||||
entry_type="manual",
|
||||
added_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(entry)
|
||||
await db.commit()
|
||||
await db.refresh(entry)
|
||||
return entry
|
||||
|
||||
|
||||
async def remove_entry(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
symbol: str,
|
||||
) -> None:
|
||||
"""Remove a watchlist entry (manual or auto)."""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
result = await db.execute(
|
||||
select(WatchlistEntry).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
WatchlistEntry.ticker_id == ticker.id,
|
||||
)
|
||||
)
|
||||
entry = result.scalar_one_or_none()
|
||||
if entry is None:
|
||||
raise NotFoundError(f"Ticker not on watchlist: {ticker.symbol}")
|
||||
|
||||
await db.delete(entry)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _enrich_entry(
|
||||
db: AsyncSession,
|
||||
entry: WatchlistEntry,
|
||||
symbol: str,
|
||||
) -> dict:
|
||||
"""Build enriched watchlist entry dict with scores, R:R, and SR levels."""
|
||||
ticker_id = entry.ticker_id
|
||||
|
||||
# Composite score
|
||||
comp_result = await db.execute(
|
||||
select(CompositeScore).where(CompositeScore.ticker_id == ticker_id)
|
||||
)
|
||||
comp = comp_result.scalar_one_or_none()
|
||||
|
||||
# Dimension scores
|
||||
dim_result = await db.execute(
|
||||
select(DimensionScore).where(DimensionScore.ticker_id == ticker_id)
|
||||
)
|
||||
dims = [
|
||||
{"dimension": ds.dimension, "score": ds.score}
|
||||
for ds in dim_result.scalars().all()
|
||||
]
|
||||
|
||||
# Best trade setup (highest R:R) for this ticker
|
||||
setup_result = await db.execute(
|
||||
select(TradeSetup)
|
||||
.where(TradeSetup.ticker_id == ticker_id)
|
||||
.order_by(TradeSetup.rr_ratio.desc())
|
||||
.limit(1)
|
||||
)
|
||||
setup = setup_result.scalar_one_or_none()
|
||||
|
||||
# Active SR levels
|
||||
sr_result = await db.execute(
|
||||
select(SRLevel)
|
||||
.where(SRLevel.ticker_id == ticker_id)
|
||||
.order_by(SRLevel.strength.desc())
|
||||
)
|
||||
sr_levels = [
|
||||
{
|
||||
"price_level": lv.price_level,
|
||||
"type": lv.type,
|
||||
"strength": lv.strength,
|
||||
}
|
||||
for lv in sr_result.scalars().all()
|
||||
]
|
||||
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"entry_type": entry.entry_type,
|
||||
"composite_score": comp.score if comp else None,
|
||||
"dimensions": dims,
|
||||
"rr_ratio": setup.rr_ratio if setup else None,
|
||||
"rr_direction": setup.direction if setup else None,
|
||||
"sr_levels": sr_levels,
|
||||
"added_at": entry.added_at,
|
||||
}
|
||||
|
||||
|
||||
async def get_watchlist(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
sort_by: str = "composite",
|
||||
) -> list[dict]:
|
||||
"""Get user's watchlist with enriched data.
|
||||
|
||||
Runs auto_populate first to ensure auto entries are current,
|
||||
then enriches each entry with scores, R:R, and SR levels.
|
||||
|
||||
sort_by: "composite", "rr", or a dimension name
|
||||
(e.g. "technical", "sr_quality", "sentiment", "fundamental", "momentum").
|
||||
"""
|
||||
# Auto-populate to refresh auto entries
|
||||
await auto_populate(db, user_id)
|
||||
await db.commit()
|
||||
|
||||
# Fetch all entries with ticker symbol
|
||||
stmt = (
|
||||
select(WatchlistEntry, Ticker.symbol)
|
||||
.join(Ticker, WatchlistEntry.ticker_id == Ticker.id)
|
||||
.where(WatchlistEntry.user_id == user_id)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
rows = result.all()
|
||||
|
||||
entries: list[dict] = []
|
||||
for entry, symbol in rows:
|
||||
enriched = await _enrich_entry(db, entry, symbol)
|
||||
entries.append(enriched)
|
||||
|
||||
# Sort
|
||||
if sort_by == "composite":
|
||||
entries.sort(
|
||||
key=lambda e: e["composite_score"] if e["composite_score"] is not None else -1,
|
||||
reverse=True,
|
||||
)
|
||||
elif sort_by == "rr":
|
||||
entries.sort(
|
||||
key=lambda e: e["rr_ratio"] if e["rr_ratio"] is not None else -1,
|
||||
reverse=True,
|
||||
)
|
||||
else:
|
||||
# Sort by a specific dimension score
|
||||
def _dim_sort_key(e: dict) -> float:
|
||||
for d in e["dimensions"]:
|
||||
if d["dimension"] == sort_by:
|
||||
return d["score"]
|
||||
return -1.0
|
||||
|
||||
entries.sort(key=_dim_sort_key, reverse=True)
|
||||
|
||||
return entries
|
||||
Reference in New Issue
Block a user