first commit
Some checks failed
Deploy / lint (push) Failing after 7s
Deploy / test (push) Has been skipped
Deploy / deploy (push) Has been skipped

This commit is contained in:
Dennis Thiessen
2026-02-20 17:31:01 +01:00
commit 61ab24490d
160 changed files with 17034 additions and 0 deletions

1
app/services/__init__.py Normal file
View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,238 @@
"""Admin service: user management, system settings, data cleanup, job control."""
from datetime import datetime, timedelta, timezone
from passlib.hash import bcrypt
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.exceptions import DuplicateError, NotFoundError, ValidationError
from app.models.fundamental import FundamentalData
from app.models.ohlcv import OHLCVRecord
from app.models.sentiment import SentimentScore
from app.models.settings import SystemSetting
from app.models.user import User
# ---------------------------------------------------------------------------
# User management
# ---------------------------------------------------------------------------
async def list_users(db: AsyncSession) -> list[User]:
"""Return all users ordered by id."""
result = await db.execute(select(User).order_by(User.id))
return list(result.scalars().all())
async def create_user(
db: AsyncSession,
username: str,
password: str,
role: str = "user",
has_access: bool = False,
) -> User:
"""Create a new user account (admin action)."""
result = await db.execute(select(User).where(User.username == username))
if result.scalar_one_or_none() is not None:
raise DuplicateError(f"Username already exists: {username}")
user = User(
username=username,
password_hash=bcrypt.hash(password),
role=role,
has_access=has_access,
)
db.add(user)
await db.commit()
await db.refresh(user)
return user
async def set_user_access(db: AsyncSession, user_id: int, has_access: bool) -> User:
"""Grant or revoke API access for a user."""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise NotFoundError(f"User not found: {user_id}")
user.has_access = has_access
await db.commit()
await db.refresh(user)
return user
async def reset_password(db: AsyncSession, user_id: int, new_password: str) -> User:
"""Reset a user's password."""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise NotFoundError(f"User not found: {user_id}")
user.password_hash = bcrypt.hash(new_password)
await db.commit()
await db.refresh(user)
return user
# ---------------------------------------------------------------------------
# Registration toggle
# ---------------------------------------------------------------------------
async def toggle_registration(db: AsyncSession, enabled: bool) -> SystemSetting:
"""Enable or disable user registration via SystemSetting."""
result = await db.execute(
select(SystemSetting).where(SystemSetting.key == "registration_enabled")
)
setting = result.scalar_one_or_none()
value = str(enabled).lower()
if setting is None:
setting = SystemSetting(key="registration_enabled", value=value)
db.add(setting)
else:
setting.value = value
await db.commit()
await db.refresh(setting)
return setting
# ---------------------------------------------------------------------------
# System settings CRUD
# ---------------------------------------------------------------------------
async def list_settings(db: AsyncSession) -> list[SystemSetting]:
"""Return all system settings."""
result = await db.execute(select(SystemSetting).order_by(SystemSetting.key))
return list(result.scalars().all())
async def update_setting(db: AsyncSession, key: str, value: str) -> SystemSetting:
"""Create or update a system setting."""
result = await db.execute(
select(SystemSetting).where(SystemSetting.key == key)
)
setting = result.scalar_one_or_none()
if setting is None:
setting = SystemSetting(key=key, value=value)
db.add(setting)
else:
setting.value = value
await db.commit()
await db.refresh(setting)
return setting
# ---------------------------------------------------------------------------
# Data cleanup
# ---------------------------------------------------------------------------
async def cleanup_data(db: AsyncSession, older_than_days: int) -> dict[str, int]:
"""Delete OHLCV, sentiment, and fundamental records older than N days.
Preserves tickers, users, and latest scores.
Returns a dict with counts of deleted records per table.
"""
cutoff = datetime.now(timezone.utc) - timedelta(days=older_than_days)
counts: dict[str, int] = {}
# OHLCV — date column is a date, compare with cutoff date
result = await db.execute(
delete(OHLCVRecord).where(OHLCVRecord.date < cutoff.date())
)
counts["ohlcv"] = result.rowcount # type: ignore[assignment]
# Sentiment — timestamp is datetime
result = await db.execute(
delete(SentimentScore).where(SentimentScore.timestamp < cutoff)
)
counts["sentiment"] = result.rowcount # type: ignore[assignment]
# Fundamentals — fetched_at is datetime
result = await db.execute(
delete(FundamentalData).where(FundamentalData.fetched_at < cutoff)
)
counts["fundamentals"] = result.rowcount # type: ignore[assignment]
await db.commit()
return counts
# ---------------------------------------------------------------------------
# Job control (placeholder — scheduler is Task 12.1)
# ---------------------------------------------------------------------------
VALID_JOB_NAMES = {"data_collector", "sentiment_collector", "fundamental_collector", "rr_scanner"}
JOB_LABELS = {
"data_collector": "Data Collector (OHLCV)",
"sentiment_collector": "Sentiment Collector",
"fundamental_collector": "Fundamental Collector",
"rr_scanner": "R:R Scanner",
}
async def list_jobs(db: AsyncSession) -> list[dict]:
"""Return status of all scheduled jobs."""
from app.scheduler import scheduler
jobs_out = []
for name in sorted(VALID_JOB_NAMES):
# Check enabled setting
key = f"job_{name}_enabled"
result = await db.execute(
select(SystemSetting).where(SystemSetting.key == key)
)
setting = result.scalar_one_or_none()
enabled = setting.value == "true" if setting else True # default enabled
# Get scheduler job info
job = scheduler.get_job(name)
next_run = None
if job and job.next_run_time:
next_run = job.next_run_time.isoformat()
jobs_out.append({
"name": name,
"label": JOB_LABELS.get(name, name),
"enabled": enabled,
"next_run_at": next_run,
"registered": job is not None,
})
return jobs_out
async def trigger_job(db: AsyncSession, job_name: str) -> dict[str, str]:
"""Trigger a manual job run via the scheduler.
Runs the job immediately (in addition to its regular schedule).
"""
if job_name not in VALID_JOB_NAMES:
raise ValidationError(f"Unknown job: {job_name}. Valid jobs: {', '.join(sorted(VALID_JOB_NAMES))}")
from app.scheduler import scheduler
job = scheduler.get_job(job_name)
if job is None:
return {"job": job_name, "status": "not_found", "message": f"Job '{job_name}' is not registered in the scheduler"}
job.modify(next_run_time=None) # Reset, then trigger immediately
from datetime import datetime, timezone
job.modify(next_run_time=datetime.now(timezone.utc))
return {"job": job_name, "status": "triggered", "message": f"Job '{job_name}' triggered for immediate execution"}
async def toggle_job(db: AsyncSession, job_name: str, enabled: bool) -> SystemSetting:
"""Enable or disable a scheduled job by storing state in SystemSetting.
Actual scheduler integration happens in Task 12.1.
"""
if job_name not in VALID_JOB_NAMES:
raise ValidationError(f"Unknown job: {job_name}. Valid jobs: {', '.join(sorted(VALID_JOB_NAMES))}")
key = f"job_{job_name}_enabled"
return await update_setting(db, key, str(enabled).lower())

View File

@@ -0,0 +1,66 @@
"""Auth service: registration, login, and JWT token generation."""
from datetime import datetime, timedelta, timezone
from jose import jwt
from passlib.hash import bcrypt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.dependencies import JWT_ALGORITHM
from app.exceptions import AuthenticationError, AuthorizationError, DuplicateError
from app.models.settings import SystemSetting
from app.models.user import User
async def register(db: AsyncSession, username: str, password: str) -> User:
"""Register a new user.
Checks if registration is enabled via SystemSetting, rejects duplicates,
and creates a user with role='user' and has_access=False.
"""
# Check registration toggle
result = await db.execute(
select(SystemSetting).where(SystemSetting.key == "registration_enabled")
)
setting = result.scalar_one_or_none()
if setting is not None and setting.value.lower() == "false":
raise AuthorizationError("Registration is closed")
# Check duplicate username
result = await db.execute(select(User).where(User.username == username))
if result.scalar_one_or_none() is not None:
raise DuplicateError(f"Username already exists: {username}")
user = User(
username=username,
password_hash=bcrypt.hash(password),
role="user",
has_access=False,
)
db.add(user)
await db.commit()
await db.refresh(user)
return user
async def login(db: AsyncSession, username: str, password: str) -> str:
"""Authenticate user and return a JWT access token.
Returns the same error message for wrong username or wrong password
to avoid leaking which field is incorrect.
"""
result = await db.execute(select(User).where(User.username == username))
user = result.scalar_one_or_none()
if user is None or not bcrypt.verify(password, user.password_hash):
raise AuthenticationError("Invalid credentials")
payload = {
"sub": str(user.id),
"role": user.role,
"exp": datetime.now(timezone.utc) + timedelta(minutes=settings.jwt_expiry_minutes),
}
token = jwt.encode(payload, settings.jwt_secret, algorithm=JWT_ALGORITHM)
return token

View File

@@ -0,0 +1,101 @@
"""Fundamental data service.
Stores fundamental data (P/E, revenue growth, earnings surprise, market cap)
and marks the fundamental dimension score as stale on new data.
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.exceptions import NotFoundError
from app.models.fundamental import FundamentalData
from app.models.score import DimensionScore
from app.models.ticker import Ticker
logger = logging.getLogger(__name__)
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
"""Look up a ticker by symbol."""
normalised = symbol.strip().upper()
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
ticker = result.scalar_one_or_none()
if ticker is None:
raise NotFoundError(f"Ticker not found: {normalised}")
return ticker
async def store_fundamental(
db: AsyncSession,
symbol: str,
pe_ratio: float | None = None,
revenue_growth: float | None = None,
earnings_surprise: float | None = None,
market_cap: float | None = None,
) -> FundamentalData:
"""Store or update fundamental data for a ticker.
Keeps a single latest snapshot per ticker. On new data, marks the
fundamental dimension score as stale (if one exists).
"""
ticker = await _get_ticker(db, symbol)
# Check for existing record
result = await db.execute(
select(FundamentalData).where(FundamentalData.ticker_id == ticker.id)
)
existing = result.scalar_one_or_none()
now = datetime.now(timezone.utc)
if existing is not None:
existing.pe_ratio = pe_ratio
existing.revenue_growth = revenue_growth
existing.earnings_surprise = earnings_surprise
existing.market_cap = market_cap
existing.fetched_at = now
record = existing
else:
record = FundamentalData(
ticker_id=ticker.id,
pe_ratio=pe_ratio,
revenue_growth=revenue_growth,
earnings_surprise=earnings_surprise,
market_cap=market_cap,
fetched_at=now,
)
db.add(record)
# Mark fundamental dimension score as stale if it exists
# TODO: Use DimensionScore service when built
dim_result = await db.execute(
select(DimensionScore).where(
DimensionScore.ticker_id == ticker.id,
DimensionScore.dimension == "fundamental",
)
)
dim_score = dim_result.scalar_one_or_none()
if dim_score is not None:
dim_score.is_stale = True
await db.commit()
await db.refresh(record)
return record
async def get_fundamental(
db: AsyncSession,
symbol: str,
) -> FundamentalData | None:
"""Get the latest fundamental data for a ticker."""
ticker = await _get_ticker(db, symbol)
result = await db.execute(
select(FundamentalData).where(FundamentalData.ticker_id == ticker.id)
)
return result.scalar_one_or_none()

View File

@@ -0,0 +1,509 @@
"""Technical Analysis service.
Computes indicators from OHLCV data. Each indicator function is a pure
function that takes a list of OHLCV-like records and returns raw values
plus a normalized 0-100 score. The service layer handles DB fetching,
caching, and minimum-data validation.
"""
from __future__ import annotations
from datetime import date
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.cache import indicator_cache
from app.exceptions import ValidationError
from app.services.price_service import query_ohlcv
# ---------------------------------------------------------------------------
# Minimum data requirements per indicator
# ---------------------------------------------------------------------------
MIN_BARS: dict[str, int] = {
"adx": 28,
"ema": 0, # dynamic: period + 1
"rsi": 15,
"atr": 15,
"volume_profile": 20,
"pivot_points": 5,
}
DEFAULT_PERIODS: dict[str, int] = {
"adx": 14,
"ema": 20,
"rsi": 14,
"atr": 14,
}
# ---------------------------------------------------------------------------
# Pure computation helpers
# ---------------------------------------------------------------------------
def _ema(values: list[float], period: int) -> list[float]:
"""Compute EMA series. Returns list same length as *values*."""
if len(values) < period:
return []
k = 2.0 / (period + 1)
ema_vals: list[float] = [sum(values[:period]) / period]
for v in values[period:]:
ema_vals.append(v * k + ema_vals[-1] * (1 - k))
return ema_vals
def compute_adx(
highs: list[float],
lows: list[float],
closes: list[float],
period: int = 14,
) -> dict[str, Any]:
"""Compute ADX from high/low/close arrays.
Returns dict with ``adx``, ``plus_di``, ``minus_di``, ``score``.
"""
n = len(closes)
if n < 2 * period:
raise ValidationError(
f"ADX requires at least {2 * period} bars, got {n}"
)
# True Range, +DM, -DM
tr_list: list[float] = []
plus_dm: list[float] = []
minus_dm: list[float] = []
for i in range(1, n):
h, l, pc = highs[i], lows[i], closes[i - 1]
tr_list.append(max(h - l, abs(h - pc), abs(l - pc)))
up = highs[i] - highs[i - 1]
down = lows[i - 1] - lows[i]
plus_dm.append(up if up > down and up > 0 else 0.0)
minus_dm.append(down if down > up and down > 0 else 0.0)
# Smoothed TR, +DM, -DM (Wilder smoothing)
def _smooth(vals: list[float], p: int) -> list[float]:
s = [sum(vals[:p])]
for v in vals[p:]:
s.append(s[-1] - s[-1] / p + v)
return s
s_tr = _smooth(tr_list, period)
s_plus = _smooth(plus_dm, period)
s_minus = _smooth(minus_dm, period)
# +DI, -DI, DX
dx_list: list[float] = []
plus_di_last = 0.0
minus_di_last = 0.0
for i in range(len(s_tr)):
tr_v = s_tr[i] if s_tr[i] != 0 else 1e-10
pdi = 100.0 * s_plus[i] / tr_v
mdi = 100.0 * s_minus[i] / tr_v
denom = pdi + mdi if (pdi + mdi) != 0 else 1e-10
dx_list.append(100.0 * abs(pdi - mdi) / denom)
plus_di_last = pdi
minus_di_last = mdi
# ADX = smoothed DX
if len(dx_list) < period:
adx_val = sum(dx_list) / len(dx_list) if dx_list else 0.0
else:
adx_vals = _smooth(dx_list, period)
adx_val = adx_vals[-1]
score = max(0.0, min(100.0, adx_val))
return {
"adx": round(adx_val, 4),
"plus_di": round(plus_di_last, 4),
"minus_di": round(minus_di_last, 4),
"score": round(score, 4),
}
def compute_ema(
closes: list[float],
period: int = 20,
) -> dict[str, Any]:
"""Compute EMA for *closes* with given *period*.
Score: normalized position of latest close relative to EMA.
Above EMA → higher score, below → lower.
"""
min_bars = period + 1
if len(closes) < min_bars:
raise ValidationError(
f"EMA({period}) requires at least {min_bars} bars, got {len(closes)}"
)
ema_vals = _ema(closes, period)
latest_ema = ema_vals[-1]
latest_close = closes[-1]
# Score: 50 = at EMA, 100 = 5%+ above, 0 = 5%+ below
if latest_ema == 0:
pct = 0.0
else:
pct = (latest_close - latest_ema) / latest_ema * 100.0
score = max(0.0, min(100.0, 50.0 + pct * 10.0))
return {
"ema": round(latest_ema, 4),
"period": period,
"latest_close": round(latest_close, 4),
"score": round(score, 4),
}
def compute_rsi(
closes: list[float],
period: int = 14,
) -> dict[str, Any]:
"""Compute RSI. Score = RSI value (already 0-100)."""
n = len(closes)
if n < period + 1:
raise ValidationError(
f"RSI requires at least {period + 1} bars, got {n}"
)
deltas = [closes[i] - closes[i - 1] for i in range(1, n)]
gains = [d if d > 0 else 0.0 for d in deltas]
losses = [-d if d < 0 else 0.0 for d in deltas]
avg_gain = sum(gains[:period]) / period
avg_loss = sum(losses[:period]) / period
for i in range(period, len(deltas)):
avg_gain = (avg_gain * (period - 1) + gains[i]) / period
avg_loss = (avg_loss * (period - 1) + losses[i]) / period
if avg_loss == 0:
rsi = 100.0
else:
rs = avg_gain / avg_loss
rsi = 100.0 - 100.0 / (1.0 + rs)
score = max(0.0, min(100.0, rsi))
return {
"rsi": round(rsi, 4),
"period": period,
"score": round(score, 4),
}
def compute_atr(
highs: list[float],
lows: list[float],
closes: list[float],
period: int = 14,
) -> dict[str, Any]:
"""Compute ATR. Score = normalized inverse (lower ATR = higher score)."""
n = len(closes)
if n < period + 1:
raise ValidationError(
f"ATR requires at least {period + 1} bars, got {n}"
)
tr_list: list[float] = []
for i in range(1, n):
h, l, pc = highs[i], lows[i], closes[i - 1]
tr_list.append(max(h - l, abs(h - pc), abs(l - pc)))
# Wilder smoothing
atr = sum(tr_list[:period]) / period
for tr in tr_list[period:]:
atr = (atr * (period - 1) + tr) / period
# Score: inverse normalized. ATR as % of price; lower = higher score.
latest_close = closes[-1]
if latest_close == 0:
atr_pct = 0.0
else:
atr_pct = atr / latest_close * 100.0
# 0% ATR → 100 score, 10%+ ATR → 0 score
score = max(0.0, min(100.0, 100.0 - atr_pct * 10.0))
return {
"atr": round(atr, 4),
"period": period,
"atr_percent": round(atr_pct, 4),
"score": round(score, 4),
}
def compute_volume_profile(
highs: list[float],
lows: list[float],
closes: list[float],
volumes: list[int],
num_bins: int = 20,
) -> dict[str, Any]:
"""Compute Volume Profile: POC, Value Area, HVN, LVN.
Score: proximity of latest close to POC (closer = higher).
"""
n = len(closes)
if n < 20:
raise ValidationError(
f"Volume Profile requires at least 20 bars, got {n}"
)
price_min = min(lows)
price_max = max(highs)
if price_max == price_min:
price_max = price_min + 1.0 # avoid zero-width range
bin_width = (price_max - price_min) / num_bins
bins: list[float] = [0.0] * num_bins
bin_prices: list[float] = [
price_min + (i + 0.5) * bin_width for i in range(num_bins)
]
for i in range(n):
# Distribute volume across bins the bar spans
bar_low, bar_high = lows[i], highs[i]
for b in range(num_bins):
bl = price_min + b * bin_width
bh = bl + bin_width
if bar_high >= bl and bar_low <= bh:
bins[b] += volumes[i]
total_vol = sum(bins)
if total_vol == 0:
total_vol = 1.0
# POC = bin with highest volume
poc_idx = bins.index(max(bins))
poc = round(bin_prices[poc_idx], 4)
# Value Area: 70% of total volume around POC
sorted_bins = sorted(range(num_bins), key=lambda i: bins[i], reverse=True)
va_vol = 0.0
va_indices: list[int] = []
for idx in sorted_bins:
va_vol += bins[idx]
va_indices.append(idx)
if va_vol >= total_vol * 0.7:
break
va_low = round(price_min + min(va_indices) * bin_width, 4)
va_high = round(price_min + (max(va_indices) + 1) * bin_width, 4)
# HVN / LVN: bins above/below average volume
avg_vol = total_vol / num_bins
hvn = [round(bin_prices[i], 4) for i in range(num_bins) if bins[i] > avg_vol]
lvn = [round(bin_prices[i], 4) for i in range(num_bins) if bins[i] < avg_vol]
# Score: proximity of latest close to POC
latest = closes[-1]
price_range = price_max - price_min
if price_range == 0:
score = 100.0
else:
dist_pct = abs(latest - poc) / price_range
score = max(0.0, min(100.0, 100.0 * (1.0 - dist_pct)))
return {
"poc": poc,
"value_area_low": va_low,
"value_area_high": va_high,
"hvn": hvn,
"lvn": lvn,
"score": round(score, 4),
}
def compute_pivot_points(
highs: list[float],
lows: list[float],
closes: list[float],
window: int = 2,
) -> dict[str, Any]:
"""Detect swing highs/lows as pivot points.
A swing high at index *i* means highs[i] >= all highs in [i-window, i+window].
Score: based on number of pivots near current price.
"""
n = len(closes)
if n < 5:
raise ValidationError(
f"Pivot Points requires at least 5 bars, got {n}"
)
swing_highs: list[float] = []
swing_lows: list[float] = []
for i in range(window, n - window):
# Swing high
if all(highs[i] >= highs[j] for j in range(i - window, i + window + 1)):
swing_highs.append(round(highs[i], 4))
# Swing low
if all(lows[i] <= lows[j] for j in range(i - window, i + window + 1)):
swing_lows.append(round(lows[i], 4))
all_pivots = swing_highs + swing_lows
latest = closes[-1]
# Score: fraction of pivots within 2% of current price → 0-100
if not all_pivots or latest == 0:
score = 0.0
else:
near = sum(1 for p in all_pivots if abs(p - latest) / latest <= 0.02)
score = min(100.0, (near / max(len(all_pivots), 1)) * 100.0)
return {
"swing_highs": swing_highs,
"swing_lows": swing_lows,
"pivot_count": len(all_pivots),
"score": round(score, 4),
}
def compute_ema_cross(
closes: list[float],
short_period: int = 20,
long_period: int = 50,
tolerance: float = 1e-6,
) -> dict[str, Any]:
"""Compare short EMA vs long EMA.
Returns signal: bullish (short > long), bearish (short < long),
neutral (within tolerance).
"""
min_bars = long_period + 1
if len(closes) < min_bars:
raise ValidationError(
f"EMA Cross requires at least {min_bars} bars, got {len(closes)}"
)
short_ema_vals = _ema(closes, short_period)
long_ema_vals = _ema(closes, long_period)
short_ema = short_ema_vals[-1]
long_ema = long_ema_vals[-1]
diff = short_ema - long_ema
if abs(diff) <= tolerance:
signal = "neutral"
elif diff > 0:
signal = "bullish"
else:
signal = "bearish"
return {
"short_ema": round(short_ema, 4),
"long_ema": round(long_ema, 4),
"short_period": short_period,
"long_period": long_period,
"signal": signal,
}
# ---------------------------------------------------------------------------
# Supported indicator types
# ---------------------------------------------------------------------------
INDICATOR_TYPES = {"adx", "ema", "rsi", "atr", "volume_profile", "pivot_points"}
# ---------------------------------------------------------------------------
# Service-layer functions (DB + cache + validation)
# ---------------------------------------------------------------------------
def _extract_ohlcv(records: list) -> tuple[
list[float], list[float], list[float], list[float], list[int]
]:
"""Extract parallel arrays from OHLCVRecord list."""
opens = [float(r.open) for r in records]
highs = [float(r.high) for r in records]
lows = [float(r.low) for r in records]
closes = [float(r.close) for r in records]
volumes = [int(r.volume) for r in records]
return opens, highs, lows, closes, volumes
async def get_indicator(
db: AsyncSession,
symbol: str,
indicator_type: str,
start_date: date | None = None,
end_date: date | None = None,
period: int | None = None,
) -> dict[str, Any]:
"""Compute a single indicator for *symbol*.
Checks cache first; stores result after computing.
"""
indicator_type = indicator_type.lower()
if indicator_type not in INDICATOR_TYPES:
raise ValidationError(
f"Unknown indicator type: {indicator_type}. "
f"Supported: {', '.join(sorted(INDICATOR_TYPES))}"
)
cache_key = (symbol.upper(), str(start_date), str(end_date), indicator_type)
cached = indicator_cache.get(cache_key)
if cached is not None:
return cached
records = await query_ohlcv(db, symbol, start_date, end_date)
_, highs, lows, closes, volumes = _extract_ohlcv(records)
n = len(records)
if indicator_type == "adx":
p = period or DEFAULT_PERIODS["adx"]
result = compute_adx(highs, lows, closes, period=p)
elif indicator_type == "ema":
p = period or DEFAULT_PERIODS["ema"]
result = compute_ema(closes, period=p)
elif indicator_type == "rsi":
p = period or DEFAULT_PERIODS["rsi"]
result = compute_rsi(closes, period=p)
elif indicator_type == "atr":
p = period or DEFAULT_PERIODS["atr"]
result = compute_atr(highs, lows, closes, period=p)
elif indicator_type == "volume_profile":
result = compute_volume_profile(highs, lows, closes, volumes)
elif indicator_type == "pivot_points":
result = compute_pivot_points(highs, lows, closes)
else:
raise ValidationError(f"Unknown indicator type: {indicator_type}")
response = {
"indicator_type": indicator_type,
"values": {k: v for k, v in result.items() if k != "score"},
"score": result["score"],
"bars_used": n,
}
indicator_cache.set(cache_key, response)
return response
async def get_ema_cross(
db: AsyncSession,
symbol: str,
start_date: date | None = None,
end_date: date | None = None,
short_period: int = 20,
long_period: int = 50,
) -> dict[str, Any]:
"""Compute EMA cross signal for *symbol*."""
cache_key = (
symbol.upper(),
str(start_date),
str(end_date),
f"ema_cross_{short_period}_{long_period}",
)
cached = indicator_cache.get(cache_key)
if cached is not None:
return cached
records = await query_ohlcv(db, symbol, start_date, end_date)
_, _, _, closes, _ = _extract_ohlcv(records)
result = compute_ema_cross(closes, short_period, long_period)
indicator_cache.set(cache_key, result)
return result

View File

@@ -0,0 +1,172 @@
"""Ingestion Pipeline service: fetch from provider, validate, upsert into Price Store.
Handles rate-limit resume via IngestionProgress and provider error isolation.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import date, timedelta
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.exceptions import NotFoundError, ProviderError, RateLimitError
from app.models.settings import IngestionProgress
from app.models.ticker import Ticker
from app.providers.protocol import MarketDataProvider
from app.services import price_service
logger = logging.getLogger(__name__)
@dataclass
class IngestionResult:
"""Result of an ingestion run."""
symbol: str
records_ingested: int
last_date: date | None
status: str # "complete" | "partial" | "error"
message: str | None = None
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
"""Look up ticker by symbol. Raises NotFoundError if missing."""
normalised = symbol.strip().upper()
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
ticker = result.scalar_one_or_none()
if ticker is None:
raise NotFoundError(f"Ticker not found: {normalised}")
return ticker
async def _get_progress(db: AsyncSession, ticker_id: int) -> IngestionProgress | None:
"""Get the IngestionProgress record for a ticker, if any."""
result = await db.execute(
select(IngestionProgress).where(IngestionProgress.ticker_id == ticker_id)
)
return result.scalar_one_or_none()
async def _update_progress(
db: AsyncSession, ticker_id: int, last_date: date
) -> None:
"""Create or update the IngestionProgress record for a ticker."""
progress = await _get_progress(db, ticker_id)
if progress is None:
progress = IngestionProgress(ticker_id=ticker_id, last_ingested_date=last_date)
db.add(progress)
else:
progress.last_ingested_date = last_date
await db.commit()
async def fetch_and_ingest(
db: AsyncSession,
provider: MarketDataProvider,
symbol: str,
start_date: date | None = None,
end_date: date | None = None,
) -> IngestionResult:
"""Fetch OHLCV data from provider and upsert into Price Store.
- Resolves start_date from IngestionProgress if not provided (resume).
- Defaults end_date to today.
- Tracks last_ingested_date after each successful upsert.
- On RateLimitError from provider: returns partial progress.
- On ProviderError: returns error, no data modification.
"""
ticker = await _get_ticker(db, symbol)
# Resolve end_date
if end_date is None:
end_date = date.today()
# Resolve start_date: use progress resume or default to 1 year ago
if start_date is None:
progress = await _get_progress(db, ticker.id)
if progress is not None:
start_date = progress.last_ingested_date + timedelta(days=1)
else:
start_date = end_date - timedelta(days=365)
# If start > end, nothing to fetch
if start_date > end_date:
return IngestionResult(
symbol=ticker.symbol,
records_ingested=0,
last_date=None,
status="complete",
message="Already up to date",
)
# Fetch from provider
try:
records = await provider.fetch_ohlcv(ticker.symbol, start_date, end_date)
except RateLimitError:
# No data fetched at all — return partial with 0 records
return IngestionResult(
symbol=ticker.symbol,
records_ingested=0,
last_date=None,
status="partial",
message="Rate limited before any records fetched. Resume available.",
)
except ProviderError as exc:
logger.error("Provider error for %s: %s", ticker.symbol, exc)
return IngestionResult(
symbol=ticker.symbol,
records_ingested=0,
last_date=None,
status="error",
message=str(exc),
)
# Sort records by date to ensure ordered ingestion
records.sort(key=lambda r: r.date)
ingested_count = 0
last_ingested: date | None = None
for record in records:
try:
await price_service.upsert_ohlcv(
db,
symbol=ticker.symbol,
record_date=record.date,
open_=record.open,
high=record.high,
low=record.low,
close=record.close,
volume=record.volume,
)
ingested_count += 1
last_ingested = record.date
# Update progress after each successful upsert
await _update_progress(db, ticker.id, record.date)
except RateLimitError:
# Mid-ingestion rate limit — return partial progress
logger.warning(
"Rate limited during ingestion for %s after %d records",
ticker.symbol,
ingested_count,
)
return IngestionResult(
symbol=ticker.symbol,
records_ingested=ingested_count,
last_date=last_ingested,
status="partial",
message=f"Rate limited. Ingested {ingested_count} records. Resume available.",
)
return IngestionResult(
symbol=ticker.symbol,
records_ingested=ingested_count,
last_date=last_ingested,
status="complete",
message=f"Successfully ingested {ingested_count} records",
)

View File

@@ -0,0 +1,110 @@
"""Price Store service: upsert and query OHLCV records."""
from datetime import date, datetime
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.exceptions import NotFoundError, ValidationError
from app.models.ohlcv import OHLCVRecord
from app.models.ticker import Ticker
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
"""Look up a ticker by symbol. Raises NotFoundError if missing."""
normalised = symbol.strip().upper()
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
ticker = result.scalar_one_or_none()
if ticker is None:
raise NotFoundError(f"Ticker not found: {normalised}")
return ticker
def _validate_ohlcv(
high: float, low: float, open_: float, close: float, volume: int, record_date: date
) -> None:
"""Business-rule validation for an OHLCV record."""
if high < low:
raise ValidationError("Validation error: high must be >= low")
if any(p < 0 for p in (open_, high, low, close)):
raise ValidationError("Validation error: prices must be >= 0")
if volume < 0:
raise ValidationError("Validation error: volume must be >= 0")
if record_date > date.today():
raise ValidationError("Validation error: date must not be in the future")
async def upsert_ohlcv(
db: AsyncSession,
symbol: str,
record_date: date,
open_: float,
high: float,
low: float,
close: float,
volume: int,
) -> OHLCVRecord:
"""Insert or update an OHLCV record for (ticker, date).
Validates business rules, resolves ticker, then uses
ON CONFLICT DO UPDATE on the (ticker_id, date) unique constraint.
"""
_validate_ohlcv(high, low, open_, close, volume, record_date)
ticker = await _get_ticker(db, symbol)
stmt = pg_insert(OHLCVRecord).values(
ticker_id=ticker.id,
date=record_date,
open=open_,
high=high,
low=low,
close=close,
volume=volume,
created_at=datetime.utcnow(),
)
stmt = stmt.on_conflict_do_update(
constraint="uq_ohlcv_ticker_date",
set_={
"open": stmt.excluded.open,
"high": stmt.excluded.high,
"low": stmt.excluded.low,
"close": stmt.excluded.close,
"volume": stmt.excluded.volume,
"created_at": stmt.excluded.created_at,
},
)
stmt = stmt.returning(OHLCVRecord)
result = await db.execute(stmt)
await db.commit()
record = result.scalar_one()
# TODO: Invalidate LRU cache entries for this ticker (Task 7.1)
# TODO: Mark composite score as stale for this ticker (Task 10.1)
return record
async def query_ohlcv(
db: AsyncSession,
symbol: str,
start_date: date | None = None,
end_date: date | None = None,
) -> list[OHLCVRecord]:
"""Query OHLCV records for a ticker, optionally filtered by date range.
Returns records sorted by date ascending.
Raises NotFoundError if the ticker does not exist.
"""
ticker = await _get_ticker(db, symbol)
stmt = select(OHLCVRecord).where(OHLCVRecord.ticker_id == ticker.id)
if start_date is not None:
stmt = stmt.where(OHLCVRecord.date >= start_date)
if end_date is not None:
stmt = stmt.where(OHLCVRecord.date <= end_date)
stmt = stmt.order_by(OHLCVRecord.date.asc())
result = await db.execute(stmt)
return list(result.scalars().all())

View File

@@ -0,0 +1,241 @@
"""R:R Scanner service.
Scans tracked tickers for asymmetric risk-reward trade setups.
Long: target = nearest SR above, stop = entry - ATR × multiplier.
Short: target = nearest SR below, stop = entry + ATR × multiplier.
Filters by configurable R:R threshold (default 3:1).
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.exceptions import NotFoundError
from app.models.score import CompositeScore
from app.models.sr_level import SRLevel
from app.models.ticker import Ticker
from app.models.trade_setup import TradeSetup
from app.services.indicator_service import _extract_ohlcv, compute_atr
from app.services.price_service import query_ohlcv
logger = logging.getLogger(__name__)
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
normalised = symbol.strip().upper()
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
ticker = result.scalar_one_or_none()
if ticker is None:
raise NotFoundError(f"Ticker not found: {normalised}")
return ticker
async def scan_ticker(
db: AsyncSession,
symbol: str,
rr_threshold: float = 3.0,
atr_multiplier: float = 1.5,
) -> list[TradeSetup]:
"""Scan a single ticker for trade setups meeting the R:R threshold.
1. Fetch OHLCV data and compute ATR.
2. Fetch SR levels.
3. Compute long and short setups.
4. Filter by R:R threshold.
5. Delete old setups for this ticker and persist new ones.
Returns list of persisted TradeSetup models.
"""
ticker = await _get_ticker(db, symbol)
# Fetch OHLCV
records = await query_ohlcv(db, symbol)
if not records or len(records) < 15:
logger.info(
"Skipping %s: insufficient OHLCV data (%d bars, need 15+)",
symbol, len(records),
)
# Clear any stale setups
await db.execute(
delete(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
)
return []
_, highs, lows, closes, _ = _extract_ohlcv(records)
entry_price = closes[-1]
# Compute ATR
try:
atr_result = compute_atr(highs, lows, closes)
atr_value = atr_result["atr"]
except Exception:
logger.info("Skipping %s: cannot compute ATR", symbol)
await db.execute(
delete(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
)
return []
if atr_value <= 0:
logger.info("Skipping %s: ATR is zero or negative", symbol)
await db.execute(
delete(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
)
return []
# Fetch SR levels from DB (already computed by sr_service)
sr_result = await db.execute(
select(SRLevel).where(SRLevel.ticker_id == ticker.id)
)
sr_levels = list(sr_result.scalars().all())
if not sr_levels:
logger.info("Skipping %s: no SR levels available", symbol)
await db.execute(
delete(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
)
return []
levels_above = sorted(
[lv for lv in sr_levels if lv.price_level > entry_price],
key=lambda lv: lv.price_level,
)
levels_below = sorted(
[lv for lv in sr_levels if lv.price_level < entry_price],
key=lambda lv: lv.price_level,
reverse=True,
)
# Get composite score for this ticker
comp_result = await db.execute(
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
)
comp = comp_result.scalar_one_or_none()
composite_score = comp.score if comp else 0.0
now = datetime.now(timezone.utc)
setups: list[TradeSetup] = []
# Long setup: target = nearest SR above, stop = entry - ATR × multiplier
if levels_above:
target = levels_above[0].price_level
stop = entry_price - (atr_value * atr_multiplier)
reward = target - entry_price
risk = entry_price - stop
if risk > 0 and reward > 0:
rr = reward / risk
if rr >= rr_threshold:
setups.append(TradeSetup(
ticker_id=ticker.id,
direction="long",
entry_price=round(entry_price, 4),
stop_loss=round(stop, 4),
target=round(target, 4),
rr_ratio=round(rr, 4),
composite_score=round(composite_score, 4),
detected_at=now,
))
# Short setup: target = nearest SR below, stop = entry + ATR × multiplier
if levels_below:
target = levels_below[0].price_level
stop = entry_price + (atr_value * atr_multiplier)
reward = entry_price - target
risk = stop - entry_price
if risk > 0 and reward > 0:
rr = reward / risk
if rr >= rr_threshold:
setups.append(TradeSetup(
ticker_id=ticker.id,
direction="short",
entry_price=round(entry_price, 4),
stop_loss=round(stop, 4),
target=round(target, 4),
rr_ratio=round(rr, 4),
composite_score=round(composite_score, 4),
detected_at=now,
))
# Delete old setups for this ticker, persist new ones
await db.execute(
delete(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
)
for setup in setups:
db.add(setup)
await db.commit()
# Refresh to get IDs
for s in setups:
await db.refresh(s)
return setups
async def scan_all_tickers(
db: AsyncSession,
rr_threshold: float = 3.0,
atr_multiplier: float = 1.5,
) -> list[TradeSetup]:
"""Scan all tracked tickers for trade setups.
Processes each ticker independently — one failure doesn't stop others.
Returns all setups found across all tickers.
"""
result = await db.execute(select(Ticker).order_by(Ticker.symbol))
tickers = list(result.scalars().all())
all_setups: list[TradeSetup] = []
for ticker in tickers:
try:
setups = await scan_ticker(
db, ticker.symbol, rr_threshold, atr_multiplier
)
all_setups.extend(setups)
except Exception:
logger.exception("Error scanning ticker %s", ticker.symbol)
return all_setups
async def get_trade_setups(
db: AsyncSession,
direction: str | None = None,
) -> list[dict]:
"""Get all stored trade setups, optionally filtered by direction.
Returns dicts sorted by R:R desc, secondary composite desc.
Each dict includes the ticker symbol.
"""
stmt = (
select(TradeSetup, Ticker.symbol)
.join(Ticker, TradeSetup.ticker_id == Ticker.id)
)
if direction is not None:
stmt = stmt.where(TradeSetup.direction == direction.lower())
stmt = stmt.order_by(
TradeSetup.rr_ratio.desc(),
TradeSetup.composite_score.desc(),
)
result = await db.execute(stmt)
rows = result.all()
return [
{
"id": setup.id,
"symbol": symbol,
"direction": setup.direction,
"entry_price": setup.entry_price,
"stop_loss": setup.stop_loss,
"target": setup.target,
"rr_ratio": setup.rr_ratio,
"composite_score": setup.composite_score,
"detected_at": setup.detected_at,
}
for setup, symbol in rows
]

View File

@@ -0,0 +1,584 @@
"""Scoring Engine service.
Computes dimension scores (technical, sr_quality, sentiment, fundamental,
momentum) each 0-100, composite score as weighted average of available
dimensions with re-normalized weights, staleness marking/recomputation
on demand, and weight update triggers full recomputation.
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.exceptions import NotFoundError, ValidationError
from app.models.score import CompositeScore, DimensionScore
from app.models.settings import SystemSetting
from app.models.ticker import Ticker
logger = logging.getLogger(__name__)
DIMENSIONS = ["technical", "sr_quality", "sentiment", "fundamental", "momentum"]
DEFAULT_WEIGHTS: dict[str, float] = {
"technical": 0.25,
"sr_quality": 0.20,
"sentiment": 0.15,
"fundamental": 0.20,
"momentum": 0.20,
}
SCORING_WEIGHTS_KEY = "scoring_weights"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
normalised = symbol.strip().upper()
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
ticker = result.scalar_one_or_none()
if ticker is None:
raise NotFoundError(f"Ticker not found: {normalised}")
return ticker
async def _get_weights(db: AsyncSession) -> dict[str, float]:
"""Load scoring weights from SystemSetting, falling back to defaults."""
result = await db.execute(
select(SystemSetting).where(SystemSetting.key == SCORING_WEIGHTS_KEY)
)
setting = result.scalar_one_or_none()
if setting is not None:
try:
return json.loads(setting.value)
except (json.JSONDecodeError, TypeError):
logger.warning("Invalid scoring weights in DB, using defaults")
return dict(DEFAULT_WEIGHTS)
async def _save_weights(db: AsyncSession, weights: dict[str, float]) -> None:
"""Persist scoring weights to SystemSetting."""
result = await db.execute(
select(SystemSetting).where(SystemSetting.key == SCORING_WEIGHTS_KEY)
)
setting = result.scalar_one_or_none()
now = datetime.now(timezone.utc)
if setting is not None:
setting.value = json.dumps(weights)
setting.updated_at = now
else:
setting = SystemSetting(
key=SCORING_WEIGHTS_KEY,
value=json.dumps(weights),
updated_at=now,
)
db.add(setting)
# ---------------------------------------------------------------------------
# Dimension score computation
# ---------------------------------------------------------------------------
async def _compute_technical_score(db: AsyncSession, symbol: str) -> float | None:
"""Compute technical dimension score from ADX, EMA, RSI."""
from app.services.indicator_service import (
compute_adx,
compute_ema,
compute_rsi,
_extract_ohlcv,
)
from app.services.price_service import query_ohlcv
records = await query_ohlcv(db, symbol)
if not records:
return None
_, highs, lows, closes, _ = _extract_ohlcv(records)
scores: list[tuple[float, float]] = [] # (weight, score)
# ADX (weight 0.4) — needs 28+ bars
try:
adx_result = compute_adx(highs, lows, closes)
scores.append((0.4, adx_result["score"]))
except Exception:
pass
# EMA (weight 0.3) — needs period+1 bars
try:
ema_result = compute_ema(closes)
scores.append((0.3, ema_result["score"]))
except Exception:
pass
# RSI (weight 0.3) — needs 15+ bars
try:
rsi_result = compute_rsi(closes)
scores.append((0.3, rsi_result["score"]))
except Exception:
pass
if not scores:
return None
total_weight = sum(w for w, _ in scores)
if total_weight == 0:
return None
weighted = sum(w * s for w, s in scores) / total_weight
return max(0.0, min(100.0, weighted))
async def _compute_sr_quality_score(db: AsyncSession, symbol: str) -> float | None:
"""Compute S/R quality dimension score.
Based on number of strong levels, proximity to current price, avg strength.
"""
from app.services.price_service import query_ohlcv
from app.services.sr_service import get_sr_levels
records = await query_ohlcv(db, symbol)
if not records:
return None
current_price = float(records[-1].close)
if current_price <= 0:
return None
try:
levels = await get_sr_levels(db, symbol)
except Exception:
return None
if not levels:
return None
# Factor 1: Number of strong levels (strength >= 50) — max 40 pts
strong_count = sum(1 for lv in levels if lv.strength >= 50)
count_score = min(40.0, strong_count * 10.0)
# Factor 2: Proximity of nearest level to current price — max 30 pts
distances = [
abs(lv.price_level - current_price) / current_price for lv in levels
]
nearest_dist = min(distances) if distances else 1.0
# Closer = higher score. 0% distance = 30, 5%+ = 0
proximity_score = max(0.0, min(30.0, 30.0 * (1.0 - nearest_dist / 0.05)))
# Factor 3: Average strength — max 30 pts
avg_strength = sum(lv.strength for lv in levels) / len(levels)
strength_score = min(30.0, avg_strength * 0.3)
total = count_score + proximity_score + strength_score
return max(0.0, min(100.0, total))
async def _compute_sentiment_score(db: AsyncSession, symbol: str) -> float | None:
"""Compute sentiment dimension score via sentiment service."""
from app.services.sentiment_service import compute_sentiment_dimension_score
try:
return await compute_sentiment_dimension_score(db, symbol)
except Exception:
return None
async def _compute_fundamental_score(db: AsyncSession, symbol: str) -> float | None:
"""Compute fundamental dimension score.
Normalized composite of P/E (lower is better), revenue growth
(higher is better), earnings surprise (higher is better).
"""
from app.services.fundamental_service import get_fundamental
fund = await get_fundamental(db, symbol)
if fund is None:
return None
scores: list[float] = []
# P/E: lower is better. 0-15 = 100, 15-30 = 50-100, 30+ = 0-50
if fund.pe_ratio is not None and fund.pe_ratio > 0:
pe_score = max(0.0, min(100.0, 100.0 - (fund.pe_ratio - 15.0) * (100.0 / 30.0)))
scores.append(pe_score)
# Revenue growth: higher is better. 0% = 50, 20%+ = 100, -20% = 0
if fund.revenue_growth is not None:
rg_score = max(0.0, min(100.0, 50.0 + fund.revenue_growth * 2.5))
scores.append(rg_score)
# Earnings surprise: higher is better. 0% = 50, 10%+ = 100, -10% = 0
if fund.earnings_surprise is not None:
es_score = max(0.0, min(100.0, 50.0 + fund.earnings_surprise * 5.0))
scores.append(es_score)
if not scores:
return None
return sum(scores) / len(scores)
async def _compute_momentum_score(db: AsyncSession, symbol: str) -> float | None:
"""Compute momentum dimension score.
Rate of change of price over 5-day and 20-day lookback periods.
"""
from app.services.price_service import query_ohlcv
records = await query_ohlcv(db, symbol)
if not records or len(records) < 6:
return None
closes = [float(r.close) for r in records]
latest = closes[-1]
scores: list[tuple[float, float]] = [] # (weight, score)
# 5-day ROC (weight 0.5)
if len(closes) >= 6 and closes[-6] > 0:
roc_5 = (latest - closes[-6]) / closes[-6] * 100.0
# Map: -10% → 0, 0% → 50, +10% → 100
score_5 = max(0.0, min(100.0, 50.0 + roc_5 * 5.0))
scores.append((0.5, score_5))
# 20-day ROC (weight 0.5)
if len(closes) >= 21 and closes[-21] > 0:
roc_20 = (latest - closes[-21]) / closes[-21] * 100.0
score_20 = max(0.0, min(100.0, 50.0 + roc_20 * 5.0))
scores.append((0.5, score_20))
if not scores:
return None
total_weight = sum(w for w, _ in scores)
if total_weight == 0:
return None
weighted = sum(w * s for w, s in scores) / total_weight
return max(0.0, min(100.0, weighted))
_DIMENSION_COMPUTERS = {
"technical": _compute_technical_score,
"sr_quality": _compute_sr_quality_score,
"sentiment": _compute_sentiment_score,
"fundamental": _compute_fundamental_score,
"momentum": _compute_momentum_score,
}
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def compute_dimension_score(
db: AsyncSession, symbol: str, dimension: str
) -> float | None:
"""Compute a single dimension score for a ticker.
Returns the score (0-100) or None if insufficient data.
Persists the result to the DimensionScore table.
"""
if dimension not in _DIMENSION_COMPUTERS:
raise ValidationError(
f"Unknown dimension: {dimension}. Valid: {', '.join(DIMENSIONS)}"
)
ticker = await _get_ticker(db, symbol)
score_val = await _DIMENSION_COMPUTERS[dimension](db, symbol)
now = datetime.now(timezone.utc)
# Upsert dimension score
result = await db.execute(
select(DimensionScore).where(
DimensionScore.ticker_id == ticker.id,
DimensionScore.dimension == dimension,
)
)
existing = result.scalar_one_or_none()
if score_val is not None:
score_val = max(0.0, min(100.0, score_val))
if existing is not None:
if score_val is not None:
existing.score = score_val
existing.is_stale = False
existing.computed_at = now
else:
# Can't compute — mark stale
existing.is_stale = True
elif score_val is not None:
dim = DimensionScore(
ticker_id=ticker.id,
dimension=dimension,
score=score_val,
is_stale=False,
computed_at=now,
)
db.add(dim)
return score_val
async def compute_all_dimensions(
db: AsyncSession, symbol: str
) -> dict[str, float | None]:
"""Compute all dimension scores for a ticker. Returns dimension → score map."""
results: dict[str, float | None] = {}
for dim in DIMENSIONS:
results[dim] = await compute_dimension_score(db, symbol, dim)
return results
async def compute_composite_score(
db: AsyncSession,
symbol: str,
weights: dict[str, float] | None = None,
) -> tuple[float | None, list[str]]:
"""Compute composite score from available dimension scores.
Returns (composite_score, missing_dimensions).
Missing dimensions are excluded and weights re-normalized.
"""
ticker = await _get_ticker(db, symbol)
if weights is None:
weights = await _get_weights(db)
# Get current dimension scores
result = await db.execute(
select(DimensionScore).where(DimensionScore.ticker_id == ticker.id)
)
dim_scores = {ds.dimension: ds for ds in result.scalars().all()}
available: list[tuple[str, float, float]] = [] # (dim, weight, score)
missing: list[str] = []
for dim in DIMENSIONS:
w = weights.get(dim, 0.0)
if w <= 0:
continue
ds = dim_scores.get(dim)
if ds is not None and not ds.is_stale and ds.score is not None:
available.append((dim, w, ds.score))
else:
missing.append(dim)
if not available:
return None, missing
# Re-normalize weights
total_weight = sum(w for _, w, _ in available)
if total_weight == 0:
return None, missing
composite = sum(w * s for _, w, s in available) / total_weight
composite = max(0.0, min(100.0, composite))
# Persist composite score
now = datetime.now(timezone.utc)
comp_result = await db.execute(
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
)
existing = comp_result.scalar_one_or_none()
if existing is not None:
existing.score = composite
existing.is_stale = False
existing.weights_json = json.dumps(weights)
existing.computed_at = now
else:
comp = CompositeScore(
ticker_id=ticker.id,
score=composite,
is_stale=False,
weights_json=json.dumps(weights),
computed_at=now,
)
db.add(comp)
return composite, missing
async def get_score(
db: AsyncSession, symbol: str
) -> dict:
"""Get composite + all dimension scores for a ticker.
Recomputes stale dimensions on demand, then recomputes composite.
Returns a dict suitable for ScoreResponse.
"""
ticker = await _get_ticker(db, symbol)
weights = await _get_weights(db)
# Check for stale dimension scores and recompute them
result = await db.execute(
select(DimensionScore).where(DimensionScore.ticker_id == ticker.id)
)
dim_scores = {ds.dimension: ds for ds in result.scalars().all()}
for dim in DIMENSIONS:
ds = dim_scores.get(dim)
if ds is None or ds.is_stale:
await compute_dimension_score(db, symbol, dim)
# Check composite staleness
comp_result = await db.execute(
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
)
comp = comp_result.scalar_one_or_none()
if comp is None or comp.is_stale:
await compute_composite_score(db, symbol, weights)
await db.commit()
# Re-fetch everything fresh
result = await db.execute(
select(DimensionScore).where(DimensionScore.ticker_id == ticker.id)
)
dim_scores_list = list(result.scalars().all())
comp_result = await db.execute(
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
)
comp = comp_result.scalar_one_or_none()
dimensions = []
missing = []
for dim in DIMENSIONS:
found = next((ds for ds in dim_scores_list if ds.dimension == dim), None)
if found is not None:
dimensions.append({
"dimension": found.dimension,
"score": found.score,
"is_stale": found.is_stale,
"computed_at": found.computed_at,
})
else:
missing.append(dim)
return {
"symbol": ticker.symbol,
"composite_score": comp.score if comp else None,
"composite_stale": comp.is_stale if comp else False,
"weights": weights,
"dimensions": dimensions,
"missing_dimensions": missing,
"computed_at": comp.computed_at if comp else None,
}
async def get_rankings(db: AsyncSession) -> dict:
"""Get all tickers ranked by composite score descending.
Returns dict suitable for RankingResponse.
"""
weights = await _get_weights(db)
# Get all tickers
result = await db.execute(select(Ticker).order_by(Ticker.symbol))
tickers = list(result.scalars().all())
rankings: list[dict] = []
for ticker in tickers:
# Get composite score
comp_result = await db.execute(
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
)
comp = comp_result.scalar_one_or_none()
# If no composite or stale, recompute
if comp is None or comp.is_stale:
# Recompute stale dimensions first
dim_result = await db.execute(
select(DimensionScore).where(
DimensionScore.ticker_id == ticker.id
)
)
dim_scores = {ds.dimension: ds for ds in dim_result.scalars().all()}
for dim in DIMENSIONS:
ds = dim_scores.get(dim)
if ds is None or ds.is_stale:
await compute_dimension_score(db, ticker.symbol, dim)
await compute_composite_score(db, ticker.symbol, weights)
await db.commit()
# Re-fetch
comp_result = await db.execute(
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
)
comp = comp_result.scalar_one_or_none()
if comp is None:
continue
dim_result = await db.execute(
select(DimensionScore).where(
DimensionScore.ticker_id == ticker.id
)
)
dims = [
{
"dimension": ds.dimension,
"score": ds.score,
"is_stale": ds.is_stale,
"computed_at": ds.computed_at,
}
for ds in dim_result.scalars().all()
]
rankings.append({
"symbol": ticker.symbol,
"composite_score": comp.score,
"dimensions": dims,
})
# Sort by composite score descending
rankings.sort(key=lambda r: r["composite_score"], reverse=True)
return {
"rankings": rankings,
"weights": weights,
}
async def update_weights(
db: AsyncSession, weights: dict[str, float]
) -> dict[str, float]:
"""Update scoring weights and recompute all composite scores.
Validates that all weights are positive and dimensions are valid.
Returns the new weights.
"""
# Validate
for dim, w in weights.items():
if dim not in DIMENSIONS:
raise ValidationError(
f"Unknown dimension: {dim}. Valid: {', '.join(DIMENSIONS)}"
)
if w < 0:
raise ValidationError(f"Weight for {dim} must be non-negative, got {w}")
# Ensure all dimensions have a weight (default 0 for unspecified)
full_weights = {dim: weights.get(dim, 0.0) for dim in DIMENSIONS}
# Persist
await _save_weights(db, full_weights)
# Recompute all composite scores
result = await db.execute(select(Ticker))
tickers = list(result.scalars().all())
for ticker in tickers:
await compute_composite_score(db, ticker.symbol, full_weights)
await db.commit()
return full_weights

View File

@@ -0,0 +1,131 @@
"""Sentiment service.
Stores sentiment records and computes the sentiment dimension score
using a time-decay weighted average over a configurable lookback window.
"""
from __future__ import annotations
import math
from datetime import datetime, timedelta, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.exceptions import NotFoundError
from app.models.sentiment import SentimentScore
from app.models.ticker import Ticker
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
"""Look up a ticker by symbol."""
normalised = symbol.strip().upper()
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
ticker = result.scalar_one_or_none()
if ticker is None:
raise NotFoundError(f"Ticker not found: {normalised}")
return ticker
async def store_sentiment(
db: AsyncSession,
symbol: str,
classification: str,
confidence: int,
source: str,
timestamp: datetime | None = None,
) -> SentimentScore:
"""Store a new sentiment record for a ticker."""
ticker = await _get_ticker(db, symbol)
if timestamp is None:
timestamp = datetime.now(timezone.utc)
record = SentimentScore(
ticker_id=ticker.id,
classification=classification,
confidence=confidence,
source=source,
timestamp=timestamp,
)
db.add(record)
await db.commit()
await db.refresh(record)
return record
async def get_sentiment_scores(
db: AsyncSession,
symbol: str,
lookback_hours: float = 24,
) -> list[SentimentScore]:
"""Get recent sentiment records within the lookback window."""
ticker = await _get_ticker(db, symbol)
cutoff = datetime.now(timezone.utc) - timedelta(hours=lookback_hours)
result = await db.execute(
select(SentimentScore)
.where(
SentimentScore.ticker_id == ticker.id,
SentimentScore.timestamp >= cutoff,
)
.order_by(SentimentScore.timestamp.desc())
)
return list(result.scalars().all())
def _classification_to_base_score(classification: str, confidence: int) -> float:
"""Map classification + confidence to a base score (0-100).
bullish → confidence (high confidence = high score)
bearish → 100 - confidence (high confidence bearish = low score)
neutral → 50
"""
cl = classification.lower()
if cl == "bullish":
return float(confidence)
elif cl == "bearish":
return float(100 - confidence)
else:
return 50.0
async def compute_sentiment_dimension_score(
db: AsyncSession,
symbol: str,
lookback_hours: float = 24,
decay_rate: float = 0.1,
) -> float | None:
"""Compute the sentiment dimension score using time-decay weighted average.
Returns a score in [0, 100] or None if no scores exist in the window.
Algorithm:
1. For each score in the lookback window, compute base_score from
classification + confidence.
2. Apply time decay: weight = exp(-decay_rate * hours_since_score).
3. Weighted average: sum(base_score * weight) / sum(weight).
"""
scores = await get_sentiment_scores(db, symbol, lookback_hours)
if not scores:
return None
now = datetime.now(timezone.utc)
weighted_sum = 0.0
weight_total = 0.0
for score in scores:
ts = score.timestamp
if ts.tzinfo is None:
ts = ts.replace(tzinfo=timezone.utc)
hours_since = (now - ts).total_seconds() / 3600.0
weight = math.exp(-decay_rate * hours_since)
base = _classification_to_base_score(score.classification, score.confidence)
weighted_sum += base * weight
weight_total += weight
if weight_total == 0:
return None
result = weighted_sum / weight_total
return max(0.0, min(100.0, result))

274
app/services/sr_service.py Normal file
View File

@@ -0,0 +1,274 @@
"""S/R Detector service.
Detects support/resistance levels from Volume Profile (HVN/LVN) and
Pivot Points (swing highs/lows), assigns strength scores, merges nearby
levels, tags as support/resistance, and persists to DB.
"""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.exceptions import NotFoundError, ValidationError
from app.models.sr_level import SRLevel
from app.models.ticker import Ticker
from app.services.indicator_service import (
_extract_ohlcv,
compute_pivot_points,
compute_volume_profile,
)
from app.services.price_service import query_ohlcv
DEFAULT_TOLERANCE = 0.005 # 0.5%
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
"""Look up a ticker by symbol."""
normalised = symbol.strip().upper()
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
ticker = result.scalar_one_or_none()
if ticker is None:
raise NotFoundError(f"Ticker not found: {normalised}")
return ticker
def _count_price_touches(
price_level: float,
highs: list[float],
lows: list[float],
closes: list[float],
tolerance: float = DEFAULT_TOLERANCE,
) -> int:
"""Count how many bars touched/respected a price level within tolerance."""
count = 0
tol = price_level * tolerance if price_level != 0 else tolerance
for i in range(len(closes)):
# A bar "touches" the level if the level is within the bar's range
# (within tolerance)
if lows[i] - tol <= price_level <= highs[i] + tol:
count += 1
return count
def _strength_from_touches(touches: int, total_bars: int) -> int:
"""Convert touch count to a 0-100 strength score.
More touches relative to total bars = higher strength.
Cap at 100.
"""
if total_bars == 0:
return 0
# Scale: each touch contributes proportionally, with a multiplier
# so that a level touched ~20% of bars gets score ~100
raw = (touches / total_bars) * 500.0
return max(0, min(100, int(round(raw))))
def _extract_candidate_levels(
highs: list[float],
lows: list[float],
closes: list[float],
volumes: list[int],
) -> list[tuple[float, str]]:
"""Extract candidate S/R levels from Volume Profile and Pivot Points.
Returns list of (price_level, detection_method) tuples.
"""
candidates: list[tuple[float, str]] = []
# Volume Profile: HVN and LVN as candidate levels
try:
vp = compute_volume_profile(highs, lows, closes, volumes)
for price in vp.get("hvn", []):
candidates.append((price, "volume_profile"))
for price in vp.get("lvn", []):
candidates.append((price, "volume_profile"))
except ValidationError:
pass # Not enough data for volume profile
# Pivot Points: swing highs and lows
try:
pp = compute_pivot_points(highs, lows, closes)
for price in pp.get("swing_highs", []):
candidates.append((price, "pivot_point"))
for price in pp.get("swing_lows", []):
candidates.append((price, "pivot_point"))
except ValidationError:
pass # Not enough data for pivot points
return candidates
def _merge_levels(
levels: list[dict],
tolerance: float = DEFAULT_TOLERANCE,
) -> list[dict]:
"""Merge levels within tolerance into consolidated levels.
Levels from different methods within tolerance are merged.
Merged levels combine strength scores (capped at 100) and get
detection_method = "merged".
"""
if not levels:
return []
# Sort by price
sorted_levels = sorted(levels, key=lambda x: x["price_level"])
merged: list[dict] = []
for level in sorted_levels:
if not merged:
merged.append(dict(level))
continue
last = merged[-1]
ref_price = last["price_level"]
tol = ref_price * tolerance if ref_price != 0 else tolerance
if abs(level["price_level"] - ref_price) <= tol:
# Merge: average price, combine strength, mark as merged
combined_strength = min(100, last["strength"] + level["strength"])
avg_price = (last["price_level"] + level["price_level"]) / 2.0
method = (
"merged"
if last["detection_method"] != level["detection_method"]
else last["detection_method"]
)
last["price_level"] = round(avg_price, 4)
last["strength"] = combined_strength
last["detection_method"] = method
else:
merged.append(dict(level))
return merged
def _tag_levels(
levels: list[dict],
current_price: float,
) -> list[dict]:
"""Tag each level as 'support' or 'resistance' relative to current price."""
for level in levels:
if level["price_level"] < current_price:
level["type"] = "support"
else:
level["type"] = "resistance"
return levels
def detect_sr_levels(
highs: list[float],
lows: list[float],
closes: list[float],
volumes: list[int],
tolerance: float = DEFAULT_TOLERANCE,
) -> list[dict]:
"""Detect, score, merge, and tag S/R levels from OHLCV data.
Returns list of dicts with keys: price_level, type, strength,
detection_method — sorted by strength descending.
"""
if not closes:
return []
candidates = _extract_candidate_levels(highs, lows, closes, volumes)
if not candidates:
return []
total_bars = len(closes)
current_price = closes[-1]
# Build level dicts with strength scores
raw_levels: list[dict] = []
for price, method in candidates:
touches = _count_price_touches(price, highs, lows, closes, tolerance)
strength = _strength_from_touches(touches, total_bars)
raw_levels.append({
"price_level": price,
"strength": strength,
"detection_method": method,
"type": "", # will be tagged after merge
})
# Merge nearby levels
merged = _merge_levels(raw_levels, tolerance)
# Tag as support/resistance
tagged = _tag_levels(merged, current_price)
# Sort by strength descending
tagged.sort(key=lambda x: x["strength"], reverse=True)
return tagged
async def recalculate_sr_levels(
db: AsyncSession,
symbol: str,
tolerance: float = DEFAULT_TOLERANCE,
) -> list[SRLevel]:
"""Recalculate S/R levels for a ticker and persist to DB.
1. Fetch OHLCV data
2. Detect levels
3. Delete old levels for ticker
4. Insert new levels
5. Return new levels sorted by strength desc
"""
ticker = await _get_ticker(db, symbol)
records = await query_ohlcv(db, symbol)
if not records:
# No OHLCV data — clear any existing levels
await db.execute(
delete(SRLevel).where(SRLevel.ticker_id == ticker.id)
)
await db.commit()
return []
_, highs, lows, closes, volumes = _extract_ohlcv(records)
levels = detect_sr_levels(highs, lows, closes, volumes, tolerance)
# Delete old levels
await db.execute(
delete(SRLevel).where(SRLevel.ticker_id == ticker.id)
)
# Insert new levels
now = datetime.utcnow()
new_models: list[SRLevel] = []
for lvl in levels:
model = SRLevel(
ticker_id=ticker.id,
price_level=lvl["price_level"],
type=lvl["type"],
strength=lvl["strength"],
detection_method=lvl["detection_method"],
created_at=now,
)
db.add(model)
new_models.append(model)
await db.commit()
# Refresh to get IDs
for m in new_models:
await db.refresh(m)
return new_models
async def get_sr_levels(
db: AsyncSession,
symbol: str,
tolerance: float = DEFAULT_TOLERANCE,
) -> list[SRLevel]:
"""Get S/R levels for a ticker, recalculating on every request (MVP).
Returns levels sorted by strength descending.
"""
return await recalculate_sr_levels(db, symbol, tolerance)

View File

@@ -0,0 +1,57 @@
"""Ticker Registry service: add, delete, and list tracked tickers."""
import re
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.exceptions import DuplicateError, NotFoundError, ValidationError
from app.models.ticker import Ticker
async def add_ticker(db: AsyncSession, symbol: str) -> Ticker:
"""Add a new ticker after validation.
Validates: non-empty, uppercase alphanumeric. Auto-uppercases input.
Raises DuplicateError if symbol already tracked.
"""
stripped = symbol.strip()
if not stripped:
raise ValidationError("Ticker symbol must not be empty or whitespace-only")
normalised = stripped.upper()
if not re.fullmatch(r"[A-Z0-9]+", normalised):
raise ValidationError(
f"Ticker symbol must be alphanumeric: {normalised}"
)
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
if result.scalar_one_or_none() is not None:
raise DuplicateError(f"Ticker already exists: {normalised}")
ticker = Ticker(symbol=normalised)
db.add(ticker)
await db.commit()
await db.refresh(ticker)
return ticker
async def delete_ticker(db: AsyncSession, symbol: str) -> None:
"""Delete a ticker and cascade all associated data.
Raises NotFoundError if the symbol is not tracked.
"""
normalised = symbol.strip().upper()
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
ticker = result.scalar_one_or_none()
if ticker is None:
raise NotFoundError(f"Ticker not found: {normalised}")
await db.delete(ticker)
await db.commit()
async def list_tickers(db: AsyncSession) -> list[Ticker]:
"""Return all tracked tickers sorted alphabetically by symbol."""
result = await db.execute(select(Ticker).order_by(Ticker.symbol.asc()))
return list(result.scalars().all())

View File

@@ -0,0 +1,288 @@
"""Watchlist service.
Auto-populates top-X tickers by composite score (default 10), supports
manual add/remove (tagged, not subject to auto-population), enforces
cap (auto + 10 manual, default max 20), and updates auto entries on
score recomputation.
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.exceptions import DuplicateError, NotFoundError, ValidationError
from app.models.score import CompositeScore, DimensionScore
from app.models.sr_level import SRLevel
from app.models.ticker import Ticker
from app.models.trade_setup import TradeSetup
from app.models.watchlist import WatchlistEntry
logger = logging.getLogger(__name__)
DEFAULT_AUTO_SIZE = 10
MAX_MANUAL = 10
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
normalised = symbol.strip().upper()
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
ticker = result.scalar_one_or_none()
if ticker is None:
raise NotFoundError(f"Ticker not found: {normalised}")
return ticker
async def auto_populate(
db: AsyncSession,
user_id: int,
top_x: int = DEFAULT_AUTO_SIZE,
) -> None:
"""Auto-populate watchlist with top-X tickers by composite score.
Replaces existing auto entries. Manual entries are untouched.
"""
# Get top-X tickers by composite score (non-stale, descending)
stmt = (
select(CompositeScore)
.where(CompositeScore.is_stale == False) # noqa: E712
.order_by(CompositeScore.score.desc())
.limit(top_x)
)
result = await db.execute(stmt)
top_scores = list(result.scalars().all())
top_ticker_ids = {cs.ticker_id for cs in top_scores}
# Delete existing auto entries for this user
await db.execute(
delete(WatchlistEntry).where(
WatchlistEntry.user_id == user_id,
WatchlistEntry.entry_type == "auto",
)
)
# Get manual ticker_ids so we don't duplicate
manual_result = await db.execute(
select(WatchlistEntry.ticker_id).where(
WatchlistEntry.user_id == user_id,
WatchlistEntry.entry_type == "manual",
)
)
manual_ticker_ids = {row[0] for row in manual_result.all()}
now = datetime.now(timezone.utc)
for ticker_id in top_ticker_ids:
if ticker_id in manual_ticker_ids:
continue # Already on watchlist as manual
entry = WatchlistEntry(
user_id=user_id,
ticker_id=ticker_id,
entry_type="auto",
added_at=now,
)
db.add(entry)
await db.flush()
async def add_manual_entry(
db: AsyncSession,
user_id: int,
symbol: str,
) -> WatchlistEntry:
"""Add a manual watchlist entry.
Raises DuplicateError if already on watchlist.
Raises ValidationError if manual cap exceeded.
"""
ticker = await _get_ticker(db, symbol)
# Check if already on watchlist
existing = await db.execute(
select(WatchlistEntry).where(
WatchlistEntry.user_id == user_id,
WatchlistEntry.ticker_id == ticker.id,
)
)
if existing.scalar_one_or_none() is not None:
raise DuplicateError(f"Ticker already on watchlist: {ticker.symbol}")
# Count current manual entries
count_result = await db.execute(
select(func.count()).select_from(WatchlistEntry).where(
WatchlistEntry.user_id == user_id,
WatchlistEntry.entry_type == "manual",
)
)
manual_count = count_result.scalar() or 0
if manual_count >= MAX_MANUAL:
raise ValidationError(
f"Manual watchlist cap reached ({MAX_MANUAL}). "
"Remove an entry before adding a new one."
)
# Check total cap
total_result = await db.execute(
select(func.count()).select_from(WatchlistEntry).where(
WatchlistEntry.user_id == user_id,
)
)
total_count = total_result.scalar() or 0
max_total = DEFAULT_AUTO_SIZE + MAX_MANUAL
if total_count >= max_total:
raise ValidationError(
f"Watchlist cap reached ({max_total}). "
"Remove an entry before adding a new one."
)
entry = WatchlistEntry(
user_id=user_id,
ticker_id=ticker.id,
entry_type="manual",
added_at=datetime.now(timezone.utc),
)
db.add(entry)
await db.commit()
await db.refresh(entry)
return entry
async def remove_entry(
db: AsyncSession,
user_id: int,
symbol: str,
) -> None:
"""Remove a watchlist entry (manual or auto)."""
ticker = await _get_ticker(db, symbol)
result = await db.execute(
select(WatchlistEntry).where(
WatchlistEntry.user_id == user_id,
WatchlistEntry.ticker_id == ticker.id,
)
)
entry = result.scalar_one_or_none()
if entry is None:
raise NotFoundError(f"Ticker not on watchlist: {ticker.symbol}")
await db.delete(entry)
await db.commit()
async def _enrich_entry(
db: AsyncSession,
entry: WatchlistEntry,
symbol: str,
) -> dict:
"""Build enriched watchlist entry dict with scores, R:R, and SR levels."""
ticker_id = entry.ticker_id
# Composite score
comp_result = await db.execute(
select(CompositeScore).where(CompositeScore.ticker_id == ticker_id)
)
comp = comp_result.scalar_one_or_none()
# Dimension scores
dim_result = await db.execute(
select(DimensionScore).where(DimensionScore.ticker_id == ticker_id)
)
dims = [
{"dimension": ds.dimension, "score": ds.score}
for ds in dim_result.scalars().all()
]
# Best trade setup (highest R:R) for this ticker
setup_result = await db.execute(
select(TradeSetup)
.where(TradeSetup.ticker_id == ticker_id)
.order_by(TradeSetup.rr_ratio.desc())
.limit(1)
)
setup = setup_result.scalar_one_or_none()
# Active SR levels
sr_result = await db.execute(
select(SRLevel)
.where(SRLevel.ticker_id == ticker_id)
.order_by(SRLevel.strength.desc())
)
sr_levels = [
{
"price_level": lv.price_level,
"type": lv.type,
"strength": lv.strength,
}
for lv in sr_result.scalars().all()
]
return {
"symbol": symbol,
"entry_type": entry.entry_type,
"composite_score": comp.score if comp else None,
"dimensions": dims,
"rr_ratio": setup.rr_ratio if setup else None,
"rr_direction": setup.direction if setup else None,
"sr_levels": sr_levels,
"added_at": entry.added_at,
}
async def get_watchlist(
db: AsyncSession,
user_id: int,
sort_by: str = "composite",
) -> list[dict]:
"""Get user's watchlist with enriched data.
Runs auto_populate first to ensure auto entries are current,
then enriches each entry with scores, R:R, and SR levels.
sort_by: "composite", "rr", or a dimension name
(e.g. "technical", "sr_quality", "sentiment", "fundamental", "momentum").
"""
# Auto-populate to refresh auto entries
await auto_populate(db, user_id)
await db.commit()
# Fetch all entries with ticker symbol
stmt = (
select(WatchlistEntry, Ticker.symbol)
.join(Ticker, WatchlistEntry.ticker_id == Ticker.id)
.where(WatchlistEntry.user_id == user_id)
)
result = await db.execute(stmt)
rows = result.all()
entries: list[dict] = []
for entry, symbol in rows:
enriched = await _enrich_entry(db, entry, symbol)
entries.append(enriched)
# Sort
if sort_by == "composite":
entries.sort(
key=lambda e: e["composite_score"] if e["composite_score"] is not None else -1,
reverse=True,
)
elif sort_by == "rr":
entries.sort(
key=lambda e: e["rr_ratio"] if e["rr_ratio"] is not None else -1,
reverse=True,
)
else:
# Sort by a specific dimension score
def _dim_sort_key(e: dict) -> float:
for d in e["dimensions"]:
if d["dimension"] == sort_by:
return d["score"]
return -1.0
entries.sort(key=_dim_sort_key, reverse=True)
return entries