first commit
This commit is contained in:
1
app/services/__init__.py
Normal file
1
app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
238
app/services/admin_service.py
Normal file
238
app/services/admin_service.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Admin service: user management, system settings, data cleanup, job control."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from passlib.hash import bcrypt
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import DuplicateError, NotFoundError, ValidationError
|
||||
from app.models.fundamental import FundamentalData
|
||||
from app.models.ohlcv import OHLCVRecord
|
||||
from app.models.sentiment import SentimentScore
|
||||
from app.models.settings import SystemSetting
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def list_users(db: AsyncSession) -> list[User]:
|
||||
"""Return all users ordered by id."""
|
||||
result = await db.execute(select(User).order_by(User.id))
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def create_user(
|
||||
db: AsyncSession,
|
||||
username: str,
|
||||
password: str,
|
||||
role: str = "user",
|
||||
has_access: bool = False,
|
||||
) -> User:
|
||||
"""Create a new user account (admin action)."""
|
||||
result = await db.execute(select(User).where(User.username == username))
|
||||
if result.scalar_one_or_none() is not None:
|
||||
raise DuplicateError(f"Username already exists: {username}")
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
password_hash=bcrypt.hash(password),
|
||||
role=role,
|
||||
has_access=has_access,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
async def set_user_access(db: AsyncSession, user_id: int, has_access: bool) -> User:
|
||||
"""Grant or revoke API access for a user."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise NotFoundError(f"User not found: {user_id}")
|
||||
|
||||
user.has_access = has_access
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
async def reset_password(db: AsyncSession, user_id: int, new_password: str) -> User:
|
||||
"""Reset a user's password."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise NotFoundError(f"User not found: {user_id}")
|
||||
|
||||
user.password_hash = bcrypt.hash(new_password)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registration toggle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def toggle_registration(db: AsyncSession, enabled: bool) -> SystemSetting:
|
||||
"""Enable or disable user registration via SystemSetting."""
|
||||
result = await db.execute(
|
||||
select(SystemSetting).where(SystemSetting.key == "registration_enabled")
|
||||
)
|
||||
setting = result.scalar_one_or_none()
|
||||
value = str(enabled).lower()
|
||||
|
||||
if setting is None:
|
||||
setting = SystemSetting(key="registration_enabled", value=value)
|
||||
db.add(setting)
|
||||
else:
|
||||
setting.value = value
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(setting)
|
||||
return setting
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System settings CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def list_settings(db: AsyncSession) -> list[SystemSetting]:
|
||||
"""Return all system settings."""
|
||||
result = await db.execute(select(SystemSetting).order_by(SystemSetting.key))
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def update_setting(db: AsyncSession, key: str, value: str) -> SystemSetting:
|
||||
"""Create or update a system setting."""
|
||||
result = await db.execute(
|
||||
select(SystemSetting).where(SystemSetting.key == key)
|
||||
)
|
||||
setting = result.scalar_one_or_none()
|
||||
|
||||
if setting is None:
|
||||
setting = SystemSetting(key=key, value=value)
|
||||
db.add(setting)
|
||||
else:
|
||||
setting.value = value
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(setting)
|
||||
return setting
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def cleanup_data(db: AsyncSession, older_than_days: int) -> dict[str, int]:
|
||||
"""Delete OHLCV, sentiment, and fundamental records older than N days.
|
||||
|
||||
Preserves tickers, users, and latest scores.
|
||||
Returns a dict with counts of deleted records per table.
|
||||
"""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=older_than_days)
|
||||
counts: dict[str, int] = {}
|
||||
|
||||
# OHLCV — date column is a date, compare with cutoff date
|
||||
result = await db.execute(
|
||||
delete(OHLCVRecord).where(OHLCVRecord.date < cutoff.date())
|
||||
)
|
||||
counts["ohlcv"] = result.rowcount # type: ignore[assignment]
|
||||
|
||||
# Sentiment — timestamp is datetime
|
||||
result = await db.execute(
|
||||
delete(SentimentScore).where(SentimentScore.timestamp < cutoff)
|
||||
)
|
||||
counts["sentiment"] = result.rowcount # type: ignore[assignment]
|
||||
|
||||
# Fundamentals — fetched_at is datetime
|
||||
result = await db.execute(
|
||||
delete(FundamentalData).where(FundamentalData.fetched_at < cutoff)
|
||||
)
|
||||
counts["fundamentals"] = result.rowcount # type: ignore[assignment]
|
||||
|
||||
await db.commit()
|
||||
return counts
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Job control (placeholder — scheduler is Task 12.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
VALID_JOB_NAMES = {"data_collector", "sentiment_collector", "fundamental_collector", "rr_scanner"}
|
||||
|
||||
JOB_LABELS = {
|
||||
"data_collector": "Data Collector (OHLCV)",
|
||||
"sentiment_collector": "Sentiment Collector",
|
||||
"fundamental_collector": "Fundamental Collector",
|
||||
"rr_scanner": "R:R Scanner",
|
||||
}
|
||||
|
||||
|
||||
async def list_jobs(db: AsyncSession) -> list[dict]:
|
||||
"""Return status of all scheduled jobs."""
|
||||
from app.scheduler import scheduler
|
||||
|
||||
jobs_out = []
|
||||
for name in sorted(VALID_JOB_NAMES):
|
||||
# Check enabled setting
|
||||
key = f"job_{name}_enabled"
|
||||
result = await db.execute(
|
||||
select(SystemSetting).where(SystemSetting.key == key)
|
||||
)
|
||||
setting = result.scalar_one_or_none()
|
||||
enabled = setting.value == "true" if setting else True # default enabled
|
||||
|
||||
# Get scheduler job info
|
||||
job = scheduler.get_job(name)
|
||||
next_run = None
|
||||
if job and job.next_run_time:
|
||||
next_run = job.next_run_time.isoformat()
|
||||
|
||||
jobs_out.append({
|
||||
"name": name,
|
||||
"label": JOB_LABELS.get(name, name),
|
||||
"enabled": enabled,
|
||||
"next_run_at": next_run,
|
||||
"registered": job is not None,
|
||||
})
|
||||
|
||||
return jobs_out
|
||||
|
||||
|
||||
async def trigger_job(db: AsyncSession, job_name: str) -> dict[str, str]:
|
||||
"""Trigger a manual job run via the scheduler.
|
||||
|
||||
Runs the job immediately (in addition to its regular schedule).
|
||||
"""
|
||||
if job_name not in VALID_JOB_NAMES:
|
||||
raise ValidationError(f"Unknown job: {job_name}. Valid jobs: {', '.join(sorted(VALID_JOB_NAMES))}")
|
||||
|
||||
from app.scheduler import scheduler
|
||||
|
||||
job = scheduler.get_job(job_name)
|
||||
if job is None:
|
||||
return {"job": job_name, "status": "not_found", "message": f"Job '{job_name}' is not registered in the scheduler"}
|
||||
|
||||
job.modify(next_run_time=None) # Reset, then trigger immediately
|
||||
from datetime import datetime, timezone
|
||||
job.modify(next_run_time=datetime.now(timezone.utc))
|
||||
|
||||
return {"job": job_name, "status": "triggered", "message": f"Job '{job_name}' triggered for immediate execution"}
|
||||
|
||||
|
||||
async def toggle_job(db: AsyncSession, job_name: str, enabled: bool) -> SystemSetting:
|
||||
"""Enable or disable a scheduled job by storing state in SystemSetting.
|
||||
|
||||
Actual scheduler integration happens in Task 12.1.
|
||||
"""
|
||||
if job_name not in VALID_JOB_NAMES:
|
||||
raise ValidationError(f"Unknown job: {job_name}. Valid jobs: {', '.join(sorted(VALID_JOB_NAMES))}")
|
||||
|
||||
key = f"job_{job_name}_enabled"
|
||||
return await update_setting(db, key, str(enabled).lower())
|
||||
66
app/services/auth_service.py
Normal file
66
app/services/auth_service.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Auth service: registration, login, and JWT token generation."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from jose import jwt
|
||||
from passlib.hash import bcrypt
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.dependencies import JWT_ALGORITHM
|
||||
from app.exceptions import AuthenticationError, AuthorizationError, DuplicateError
|
||||
from app.models.settings import SystemSetting
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
async def register(db: AsyncSession, username: str, password: str) -> User:
|
||||
"""Register a new user.
|
||||
|
||||
Checks if registration is enabled via SystemSetting, rejects duplicates,
|
||||
and creates a user with role='user' and has_access=False.
|
||||
"""
|
||||
# Check registration toggle
|
||||
result = await db.execute(
|
||||
select(SystemSetting).where(SystemSetting.key == "registration_enabled")
|
||||
)
|
||||
setting = result.scalar_one_or_none()
|
||||
if setting is not None and setting.value.lower() == "false":
|
||||
raise AuthorizationError("Registration is closed")
|
||||
|
||||
# Check duplicate username
|
||||
result = await db.execute(select(User).where(User.username == username))
|
||||
if result.scalar_one_or_none() is not None:
|
||||
raise DuplicateError(f"Username already exists: {username}")
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
password_hash=bcrypt.hash(password),
|
||||
role="user",
|
||||
has_access=False,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
async def login(db: AsyncSession, username: str, password: str) -> str:
|
||||
"""Authenticate user and return a JWT access token.
|
||||
|
||||
Returns the same error message for wrong username or wrong password
|
||||
to avoid leaking which field is incorrect.
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.username == username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None or not bcrypt.verify(password, user.password_hash):
|
||||
raise AuthenticationError("Invalid credentials")
|
||||
|
||||
payload = {
|
||||
"sub": str(user.id),
|
||||
"role": user.role,
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=settings.jwt_expiry_minutes),
|
||||
}
|
||||
token = jwt.encode(payload, settings.jwt_secret, algorithm=JWT_ALGORITHM)
|
||||
return token
|
||||
101
app/services/fundamental_service.py
Normal file
101
app/services/fundamental_service.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Fundamental data service.
|
||||
|
||||
Stores fundamental data (P/E, revenue growth, earnings surprise, market cap)
|
||||
and marks the fundamental dimension score as stale on new data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import NotFoundError
|
||||
from app.models.fundamental import FundamentalData
|
||||
from app.models.score import DimensionScore
|
||||
from app.models.ticker import Ticker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
"""Look up a ticker by symbol."""
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
async def store_fundamental(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
pe_ratio: float | None = None,
|
||||
revenue_growth: float | None = None,
|
||||
earnings_surprise: float | None = None,
|
||||
market_cap: float | None = None,
|
||||
) -> FundamentalData:
|
||||
"""Store or update fundamental data for a ticker.
|
||||
|
||||
Keeps a single latest snapshot per ticker. On new data, marks the
|
||||
fundamental dimension score as stale (if one exists).
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
# Check for existing record
|
||||
result = await db.execute(
|
||||
select(FundamentalData).where(FundamentalData.ticker_id == ticker.id)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if existing is not None:
|
||||
existing.pe_ratio = pe_ratio
|
||||
existing.revenue_growth = revenue_growth
|
||||
existing.earnings_surprise = earnings_surprise
|
||||
existing.market_cap = market_cap
|
||||
existing.fetched_at = now
|
||||
record = existing
|
||||
else:
|
||||
record = FundamentalData(
|
||||
ticker_id=ticker.id,
|
||||
pe_ratio=pe_ratio,
|
||||
revenue_growth=revenue_growth,
|
||||
earnings_surprise=earnings_surprise,
|
||||
market_cap=market_cap,
|
||||
fetched_at=now,
|
||||
)
|
||||
db.add(record)
|
||||
|
||||
# Mark fundamental dimension score as stale if it exists
|
||||
# TODO: Use DimensionScore service when built
|
||||
dim_result = await db.execute(
|
||||
select(DimensionScore).where(
|
||||
DimensionScore.ticker_id == ticker.id,
|
||||
DimensionScore.dimension == "fundamental",
|
||||
)
|
||||
)
|
||||
dim_score = dim_result.scalar_one_or_none()
|
||||
if dim_score is not None:
|
||||
dim_score.is_stale = True
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(record)
|
||||
return record
|
||||
|
||||
|
||||
async def get_fundamental(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
) -> FundamentalData | None:
|
||||
"""Get the latest fundamental data for a ticker."""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
result = await db.execute(
|
||||
select(FundamentalData).where(FundamentalData.ticker_id == ticker.id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
509
app/services/indicator_service.py
Normal file
509
app/services/indicator_service.py
Normal file
@@ -0,0 +1,509 @@
|
||||
"""Technical Analysis service.
|
||||
|
||||
Computes indicators from OHLCV data. Each indicator function is a pure
|
||||
function that takes a list of OHLCV-like records and returns raw values
|
||||
plus a normalized 0-100 score. The service layer handles DB fetching,
|
||||
caching, and minimum-data validation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.cache import indicator_cache
|
||||
from app.exceptions import ValidationError
|
||||
from app.services.price_service import query_ohlcv
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimum data requirements per indicator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MIN_BARS: dict[str, int] = {
|
||||
"adx": 28,
|
||||
"ema": 0, # dynamic: period + 1
|
||||
"rsi": 15,
|
||||
"atr": 15,
|
||||
"volume_profile": 20,
|
||||
"pivot_points": 5,
|
||||
}
|
||||
|
||||
DEFAULT_PERIODS: dict[str, int] = {
|
||||
"adx": 14,
|
||||
"ema": 20,
|
||||
"rsi": 14,
|
||||
"atr": 14,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure computation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ema(values: list[float], period: int) -> list[float]:
|
||||
"""Compute EMA series. Returns list same length as *values*."""
|
||||
if len(values) < period:
|
||||
return []
|
||||
k = 2.0 / (period + 1)
|
||||
ema_vals: list[float] = [sum(values[:period]) / period]
|
||||
for v in values[period:]:
|
||||
ema_vals.append(v * k + ema_vals[-1] * (1 - k))
|
||||
return ema_vals
|
||||
|
||||
|
||||
def compute_adx(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
period: int = 14,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute ADX from high/low/close arrays.
|
||||
|
||||
Returns dict with ``adx``, ``plus_di``, ``minus_di``, ``score``.
|
||||
"""
|
||||
n = len(closes)
|
||||
if n < 2 * period:
|
||||
raise ValidationError(
|
||||
f"ADX requires at least {2 * period} bars, got {n}"
|
||||
)
|
||||
|
||||
# True Range, +DM, -DM
|
||||
tr_list: list[float] = []
|
||||
plus_dm: list[float] = []
|
||||
minus_dm: list[float] = []
|
||||
for i in range(1, n):
|
||||
h, l, pc = highs[i], lows[i], closes[i - 1]
|
||||
tr_list.append(max(h - l, abs(h - pc), abs(l - pc)))
|
||||
up = highs[i] - highs[i - 1]
|
||||
down = lows[i - 1] - lows[i]
|
||||
plus_dm.append(up if up > down and up > 0 else 0.0)
|
||||
minus_dm.append(down if down > up and down > 0 else 0.0)
|
||||
|
||||
# Smoothed TR, +DM, -DM (Wilder smoothing)
|
||||
def _smooth(vals: list[float], p: int) -> list[float]:
|
||||
s = [sum(vals[:p])]
|
||||
for v in vals[p:]:
|
||||
s.append(s[-1] - s[-1] / p + v)
|
||||
return s
|
||||
|
||||
s_tr = _smooth(tr_list, period)
|
||||
s_plus = _smooth(plus_dm, period)
|
||||
s_minus = _smooth(minus_dm, period)
|
||||
|
||||
# +DI, -DI, DX
|
||||
dx_list: list[float] = []
|
||||
plus_di_last = 0.0
|
||||
minus_di_last = 0.0
|
||||
for i in range(len(s_tr)):
|
||||
tr_v = s_tr[i] if s_tr[i] != 0 else 1e-10
|
||||
pdi = 100.0 * s_plus[i] / tr_v
|
||||
mdi = 100.0 * s_minus[i] / tr_v
|
||||
denom = pdi + mdi if (pdi + mdi) != 0 else 1e-10
|
||||
dx_list.append(100.0 * abs(pdi - mdi) / denom)
|
||||
plus_di_last = pdi
|
||||
minus_di_last = mdi
|
||||
|
||||
# ADX = smoothed DX
|
||||
if len(dx_list) < period:
|
||||
adx_val = sum(dx_list) / len(dx_list) if dx_list else 0.0
|
||||
else:
|
||||
adx_vals = _smooth(dx_list, period)
|
||||
adx_val = adx_vals[-1]
|
||||
|
||||
score = max(0.0, min(100.0, adx_val))
|
||||
|
||||
return {
|
||||
"adx": round(adx_val, 4),
|
||||
"plus_di": round(plus_di_last, 4),
|
||||
"minus_di": round(minus_di_last, 4),
|
||||
"score": round(score, 4),
|
||||
}
|
||||
|
||||
|
||||
def compute_ema(
|
||||
closes: list[float],
|
||||
period: int = 20,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute EMA for *closes* with given *period*.
|
||||
|
||||
Score: normalized position of latest close relative to EMA.
|
||||
Above EMA → higher score, below → lower.
|
||||
"""
|
||||
min_bars = period + 1
|
||||
if len(closes) < min_bars:
|
||||
raise ValidationError(
|
||||
f"EMA({period}) requires at least {min_bars} bars, got {len(closes)}"
|
||||
)
|
||||
|
||||
ema_vals = _ema(closes, period)
|
||||
latest_ema = ema_vals[-1]
|
||||
latest_close = closes[-1]
|
||||
|
||||
# Score: 50 = at EMA, 100 = 5%+ above, 0 = 5%+ below
|
||||
if latest_ema == 0:
|
||||
pct = 0.0
|
||||
else:
|
||||
pct = (latest_close - latest_ema) / latest_ema * 100.0
|
||||
score = max(0.0, min(100.0, 50.0 + pct * 10.0))
|
||||
|
||||
return {
|
||||
"ema": round(latest_ema, 4),
|
||||
"period": period,
|
||||
"latest_close": round(latest_close, 4),
|
||||
"score": round(score, 4),
|
||||
}
|
||||
|
||||
|
||||
def compute_rsi(
|
||||
closes: list[float],
|
||||
period: int = 14,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute RSI. Score = RSI value (already 0-100)."""
|
||||
n = len(closes)
|
||||
if n < period + 1:
|
||||
raise ValidationError(
|
||||
f"RSI requires at least {period + 1} bars, got {n}"
|
||||
)
|
||||
|
||||
deltas = [closes[i] - closes[i - 1] for i in range(1, n)]
|
||||
gains = [d if d > 0 else 0.0 for d in deltas]
|
||||
losses = [-d if d < 0 else 0.0 for d in deltas]
|
||||
|
||||
avg_gain = sum(gains[:period]) / period
|
||||
avg_loss = sum(losses[:period]) / period
|
||||
|
||||
for i in range(period, len(deltas)):
|
||||
avg_gain = (avg_gain * (period - 1) + gains[i]) / period
|
||||
avg_loss = (avg_loss * (period - 1) + losses[i]) / period
|
||||
|
||||
if avg_loss == 0:
|
||||
rsi = 100.0
|
||||
else:
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100.0 - 100.0 / (1.0 + rs)
|
||||
|
||||
score = max(0.0, min(100.0, rsi))
|
||||
|
||||
return {
|
||||
"rsi": round(rsi, 4),
|
||||
"period": period,
|
||||
"score": round(score, 4),
|
||||
}
|
||||
|
||||
|
||||
def compute_atr(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
period: int = 14,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute ATR. Score = normalized inverse (lower ATR = higher score)."""
|
||||
n = len(closes)
|
||||
if n < period + 1:
|
||||
raise ValidationError(
|
||||
f"ATR requires at least {period + 1} bars, got {n}"
|
||||
)
|
||||
|
||||
tr_list: list[float] = []
|
||||
for i in range(1, n):
|
||||
h, l, pc = highs[i], lows[i], closes[i - 1]
|
||||
tr_list.append(max(h - l, abs(h - pc), abs(l - pc)))
|
||||
|
||||
# Wilder smoothing
|
||||
atr = sum(tr_list[:period]) / period
|
||||
for tr in tr_list[period:]:
|
||||
atr = (atr * (period - 1) + tr) / period
|
||||
|
||||
# Score: inverse normalized. ATR as % of price; lower = higher score.
|
||||
latest_close = closes[-1]
|
||||
if latest_close == 0:
|
||||
atr_pct = 0.0
|
||||
else:
|
||||
atr_pct = atr / latest_close * 100.0
|
||||
# 0% ATR → 100 score, 10%+ ATR → 0 score
|
||||
score = max(0.0, min(100.0, 100.0 - atr_pct * 10.0))
|
||||
|
||||
return {
|
||||
"atr": round(atr, 4),
|
||||
"period": period,
|
||||
"atr_percent": round(atr_pct, 4),
|
||||
"score": round(score, 4),
|
||||
}
|
||||
|
||||
|
||||
def compute_volume_profile(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
volumes: list[int],
|
||||
num_bins: int = 20,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute Volume Profile: POC, Value Area, HVN, LVN.
|
||||
|
||||
Score: proximity of latest close to POC (closer = higher).
|
||||
"""
|
||||
n = len(closes)
|
||||
if n < 20:
|
||||
raise ValidationError(
|
||||
f"Volume Profile requires at least 20 bars, got {n}"
|
||||
)
|
||||
|
||||
price_min = min(lows)
|
||||
price_max = max(highs)
|
||||
if price_max == price_min:
|
||||
price_max = price_min + 1.0 # avoid zero-width range
|
||||
|
||||
bin_width = (price_max - price_min) / num_bins
|
||||
bins: list[float] = [0.0] * num_bins
|
||||
bin_prices: list[float] = [
|
||||
price_min + (i + 0.5) * bin_width for i in range(num_bins)
|
||||
]
|
||||
|
||||
for i in range(n):
|
||||
# Distribute volume across bins the bar spans
|
||||
bar_low, bar_high = lows[i], highs[i]
|
||||
for b in range(num_bins):
|
||||
bl = price_min + b * bin_width
|
||||
bh = bl + bin_width
|
||||
if bar_high >= bl and bar_low <= bh:
|
||||
bins[b] += volumes[i]
|
||||
|
||||
total_vol = sum(bins)
|
||||
if total_vol == 0:
|
||||
total_vol = 1.0
|
||||
|
||||
# POC = bin with highest volume
|
||||
poc_idx = bins.index(max(bins))
|
||||
poc = round(bin_prices[poc_idx], 4)
|
||||
|
||||
# Value Area: 70% of total volume around POC
|
||||
sorted_bins = sorted(range(num_bins), key=lambda i: bins[i], reverse=True)
|
||||
va_vol = 0.0
|
||||
va_indices: list[int] = []
|
||||
for idx in sorted_bins:
|
||||
va_vol += bins[idx]
|
||||
va_indices.append(idx)
|
||||
if va_vol >= total_vol * 0.7:
|
||||
break
|
||||
va_low = round(price_min + min(va_indices) * bin_width, 4)
|
||||
va_high = round(price_min + (max(va_indices) + 1) * bin_width, 4)
|
||||
|
||||
# HVN / LVN: bins above/below average volume
|
||||
avg_vol = total_vol / num_bins
|
||||
hvn = [round(bin_prices[i], 4) for i in range(num_bins) if bins[i] > avg_vol]
|
||||
lvn = [round(bin_prices[i], 4) for i in range(num_bins) if bins[i] < avg_vol]
|
||||
|
||||
# Score: proximity of latest close to POC
|
||||
latest = closes[-1]
|
||||
price_range = price_max - price_min
|
||||
if price_range == 0:
|
||||
score = 100.0
|
||||
else:
|
||||
dist_pct = abs(latest - poc) / price_range
|
||||
score = max(0.0, min(100.0, 100.0 * (1.0 - dist_pct)))
|
||||
|
||||
return {
|
||||
"poc": poc,
|
||||
"value_area_low": va_low,
|
||||
"value_area_high": va_high,
|
||||
"hvn": hvn,
|
||||
"lvn": lvn,
|
||||
"score": round(score, 4),
|
||||
}
|
||||
|
||||
|
||||
def compute_pivot_points(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
window: int = 2,
|
||||
) -> dict[str, Any]:
|
||||
"""Detect swing highs/lows as pivot points.
|
||||
|
||||
A swing high at index *i* means highs[i] >= all highs in [i-window, i+window].
|
||||
Score: based on number of pivots near current price.
|
||||
"""
|
||||
n = len(closes)
|
||||
if n < 5:
|
||||
raise ValidationError(
|
||||
f"Pivot Points requires at least 5 bars, got {n}"
|
||||
)
|
||||
|
||||
swing_highs: list[float] = []
|
||||
swing_lows: list[float] = []
|
||||
|
||||
for i in range(window, n - window):
|
||||
# Swing high
|
||||
if all(highs[i] >= highs[j] for j in range(i - window, i + window + 1)):
|
||||
swing_highs.append(round(highs[i], 4))
|
||||
# Swing low
|
||||
if all(lows[i] <= lows[j] for j in range(i - window, i + window + 1)):
|
||||
swing_lows.append(round(lows[i], 4))
|
||||
|
||||
all_pivots = swing_highs + swing_lows
|
||||
latest = closes[-1]
|
||||
|
||||
# Score: fraction of pivots within 2% of current price → 0-100
|
||||
if not all_pivots or latest == 0:
|
||||
score = 0.0
|
||||
else:
|
||||
near = sum(1 for p in all_pivots if abs(p - latest) / latest <= 0.02)
|
||||
score = min(100.0, (near / max(len(all_pivots), 1)) * 100.0)
|
||||
|
||||
return {
|
||||
"swing_highs": swing_highs,
|
||||
"swing_lows": swing_lows,
|
||||
"pivot_count": len(all_pivots),
|
||||
"score": round(score, 4),
|
||||
}
|
||||
|
||||
|
||||
def compute_ema_cross(
|
||||
closes: list[float],
|
||||
short_period: int = 20,
|
||||
long_period: int = 50,
|
||||
tolerance: float = 1e-6,
|
||||
) -> dict[str, Any]:
|
||||
"""Compare short EMA vs long EMA.
|
||||
|
||||
Returns signal: bullish (short > long), bearish (short < long),
|
||||
neutral (within tolerance).
|
||||
"""
|
||||
min_bars = long_period + 1
|
||||
if len(closes) < min_bars:
|
||||
raise ValidationError(
|
||||
f"EMA Cross requires at least {min_bars} bars, got {len(closes)}"
|
||||
)
|
||||
|
||||
short_ema_vals = _ema(closes, short_period)
|
||||
long_ema_vals = _ema(closes, long_period)
|
||||
|
||||
short_ema = short_ema_vals[-1]
|
||||
long_ema = long_ema_vals[-1]
|
||||
|
||||
diff = short_ema - long_ema
|
||||
if abs(diff) <= tolerance:
|
||||
signal = "neutral"
|
||||
elif diff > 0:
|
||||
signal = "bullish"
|
||||
else:
|
||||
signal = "bearish"
|
||||
|
||||
return {
|
||||
"short_ema": round(short_ema, 4),
|
||||
"long_ema": round(long_ema, 4),
|
||||
"short_period": short_period,
|
||||
"long_period": long_period,
|
||||
"signal": signal,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Supported indicator types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
INDICATOR_TYPES = {"adx", "ema", "rsi", "atr", "volume_profile", "pivot_points"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service-layer functions (DB + cache + validation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _extract_ohlcv(records: list) -> tuple[
|
||||
list[float], list[float], list[float], list[float], list[int]
|
||||
]:
|
||||
"""Extract parallel arrays from OHLCVRecord list."""
|
||||
opens = [float(r.open) for r in records]
|
||||
highs = [float(r.high) for r in records]
|
||||
lows = [float(r.low) for r in records]
|
||||
closes = [float(r.close) for r in records]
|
||||
volumes = [int(r.volume) for r in records]
|
||||
return opens, highs, lows, closes, volumes
|
||||
|
||||
|
||||
async def get_indicator(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
indicator_type: str,
|
||||
start_date: date | None = None,
|
||||
end_date: date | None = None,
|
||||
period: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute a single indicator for *symbol*.
|
||||
|
||||
Checks cache first; stores result after computing.
|
||||
"""
|
||||
indicator_type = indicator_type.lower()
|
||||
if indicator_type not in INDICATOR_TYPES:
|
||||
raise ValidationError(
|
||||
f"Unknown indicator type: {indicator_type}. "
|
||||
f"Supported: {', '.join(sorted(INDICATOR_TYPES))}"
|
||||
)
|
||||
|
||||
cache_key = (symbol.upper(), str(start_date), str(end_date), indicator_type)
|
||||
cached = indicator_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
records = await query_ohlcv(db, symbol, start_date, end_date)
|
||||
_, highs, lows, closes, volumes = _extract_ohlcv(records)
|
||||
n = len(records)
|
||||
|
||||
if indicator_type == "adx":
|
||||
p = period or DEFAULT_PERIODS["adx"]
|
||||
result = compute_adx(highs, lows, closes, period=p)
|
||||
elif indicator_type == "ema":
|
||||
p = period or DEFAULT_PERIODS["ema"]
|
||||
result = compute_ema(closes, period=p)
|
||||
elif indicator_type == "rsi":
|
||||
p = period or DEFAULT_PERIODS["rsi"]
|
||||
result = compute_rsi(closes, period=p)
|
||||
elif indicator_type == "atr":
|
||||
p = period or DEFAULT_PERIODS["atr"]
|
||||
result = compute_atr(highs, lows, closes, period=p)
|
||||
elif indicator_type == "volume_profile":
|
||||
result = compute_volume_profile(highs, lows, closes, volumes)
|
||||
elif indicator_type == "pivot_points":
|
||||
result = compute_pivot_points(highs, lows, closes)
|
||||
else:
|
||||
raise ValidationError(f"Unknown indicator type: {indicator_type}")
|
||||
|
||||
response = {
|
||||
"indicator_type": indicator_type,
|
||||
"values": {k: v for k, v in result.items() if k != "score"},
|
||||
"score": result["score"],
|
||||
"bars_used": n,
|
||||
}
|
||||
|
||||
indicator_cache.set(cache_key, response)
|
||||
return response
|
||||
|
||||
|
||||
async def get_ema_cross(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
start_date: date | None = None,
|
||||
end_date: date | None = None,
|
||||
short_period: int = 20,
|
||||
long_period: int = 50,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute EMA cross signal for *symbol*."""
|
||||
cache_key = (
|
||||
symbol.upper(),
|
||||
str(start_date),
|
||||
str(end_date),
|
||||
f"ema_cross_{short_period}_{long_period}",
|
||||
)
|
||||
cached = indicator_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
records = await query_ohlcv(db, symbol, start_date, end_date)
|
||||
_, _, _, closes, _ = _extract_ohlcv(records)
|
||||
|
||||
result = compute_ema_cross(closes, short_period, long_period)
|
||||
|
||||
indicator_cache.set(cache_key, result)
|
||||
return result
|
||||
172
app/services/ingestion_service.py
Normal file
172
app/services/ingestion_service.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Ingestion Pipeline service: fetch from provider, validate, upsert into Price Store.
|
||||
|
||||
Handles rate-limit resume via IngestionProgress and provider error isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, timedelta
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import NotFoundError, ProviderError, RateLimitError
|
||||
from app.models.settings import IngestionProgress
|
||||
from app.models.ticker import Ticker
|
||||
from app.providers.protocol import MarketDataProvider
|
||||
from app.services import price_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IngestionResult:
|
||||
"""Result of an ingestion run."""
|
||||
|
||||
symbol: str
|
||||
records_ingested: int
|
||||
last_date: date | None
|
||||
status: str # "complete" | "partial" | "error"
|
||||
message: str | None = None
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
"""Look up ticker by symbol. Raises NotFoundError if missing."""
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
async def _get_progress(db: AsyncSession, ticker_id: int) -> IngestionProgress | None:
|
||||
"""Get the IngestionProgress record for a ticker, if any."""
|
||||
result = await db.execute(
|
||||
select(IngestionProgress).where(IngestionProgress.ticker_id == ticker_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def _update_progress(
|
||||
db: AsyncSession, ticker_id: int, last_date: date
|
||||
) -> None:
|
||||
"""Create or update the IngestionProgress record for a ticker."""
|
||||
progress = await _get_progress(db, ticker_id)
|
||||
if progress is None:
|
||||
progress = IngestionProgress(ticker_id=ticker_id, last_ingested_date=last_date)
|
||||
db.add(progress)
|
||||
else:
|
||||
progress.last_ingested_date = last_date
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def fetch_and_ingest(
|
||||
db: AsyncSession,
|
||||
provider: MarketDataProvider,
|
||||
symbol: str,
|
||||
start_date: date | None = None,
|
||||
end_date: date | None = None,
|
||||
) -> IngestionResult:
|
||||
"""Fetch OHLCV data from provider and upsert into Price Store.
|
||||
|
||||
- Resolves start_date from IngestionProgress if not provided (resume).
|
||||
- Defaults end_date to today.
|
||||
- Tracks last_ingested_date after each successful upsert.
|
||||
- On RateLimitError from provider: returns partial progress.
|
||||
- On ProviderError: returns error, no data modification.
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
# Resolve end_date
|
||||
if end_date is None:
|
||||
end_date = date.today()
|
||||
|
||||
# Resolve start_date: use progress resume or default to 1 year ago
|
||||
if start_date is None:
|
||||
progress = await _get_progress(db, ticker.id)
|
||||
if progress is not None:
|
||||
start_date = progress.last_ingested_date + timedelta(days=1)
|
||||
else:
|
||||
start_date = end_date - timedelta(days=365)
|
||||
|
||||
# If start > end, nothing to fetch
|
||||
if start_date > end_date:
|
||||
return IngestionResult(
|
||||
symbol=ticker.symbol,
|
||||
records_ingested=0,
|
||||
last_date=None,
|
||||
status="complete",
|
||||
message="Already up to date",
|
||||
)
|
||||
|
||||
# Fetch from provider
|
||||
try:
|
||||
records = await provider.fetch_ohlcv(ticker.symbol, start_date, end_date)
|
||||
except RateLimitError:
|
||||
# No data fetched at all — return partial with 0 records
|
||||
return IngestionResult(
|
||||
symbol=ticker.symbol,
|
||||
records_ingested=0,
|
||||
last_date=None,
|
||||
status="partial",
|
||||
message="Rate limited before any records fetched. Resume available.",
|
||||
)
|
||||
except ProviderError as exc:
|
||||
logger.error("Provider error for %s: %s", ticker.symbol, exc)
|
||||
return IngestionResult(
|
||||
symbol=ticker.symbol,
|
||||
records_ingested=0,
|
||||
last_date=None,
|
||||
status="error",
|
||||
message=str(exc),
|
||||
)
|
||||
|
||||
# Sort records by date to ensure ordered ingestion
|
||||
records.sort(key=lambda r: r.date)
|
||||
|
||||
ingested_count = 0
|
||||
last_ingested: date | None = None
|
||||
|
||||
for record in records:
|
||||
try:
|
||||
await price_service.upsert_ohlcv(
|
||||
db,
|
||||
symbol=ticker.symbol,
|
||||
record_date=record.date,
|
||||
open_=record.open,
|
||||
high=record.high,
|
||||
low=record.low,
|
||||
close=record.close,
|
||||
volume=record.volume,
|
||||
)
|
||||
ingested_count += 1
|
||||
last_ingested = record.date
|
||||
|
||||
# Update progress after each successful upsert
|
||||
await _update_progress(db, ticker.id, record.date)
|
||||
|
||||
except RateLimitError:
|
||||
# Mid-ingestion rate limit — return partial progress
|
||||
logger.warning(
|
||||
"Rate limited during ingestion for %s after %d records",
|
||||
ticker.symbol,
|
||||
ingested_count,
|
||||
)
|
||||
return IngestionResult(
|
||||
symbol=ticker.symbol,
|
||||
records_ingested=ingested_count,
|
||||
last_date=last_ingested,
|
||||
status="partial",
|
||||
message=f"Rate limited. Ingested {ingested_count} records. Resume available.",
|
||||
)
|
||||
|
||||
return IngestionResult(
|
||||
symbol=ticker.symbol,
|
||||
records_ingested=ingested_count,
|
||||
last_date=last_ingested,
|
||||
status="complete",
|
||||
message=f"Successfully ingested {ingested_count} records",
|
||||
)
|
||||
110
app/services/price_service.py
Normal file
110
app/services/price_service.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Price Store service: upsert and query OHLCV records."""
|
||||
|
||||
from datetime import date, datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import NotFoundError, ValidationError
|
||||
from app.models.ohlcv import OHLCVRecord
|
||||
from app.models.ticker import Ticker
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
"""Look up a ticker by symbol. Raises NotFoundError if missing."""
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
def _validate_ohlcv(
|
||||
high: float, low: float, open_: float, close: float, volume: int, record_date: date
|
||||
) -> None:
|
||||
"""Business-rule validation for an OHLCV record."""
|
||||
if high < low:
|
||||
raise ValidationError("Validation error: high must be >= low")
|
||||
if any(p < 0 for p in (open_, high, low, close)):
|
||||
raise ValidationError("Validation error: prices must be >= 0")
|
||||
if volume < 0:
|
||||
raise ValidationError("Validation error: volume must be >= 0")
|
||||
if record_date > date.today():
|
||||
raise ValidationError("Validation error: date must not be in the future")
|
||||
|
||||
|
||||
async def upsert_ohlcv(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
record_date: date,
|
||||
open_: float,
|
||||
high: float,
|
||||
low: float,
|
||||
close: float,
|
||||
volume: int,
|
||||
) -> OHLCVRecord:
|
||||
"""Insert or update an OHLCV record for (ticker, date).
|
||||
|
||||
Validates business rules, resolves ticker, then uses
|
||||
ON CONFLICT DO UPDATE on the (ticker_id, date) unique constraint.
|
||||
"""
|
||||
_validate_ohlcv(high, low, open_, close, volume, record_date)
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
stmt = pg_insert(OHLCVRecord).values(
|
||||
ticker_id=ticker.id,
|
||||
date=record_date,
|
||||
open=open_,
|
||||
high=high,
|
||||
low=low,
|
||||
close=close,
|
||||
volume=volume,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
constraint="uq_ohlcv_ticker_date",
|
||||
set_={
|
||||
"open": stmt.excluded.open,
|
||||
"high": stmt.excluded.high,
|
||||
"low": stmt.excluded.low,
|
||||
"close": stmt.excluded.close,
|
||||
"volume": stmt.excluded.volume,
|
||||
"created_at": stmt.excluded.created_at,
|
||||
},
|
||||
)
|
||||
stmt = stmt.returning(OHLCVRecord)
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
record = result.scalar_one()
|
||||
|
||||
# TODO: Invalidate LRU cache entries for this ticker (Task 7.1)
|
||||
# TODO: Mark composite score as stale for this ticker (Task 10.1)
|
||||
|
||||
return record
|
||||
|
||||
|
||||
async def query_ohlcv(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
start_date: date | None = None,
|
||||
end_date: date | None = None,
|
||||
) -> list[OHLCVRecord]:
|
||||
"""Query OHLCV records for a ticker, optionally filtered by date range.
|
||||
|
||||
Returns records sorted by date ascending.
|
||||
Raises NotFoundError if the ticker does not exist.
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
stmt = select(OHLCVRecord).where(OHLCVRecord.ticker_id == ticker.id)
|
||||
if start_date is not None:
|
||||
stmt = stmt.where(OHLCVRecord.date >= start_date)
|
||||
if end_date is not None:
|
||||
stmt = stmt.where(OHLCVRecord.date <= end_date)
|
||||
stmt = stmt.order_by(OHLCVRecord.date.asc())
|
||||
|
||||
result = await db.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
241
app/services/rr_scanner_service.py
Normal file
241
app/services/rr_scanner_service.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""R:R Scanner service.
|
||||
|
||||
Scans tracked tickers for asymmetric risk-reward trade setups.
|
||||
Long: target = nearest SR above, stop = entry - ATR × multiplier.
|
||||
Short: target = nearest SR below, stop = entry + ATR × multiplier.
|
||||
Filters by configurable R:R threshold (default 3:1).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import NotFoundError
|
||||
from app.models.score import CompositeScore
|
||||
from app.models.sr_level import SRLevel
|
||||
from app.models.ticker import Ticker
|
||||
from app.models.trade_setup import TradeSetup
|
||||
from app.services.indicator_service import _extract_ohlcv, compute_atr
|
||||
from app.services.price_service import query_ohlcv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
async def scan_ticker(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
rr_threshold: float = 3.0,
|
||||
atr_multiplier: float = 1.5,
|
||||
) -> list[TradeSetup]:
|
||||
"""Scan a single ticker for trade setups meeting the R:R threshold.
|
||||
|
||||
1. Fetch OHLCV data and compute ATR.
|
||||
2. Fetch SR levels.
|
||||
3. Compute long and short setups.
|
||||
4. Filter by R:R threshold.
|
||||
5. Delete old setups for this ticker and persist new ones.
|
||||
|
||||
Returns list of persisted TradeSetup models.
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
# Fetch OHLCV
|
||||
records = await query_ohlcv(db, symbol)
|
||||
if not records or len(records) < 15:
|
||||
logger.info(
|
||||
"Skipping %s: insufficient OHLCV data (%d bars, need 15+)",
|
||||
symbol, len(records),
|
||||
)
|
||||
# Clear any stale setups
|
||||
await db.execute(
|
||||
delete(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
|
||||
)
|
||||
return []
|
||||
|
||||
_, highs, lows, closes, _ = _extract_ohlcv(records)
|
||||
entry_price = closes[-1]
|
||||
|
||||
# Compute ATR
|
||||
try:
|
||||
atr_result = compute_atr(highs, lows, closes)
|
||||
atr_value = atr_result["atr"]
|
||||
except Exception:
|
||||
logger.info("Skipping %s: cannot compute ATR", symbol)
|
||||
await db.execute(
|
||||
delete(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
|
||||
)
|
||||
return []
|
||||
|
||||
if atr_value <= 0:
|
||||
logger.info("Skipping %s: ATR is zero or negative", symbol)
|
||||
await db.execute(
|
||||
delete(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
|
||||
)
|
||||
return []
|
||||
|
||||
# Fetch SR levels from DB (already computed by sr_service)
|
||||
sr_result = await db.execute(
|
||||
select(SRLevel).where(SRLevel.ticker_id == ticker.id)
|
||||
)
|
||||
sr_levels = list(sr_result.scalars().all())
|
||||
|
||||
if not sr_levels:
|
||||
logger.info("Skipping %s: no SR levels available", symbol)
|
||||
await db.execute(
|
||||
delete(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
|
||||
)
|
||||
return []
|
||||
|
||||
levels_above = sorted(
|
||||
[lv for lv in sr_levels if lv.price_level > entry_price],
|
||||
key=lambda lv: lv.price_level,
|
||||
)
|
||||
levels_below = sorted(
|
||||
[lv for lv in sr_levels if lv.price_level < entry_price],
|
||||
key=lambda lv: lv.price_level,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# Get composite score for this ticker
|
||||
comp_result = await db.execute(
|
||||
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
|
||||
)
|
||||
comp = comp_result.scalar_one_or_none()
|
||||
composite_score = comp.score if comp else 0.0
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
setups: list[TradeSetup] = []
|
||||
|
||||
# Long setup: target = nearest SR above, stop = entry - ATR × multiplier
|
||||
if levels_above:
|
||||
target = levels_above[0].price_level
|
||||
stop = entry_price - (atr_value * atr_multiplier)
|
||||
reward = target - entry_price
|
||||
risk = entry_price - stop
|
||||
if risk > 0 and reward > 0:
|
||||
rr = reward / risk
|
||||
if rr >= rr_threshold:
|
||||
setups.append(TradeSetup(
|
||||
ticker_id=ticker.id,
|
||||
direction="long",
|
||||
entry_price=round(entry_price, 4),
|
||||
stop_loss=round(stop, 4),
|
||||
target=round(target, 4),
|
||||
rr_ratio=round(rr, 4),
|
||||
composite_score=round(composite_score, 4),
|
||||
detected_at=now,
|
||||
))
|
||||
|
||||
# Short setup: target = nearest SR below, stop = entry + ATR × multiplier
|
||||
if levels_below:
|
||||
target = levels_below[0].price_level
|
||||
stop = entry_price + (atr_value * atr_multiplier)
|
||||
reward = entry_price - target
|
||||
risk = stop - entry_price
|
||||
if risk > 0 and reward > 0:
|
||||
rr = reward / risk
|
||||
if rr >= rr_threshold:
|
||||
setups.append(TradeSetup(
|
||||
ticker_id=ticker.id,
|
||||
direction="short",
|
||||
entry_price=round(entry_price, 4),
|
||||
stop_loss=round(stop, 4),
|
||||
target=round(target, 4),
|
||||
rr_ratio=round(rr, 4),
|
||||
composite_score=round(composite_score, 4),
|
||||
detected_at=now,
|
||||
))
|
||||
|
||||
# Delete old setups for this ticker, persist new ones
|
||||
await db.execute(
|
||||
delete(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
|
||||
)
|
||||
for setup in setups:
|
||||
db.add(setup)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Refresh to get IDs
|
||||
for s in setups:
|
||||
await db.refresh(s)
|
||||
|
||||
return setups
|
||||
|
||||
|
||||
async def scan_all_tickers(
|
||||
db: AsyncSession,
|
||||
rr_threshold: float = 3.0,
|
||||
atr_multiplier: float = 1.5,
|
||||
) -> list[TradeSetup]:
|
||||
"""Scan all tracked tickers for trade setups.
|
||||
|
||||
Processes each ticker independently — one failure doesn't stop others.
|
||||
Returns all setups found across all tickers.
|
||||
"""
|
||||
result = await db.execute(select(Ticker).order_by(Ticker.symbol))
|
||||
tickers = list(result.scalars().all())
|
||||
|
||||
all_setups: list[TradeSetup] = []
|
||||
for ticker in tickers:
|
||||
try:
|
||||
setups = await scan_ticker(
|
||||
db, ticker.symbol, rr_threshold, atr_multiplier
|
||||
)
|
||||
all_setups.extend(setups)
|
||||
except Exception:
|
||||
logger.exception("Error scanning ticker %s", ticker.symbol)
|
||||
|
||||
return all_setups
|
||||
|
||||
|
||||
async def get_trade_setups(
|
||||
db: AsyncSession,
|
||||
direction: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""Get all stored trade setups, optionally filtered by direction.
|
||||
|
||||
Returns dicts sorted by R:R desc, secondary composite desc.
|
||||
Each dict includes the ticker symbol.
|
||||
"""
|
||||
stmt = (
|
||||
select(TradeSetup, Ticker.symbol)
|
||||
.join(Ticker, TradeSetup.ticker_id == Ticker.id)
|
||||
)
|
||||
if direction is not None:
|
||||
stmt = stmt.where(TradeSetup.direction == direction.lower())
|
||||
|
||||
stmt = stmt.order_by(
|
||||
TradeSetup.rr_ratio.desc(),
|
||||
TradeSetup.composite_score.desc(),
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": setup.id,
|
||||
"symbol": symbol,
|
||||
"direction": setup.direction,
|
||||
"entry_price": setup.entry_price,
|
||||
"stop_loss": setup.stop_loss,
|
||||
"target": setup.target,
|
||||
"rr_ratio": setup.rr_ratio,
|
||||
"composite_score": setup.composite_score,
|
||||
"detected_at": setup.detected_at,
|
||||
}
|
||||
for setup, symbol in rows
|
||||
]
|
||||
584
app/services/scoring_service.py
Normal file
584
app/services/scoring_service.py
Normal file
@@ -0,0 +1,584 @@
|
||||
"""Scoring Engine service.
|
||||
|
||||
Computes dimension scores (technical, sr_quality, sentiment, fundamental,
|
||||
momentum) each 0-100, composite score as weighted average of available
|
||||
dimensions with re-normalized weights, staleness marking/recomputation
|
||||
on demand, and weight update triggers full recomputation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import NotFoundError, ValidationError
|
||||
from app.models.score import CompositeScore, DimensionScore
|
||||
from app.models.settings import SystemSetting
|
||||
from app.models.ticker import Ticker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DIMENSIONS = ["technical", "sr_quality", "sentiment", "fundamental", "momentum"]
|
||||
|
||||
DEFAULT_WEIGHTS: dict[str, float] = {
|
||||
"technical": 0.25,
|
||||
"sr_quality": 0.20,
|
||||
"sentiment": 0.15,
|
||||
"fundamental": 0.20,
|
||||
"momentum": 0.20,
|
||||
}
|
||||
|
||||
SCORING_WEIGHTS_KEY = "scoring_weights"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
async def _get_weights(db: AsyncSession) -> dict[str, float]:
|
||||
"""Load scoring weights from SystemSetting, falling back to defaults."""
|
||||
result = await db.execute(
|
||||
select(SystemSetting).where(SystemSetting.key == SCORING_WEIGHTS_KEY)
|
||||
)
|
||||
setting = result.scalar_one_or_none()
|
||||
if setting is not None:
|
||||
try:
|
||||
return json.loads(setting.value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning("Invalid scoring weights in DB, using defaults")
|
||||
return dict(DEFAULT_WEIGHTS)
|
||||
|
||||
|
||||
async def _save_weights(db: AsyncSession, weights: dict[str, float]) -> None:
|
||||
"""Persist scoring weights to SystemSetting."""
|
||||
result = await db.execute(
|
||||
select(SystemSetting).where(SystemSetting.key == SCORING_WEIGHTS_KEY)
|
||||
)
|
||||
setting = result.scalar_one_or_none()
|
||||
now = datetime.now(timezone.utc)
|
||||
if setting is not None:
|
||||
setting.value = json.dumps(weights)
|
||||
setting.updated_at = now
|
||||
else:
|
||||
setting = SystemSetting(
|
||||
key=SCORING_WEIGHTS_KEY,
|
||||
value=json.dumps(weights),
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(setting)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dimension score computation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _compute_technical_score(db: AsyncSession, symbol: str) -> float | None:
|
||||
"""Compute technical dimension score from ADX, EMA, RSI."""
|
||||
from app.services.indicator_service import (
|
||||
compute_adx,
|
||||
compute_ema,
|
||||
compute_rsi,
|
||||
_extract_ohlcv,
|
||||
)
|
||||
from app.services.price_service import query_ohlcv
|
||||
|
||||
records = await query_ohlcv(db, symbol)
|
||||
if not records:
|
||||
return None
|
||||
|
||||
_, highs, lows, closes, _ = _extract_ohlcv(records)
|
||||
|
||||
scores: list[tuple[float, float]] = [] # (weight, score)
|
||||
|
||||
# ADX (weight 0.4) — needs 28+ bars
|
||||
try:
|
||||
adx_result = compute_adx(highs, lows, closes)
|
||||
scores.append((0.4, adx_result["score"]))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# EMA (weight 0.3) — needs period+1 bars
|
||||
try:
|
||||
ema_result = compute_ema(closes)
|
||||
scores.append((0.3, ema_result["score"]))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# RSI (weight 0.3) — needs 15+ bars
|
||||
try:
|
||||
rsi_result = compute_rsi(closes)
|
||||
scores.append((0.3, rsi_result["score"]))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not scores:
|
||||
return None
|
||||
|
||||
total_weight = sum(w for w, _ in scores)
|
||||
if total_weight == 0:
|
||||
return None
|
||||
weighted = sum(w * s for w, s in scores) / total_weight
|
||||
return max(0.0, min(100.0, weighted))
|
||||
|
||||
|
||||
async def _compute_sr_quality_score(db: AsyncSession, symbol: str) -> float | None:
|
||||
"""Compute S/R quality dimension score.
|
||||
|
||||
Based on number of strong levels, proximity to current price, avg strength.
|
||||
"""
|
||||
from app.services.price_service import query_ohlcv
|
||||
from app.services.sr_service import get_sr_levels
|
||||
|
||||
records = await query_ohlcv(db, symbol)
|
||||
if not records:
|
||||
return None
|
||||
|
||||
current_price = float(records[-1].close)
|
||||
if current_price <= 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
levels = await get_sr_levels(db, symbol)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not levels:
|
||||
return None
|
||||
|
||||
# Factor 1: Number of strong levels (strength >= 50) — max 40 pts
|
||||
strong_count = sum(1 for lv in levels if lv.strength >= 50)
|
||||
count_score = min(40.0, strong_count * 10.0)
|
||||
|
||||
# Factor 2: Proximity of nearest level to current price — max 30 pts
|
||||
distances = [
|
||||
abs(lv.price_level - current_price) / current_price for lv in levels
|
||||
]
|
||||
nearest_dist = min(distances) if distances else 1.0
|
||||
# Closer = higher score. 0% distance = 30, 5%+ = 0
|
||||
proximity_score = max(0.0, min(30.0, 30.0 * (1.0 - nearest_dist / 0.05)))
|
||||
|
||||
# Factor 3: Average strength — max 30 pts
|
||||
avg_strength = sum(lv.strength for lv in levels) / len(levels)
|
||||
strength_score = min(30.0, avg_strength * 0.3)
|
||||
|
||||
total = count_score + proximity_score + strength_score
|
||||
return max(0.0, min(100.0, total))
|
||||
|
||||
|
||||
async def _compute_sentiment_score(db: AsyncSession, symbol: str) -> float | None:
|
||||
"""Compute sentiment dimension score via sentiment service."""
|
||||
from app.services.sentiment_service import compute_sentiment_dimension_score
|
||||
|
||||
try:
|
||||
return await compute_sentiment_dimension_score(db, symbol)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def _compute_fundamental_score(db: AsyncSession, symbol: str) -> float | None:
|
||||
"""Compute fundamental dimension score.
|
||||
|
||||
Normalized composite of P/E (lower is better), revenue growth
|
||||
(higher is better), earnings surprise (higher is better).
|
||||
"""
|
||||
from app.services.fundamental_service import get_fundamental
|
||||
|
||||
fund = await get_fundamental(db, symbol)
|
||||
if fund is None:
|
||||
return None
|
||||
|
||||
scores: list[float] = []
|
||||
|
||||
# P/E: lower is better. 0-15 = 100, 15-30 = 50-100, 30+ = 0-50
|
||||
if fund.pe_ratio is not None and fund.pe_ratio > 0:
|
||||
pe_score = max(0.0, min(100.0, 100.0 - (fund.pe_ratio - 15.0) * (100.0 / 30.0)))
|
||||
scores.append(pe_score)
|
||||
|
||||
# Revenue growth: higher is better. 0% = 50, 20%+ = 100, -20% = 0
|
||||
if fund.revenue_growth is not None:
|
||||
rg_score = max(0.0, min(100.0, 50.0 + fund.revenue_growth * 2.5))
|
||||
scores.append(rg_score)
|
||||
|
||||
# Earnings surprise: higher is better. 0% = 50, 10%+ = 100, -10% = 0
|
||||
if fund.earnings_surprise is not None:
|
||||
es_score = max(0.0, min(100.0, 50.0 + fund.earnings_surprise * 5.0))
|
||||
scores.append(es_score)
|
||||
|
||||
if not scores:
|
||||
return None
|
||||
|
||||
return sum(scores) / len(scores)
|
||||
|
||||
|
||||
async def _compute_momentum_score(db: AsyncSession, symbol: str) -> float | None:
|
||||
"""Compute momentum dimension score.
|
||||
|
||||
Rate of change of price over 5-day and 20-day lookback periods.
|
||||
"""
|
||||
from app.services.price_service import query_ohlcv
|
||||
|
||||
records = await query_ohlcv(db, symbol)
|
||||
if not records or len(records) < 6:
|
||||
return None
|
||||
|
||||
closes = [float(r.close) for r in records]
|
||||
latest = closes[-1]
|
||||
|
||||
scores: list[tuple[float, float]] = [] # (weight, score)
|
||||
|
||||
# 5-day ROC (weight 0.5)
|
||||
if len(closes) >= 6 and closes[-6] > 0:
|
||||
roc_5 = (latest - closes[-6]) / closes[-6] * 100.0
|
||||
# Map: -10% → 0, 0% → 50, +10% → 100
|
||||
score_5 = max(0.0, min(100.0, 50.0 + roc_5 * 5.0))
|
||||
scores.append((0.5, score_5))
|
||||
|
||||
# 20-day ROC (weight 0.5)
|
||||
if len(closes) >= 21 and closes[-21] > 0:
|
||||
roc_20 = (latest - closes[-21]) / closes[-21] * 100.0
|
||||
score_20 = max(0.0, min(100.0, 50.0 + roc_20 * 5.0))
|
||||
scores.append((0.5, score_20))
|
||||
|
||||
if not scores:
|
||||
return None
|
||||
|
||||
total_weight = sum(w for w, _ in scores)
|
||||
if total_weight == 0:
|
||||
return None
|
||||
weighted = sum(w * s for w, s in scores) / total_weight
|
||||
return max(0.0, min(100.0, weighted))
|
||||
|
||||
|
||||
_DIMENSION_COMPUTERS = {
|
||||
"technical": _compute_technical_score,
|
||||
"sr_quality": _compute_sr_quality_score,
|
||||
"sentiment": _compute_sentiment_score,
|
||||
"fundamental": _compute_fundamental_score,
|
||||
"momentum": _compute_momentum_score,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def compute_dimension_score(
|
||||
db: AsyncSession, symbol: str, dimension: str
|
||||
) -> float | None:
|
||||
"""Compute a single dimension score for a ticker.
|
||||
|
||||
Returns the score (0-100) or None if insufficient data.
|
||||
Persists the result to the DimensionScore table.
|
||||
"""
|
||||
if dimension not in _DIMENSION_COMPUTERS:
|
||||
raise ValidationError(
|
||||
f"Unknown dimension: {dimension}. Valid: {', '.join(DIMENSIONS)}"
|
||||
)
|
||||
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
score_val = await _DIMENSION_COMPUTERS[dimension](db, symbol)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Upsert dimension score
|
||||
result = await db.execute(
|
||||
select(DimensionScore).where(
|
||||
DimensionScore.ticker_id == ticker.id,
|
||||
DimensionScore.dimension == dimension,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if score_val is not None:
|
||||
score_val = max(0.0, min(100.0, score_val))
|
||||
|
||||
if existing is not None:
|
||||
if score_val is not None:
|
||||
existing.score = score_val
|
||||
existing.is_stale = False
|
||||
existing.computed_at = now
|
||||
else:
|
||||
# Can't compute — mark stale
|
||||
existing.is_stale = True
|
||||
elif score_val is not None:
|
||||
dim = DimensionScore(
|
||||
ticker_id=ticker.id,
|
||||
dimension=dimension,
|
||||
score=score_val,
|
||||
is_stale=False,
|
||||
computed_at=now,
|
||||
)
|
||||
db.add(dim)
|
||||
|
||||
return score_val
|
||||
|
||||
|
||||
async def compute_all_dimensions(
|
||||
db: AsyncSession, symbol: str
|
||||
) -> dict[str, float | None]:
|
||||
"""Compute all dimension scores for a ticker. Returns dimension → score map."""
|
||||
results: dict[str, float | None] = {}
|
||||
for dim in DIMENSIONS:
|
||||
results[dim] = await compute_dimension_score(db, symbol, dim)
|
||||
return results
|
||||
|
||||
|
||||
async def compute_composite_score(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
weights: dict[str, float] | None = None,
|
||||
) -> tuple[float | None, list[str]]:
|
||||
"""Compute composite score from available dimension scores.
|
||||
|
||||
Returns (composite_score, missing_dimensions).
|
||||
Missing dimensions are excluded and weights re-normalized.
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
if weights is None:
|
||||
weights = await _get_weights(db)
|
||||
|
||||
# Get current dimension scores
|
||||
result = await db.execute(
|
||||
select(DimensionScore).where(DimensionScore.ticker_id == ticker.id)
|
||||
)
|
||||
dim_scores = {ds.dimension: ds for ds in result.scalars().all()}
|
||||
|
||||
available: list[tuple[str, float, float]] = [] # (dim, weight, score)
|
||||
missing: list[str] = []
|
||||
|
||||
for dim in DIMENSIONS:
|
||||
w = weights.get(dim, 0.0)
|
||||
if w <= 0:
|
||||
continue
|
||||
ds = dim_scores.get(dim)
|
||||
if ds is not None and not ds.is_stale and ds.score is not None:
|
||||
available.append((dim, w, ds.score))
|
||||
else:
|
||||
missing.append(dim)
|
||||
|
||||
if not available:
|
||||
return None, missing
|
||||
|
||||
# Re-normalize weights
|
||||
total_weight = sum(w for _, w, _ in available)
|
||||
if total_weight == 0:
|
||||
return None, missing
|
||||
|
||||
composite = sum(w * s for _, w, s in available) / total_weight
|
||||
composite = max(0.0, min(100.0, composite))
|
||||
|
||||
# Persist composite score
|
||||
now = datetime.now(timezone.utc)
|
||||
comp_result = await db.execute(
|
||||
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
|
||||
)
|
||||
existing = comp_result.scalar_one_or_none()
|
||||
|
||||
if existing is not None:
|
||||
existing.score = composite
|
||||
existing.is_stale = False
|
||||
existing.weights_json = json.dumps(weights)
|
||||
existing.computed_at = now
|
||||
else:
|
||||
comp = CompositeScore(
|
||||
ticker_id=ticker.id,
|
||||
score=composite,
|
||||
is_stale=False,
|
||||
weights_json=json.dumps(weights),
|
||||
computed_at=now,
|
||||
)
|
||||
db.add(comp)
|
||||
|
||||
return composite, missing
|
||||
|
||||
|
||||
async def get_score(
|
||||
db: AsyncSession, symbol: str
|
||||
) -> dict:
|
||||
"""Get composite + all dimension scores for a ticker.
|
||||
|
||||
Recomputes stale dimensions on demand, then recomputes composite.
|
||||
Returns a dict suitable for ScoreResponse.
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
weights = await _get_weights(db)
|
||||
|
||||
# Check for stale dimension scores and recompute them
|
||||
result = await db.execute(
|
||||
select(DimensionScore).where(DimensionScore.ticker_id == ticker.id)
|
||||
)
|
||||
dim_scores = {ds.dimension: ds for ds in result.scalars().all()}
|
||||
|
||||
for dim in DIMENSIONS:
|
||||
ds = dim_scores.get(dim)
|
||||
if ds is None or ds.is_stale:
|
||||
await compute_dimension_score(db, symbol, dim)
|
||||
|
||||
# Check composite staleness
|
||||
comp_result = await db.execute(
|
||||
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
|
||||
)
|
||||
comp = comp_result.scalar_one_or_none()
|
||||
|
||||
if comp is None or comp.is_stale:
|
||||
await compute_composite_score(db, symbol, weights)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Re-fetch everything fresh
|
||||
result = await db.execute(
|
||||
select(DimensionScore).where(DimensionScore.ticker_id == ticker.id)
|
||||
)
|
||||
dim_scores_list = list(result.scalars().all())
|
||||
|
||||
comp_result = await db.execute(
|
||||
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
|
||||
)
|
||||
comp = comp_result.scalar_one_or_none()
|
||||
|
||||
dimensions = []
|
||||
missing = []
|
||||
for dim in DIMENSIONS:
|
||||
found = next((ds for ds in dim_scores_list if ds.dimension == dim), None)
|
||||
if found is not None:
|
||||
dimensions.append({
|
||||
"dimension": found.dimension,
|
||||
"score": found.score,
|
||||
"is_stale": found.is_stale,
|
||||
"computed_at": found.computed_at,
|
||||
})
|
||||
else:
|
||||
missing.append(dim)
|
||||
|
||||
return {
|
||||
"symbol": ticker.symbol,
|
||||
"composite_score": comp.score if comp else None,
|
||||
"composite_stale": comp.is_stale if comp else False,
|
||||
"weights": weights,
|
||||
"dimensions": dimensions,
|
||||
"missing_dimensions": missing,
|
||||
"computed_at": comp.computed_at if comp else None,
|
||||
}
|
||||
|
||||
|
||||
async def get_rankings(db: AsyncSession) -> dict:
|
||||
"""Get all tickers ranked by composite score descending.
|
||||
|
||||
Returns dict suitable for RankingResponse.
|
||||
"""
|
||||
weights = await _get_weights(db)
|
||||
|
||||
# Get all tickers
|
||||
result = await db.execute(select(Ticker).order_by(Ticker.symbol))
|
||||
tickers = list(result.scalars().all())
|
||||
|
||||
rankings: list[dict] = []
|
||||
for ticker in tickers:
|
||||
# Get composite score
|
||||
comp_result = await db.execute(
|
||||
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
|
||||
)
|
||||
comp = comp_result.scalar_one_or_none()
|
||||
|
||||
# If no composite or stale, recompute
|
||||
if comp is None or comp.is_stale:
|
||||
# Recompute stale dimensions first
|
||||
dim_result = await db.execute(
|
||||
select(DimensionScore).where(
|
||||
DimensionScore.ticker_id == ticker.id
|
||||
)
|
||||
)
|
||||
dim_scores = {ds.dimension: ds for ds in dim_result.scalars().all()}
|
||||
for dim in DIMENSIONS:
|
||||
ds = dim_scores.get(dim)
|
||||
if ds is None or ds.is_stale:
|
||||
await compute_dimension_score(db, ticker.symbol, dim)
|
||||
|
||||
await compute_composite_score(db, ticker.symbol, weights)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Re-fetch
|
||||
comp_result = await db.execute(
|
||||
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
|
||||
)
|
||||
comp = comp_result.scalar_one_or_none()
|
||||
if comp is None:
|
||||
continue
|
||||
|
||||
dim_result = await db.execute(
|
||||
select(DimensionScore).where(
|
||||
DimensionScore.ticker_id == ticker.id
|
||||
)
|
||||
)
|
||||
dims = [
|
||||
{
|
||||
"dimension": ds.dimension,
|
||||
"score": ds.score,
|
||||
"is_stale": ds.is_stale,
|
||||
"computed_at": ds.computed_at,
|
||||
}
|
||||
for ds in dim_result.scalars().all()
|
||||
]
|
||||
|
||||
rankings.append({
|
||||
"symbol": ticker.symbol,
|
||||
"composite_score": comp.score,
|
||||
"dimensions": dims,
|
||||
})
|
||||
|
||||
# Sort by composite score descending
|
||||
rankings.sort(key=lambda r: r["composite_score"], reverse=True)
|
||||
|
||||
return {
|
||||
"rankings": rankings,
|
||||
"weights": weights,
|
||||
}
|
||||
|
||||
|
||||
async def update_weights(
|
||||
db: AsyncSession, weights: dict[str, float]
|
||||
) -> dict[str, float]:
|
||||
"""Update scoring weights and recompute all composite scores.
|
||||
|
||||
Validates that all weights are positive and dimensions are valid.
|
||||
Returns the new weights.
|
||||
"""
|
||||
# Validate
|
||||
for dim, w in weights.items():
|
||||
if dim not in DIMENSIONS:
|
||||
raise ValidationError(
|
||||
f"Unknown dimension: {dim}. Valid: {', '.join(DIMENSIONS)}"
|
||||
)
|
||||
if w < 0:
|
||||
raise ValidationError(f"Weight for {dim} must be non-negative, got {w}")
|
||||
|
||||
# Ensure all dimensions have a weight (default 0 for unspecified)
|
||||
full_weights = {dim: weights.get(dim, 0.0) for dim in DIMENSIONS}
|
||||
|
||||
# Persist
|
||||
await _save_weights(db, full_weights)
|
||||
|
||||
# Recompute all composite scores
|
||||
result = await db.execute(select(Ticker))
|
||||
tickers = list(result.scalars().all())
|
||||
|
||||
for ticker in tickers:
|
||||
await compute_composite_score(db, ticker.symbol, full_weights)
|
||||
|
||||
await db.commit()
|
||||
return full_weights
|
||||
131
app/services/sentiment_service.py
Normal file
131
app/services/sentiment_service.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Sentiment service.
|
||||
|
||||
Stores sentiment records and computes the sentiment dimension score
|
||||
using a time-decay weighted average over a configurable lookback window.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import NotFoundError
|
||||
from app.models.sentiment import SentimentScore
|
||||
from app.models.ticker import Ticker
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
"""Look up a ticker by symbol."""
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
async def store_sentiment(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
classification: str,
|
||||
confidence: int,
|
||||
source: str,
|
||||
timestamp: datetime | None = None,
|
||||
) -> SentimentScore:
|
||||
"""Store a new sentiment record for a ticker."""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now(timezone.utc)
|
||||
|
||||
record = SentimentScore(
|
||||
ticker_id=ticker.id,
|
||||
classification=classification,
|
||||
confidence=confidence,
|
||||
source=source,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
db.add(record)
|
||||
await db.commit()
|
||||
await db.refresh(record)
|
||||
return record
|
||||
|
||||
|
||||
async def get_sentiment_scores(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
lookback_hours: float = 24,
|
||||
) -> list[SentimentScore]:
|
||||
"""Get recent sentiment records within the lookback window."""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(hours=lookback_hours)
|
||||
|
||||
result = await db.execute(
|
||||
select(SentimentScore)
|
||||
.where(
|
||||
SentimentScore.ticker_id == ticker.id,
|
||||
SentimentScore.timestamp >= cutoff,
|
||||
)
|
||||
.order_by(SentimentScore.timestamp.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
def _classification_to_base_score(classification: str, confidence: int) -> float:
|
||||
"""Map classification + confidence to a base score (0-100).
|
||||
|
||||
bullish → confidence (high confidence = high score)
|
||||
bearish → 100 - confidence (high confidence bearish = low score)
|
||||
neutral → 50
|
||||
"""
|
||||
cl = classification.lower()
|
||||
if cl == "bullish":
|
||||
return float(confidence)
|
||||
elif cl == "bearish":
|
||||
return float(100 - confidence)
|
||||
else:
|
||||
return 50.0
|
||||
|
||||
|
||||
async def compute_sentiment_dimension_score(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
lookback_hours: float = 24,
|
||||
decay_rate: float = 0.1,
|
||||
) -> float | None:
|
||||
"""Compute the sentiment dimension score using time-decay weighted average.
|
||||
|
||||
Returns a score in [0, 100] or None if no scores exist in the window.
|
||||
|
||||
Algorithm:
|
||||
1. For each score in the lookback window, compute base_score from
|
||||
classification + confidence.
|
||||
2. Apply time decay: weight = exp(-decay_rate * hours_since_score).
|
||||
3. Weighted average: sum(base_score * weight) / sum(weight).
|
||||
"""
|
||||
scores = await get_sentiment_scores(db, symbol, lookback_hours)
|
||||
if not scores:
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
weighted_sum = 0.0
|
||||
weight_total = 0.0
|
||||
|
||||
for score in scores:
|
||||
ts = score.timestamp
|
||||
if ts.tzinfo is None:
|
||||
ts = ts.replace(tzinfo=timezone.utc)
|
||||
hours_since = (now - ts).total_seconds() / 3600.0
|
||||
weight = math.exp(-decay_rate * hours_since)
|
||||
base = _classification_to_base_score(score.classification, score.confidence)
|
||||
weighted_sum += base * weight
|
||||
weight_total += weight
|
||||
|
||||
if weight_total == 0:
|
||||
return None
|
||||
|
||||
result = weighted_sum / weight_total
|
||||
return max(0.0, min(100.0, result))
|
||||
274
app/services/sr_service.py
Normal file
274
app/services/sr_service.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""S/R Detector service.
|
||||
|
||||
Detects support/resistance levels from Volume Profile (HVN/LVN) and
|
||||
Pivot Points (swing highs/lows), assigns strength scores, merges nearby
|
||||
levels, tags as support/resistance, and persists to DB.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import NotFoundError, ValidationError
|
||||
from app.models.sr_level import SRLevel
|
||||
from app.models.ticker import Ticker
|
||||
from app.services.indicator_service import (
|
||||
_extract_ohlcv,
|
||||
compute_pivot_points,
|
||||
compute_volume_profile,
|
||||
)
|
||||
from app.services.price_service import query_ohlcv
|
||||
|
||||
DEFAULT_TOLERANCE = 0.005 # 0.5%
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
"""Look up a ticker by symbol."""
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
def _count_price_touches(
|
||||
price_level: float,
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
tolerance: float = DEFAULT_TOLERANCE,
|
||||
) -> int:
|
||||
"""Count how many bars touched/respected a price level within tolerance."""
|
||||
count = 0
|
||||
tol = price_level * tolerance if price_level != 0 else tolerance
|
||||
for i in range(len(closes)):
|
||||
# A bar "touches" the level if the level is within the bar's range
|
||||
# (within tolerance)
|
||||
if lows[i] - tol <= price_level <= highs[i] + tol:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def _strength_from_touches(touches: int, total_bars: int) -> int:
|
||||
"""Convert touch count to a 0-100 strength score.
|
||||
|
||||
More touches relative to total bars = higher strength.
|
||||
Cap at 100.
|
||||
"""
|
||||
if total_bars == 0:
|
||||
return 0
|
||||
# Scale: each touch contributes proportionally, with a multiplier
|
||||
# so that a level touched ~20% of bars gets score ~100
|
||||
raw = (touches / total_bars) * 500.0
|
||||
return max(0, min(100, int(round(raw))))
|
||||
|
||||
|
||||
def _extract_candidate_levels(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
volumes: list[int],
|
||||
) -> list[tuple[float, str]]:
|
||||
"""Extract candidate S/R levels from Volume Profile and Pivot Points.
|
||||
|
||||
Returns list of (price_level, detection_method) tuples.
|
||||
"""
|
||||
candidates: list[tuple[float, str]] = []
|
||||
|
||||
# Volume Profile: HVN and LVN as candidate levels
|
||||
try:
|
||||
vp = compute_volume_profile(highs, lows, closes, volumes)
|
||||
for price in vp.get("hvn", []):
|
||||
candidates.append((price, "volume_profile"))
|
||||
for price in vp.get("lvn", []):
|
||||
candidates.append((price, "volume_profile"))
|
||||
except ValidationError:
|
||||
pass # Not enough data for volume profile
|
||||
|
||||
# Pivot Points: swing highs and lows
|
||||
try:
|
||||
pp = compute_pivot_points(highs, lows, closes)
|
||||
for price in pp.get("swing_highs", []):
|
||||
candidates.append((price, "pivot_point"))
|
||||
for price in pp.get("swing_lows", []):
|
||||
candidates.append((price, "pivot_point"))
|
||||
except ValidationError:
|
||||
pass # Not enough data for pivot points
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
def _merge_levels(
|
||||
levels: list[dict],
|
||||
tolerance: float = DEFAULT_TOLERANCE,
|
||||
) -> list[dict]:
|
||||
"""Merge levels within tolerance into consolidated levels.
|
||||
|
||||
Levels from different methods within tolerance are merged.
|
||||
Merged levels combine strength scores (capped at 100) and get
|
||||
detection_method = "merged".
|
||||
"""
|
||||
if not levels:
|
||||
return []
|
||||
|
||||
# Sort by price
|
||||
sorted_levels = sorted(levels, key=lambda x: x["price_level"])
|
||||
merged: list[dict] = []
|
||||
|
||||
for level in sorted_levels:
|
||||
if not merged:
|
||||
merged.append(dict(level))
|
||||
continue
|
||||
|
||||
last = merged[-1]
|
||||
ref_price = last["price_level"]
|
||||
tol = ref_price * tolerance if ref_price != 0 else tolerance
|
||||
|
||||
if abs(level["price_level"] - ref_price) <= tol:
|
||||
# Merge: average price, combine strength, mark as merged
|
||||
combined_strength = min(100, last["strength"] + level["strength"])
|
||||
avg_price = (last["price_level"] + level["price_level"]) / 2.0
|
||||
method = (
|
||||
"merged"
|
||||
if last["detection_method"] != level["detection_method"]
|
||||
else last["detection_method"]
|
||||
)
|
||||
last["price_level"] = round(avg_price, 4)
|
||||
last["strength"] = combined_strength
|
||||
last["detection_method"] = method
|
||||
else:
|
||||
merged.append(dict(level))
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def _tag_levels(
|
||||
levels: list[dict],
|
||||
current_price: float,
|
||||
) -> list[dict]:
|
||||
"""Tag each level as 'support' or 'resistance' relative to current price."""
|
||||
for level in levels:
|
||||
if level["price_level"] < current_price:
|
||||
level["type"] = "support"
|
||||
else:
|
||||
level["type"] = "resistance"
|
||||
return levels
|
||||
|
||||
|
||||
def detect_sr_levels(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
volumes: list[int],
|
||||
tolerance: float = DEFAULT_TOLERANCE,
|
||||
) -> list[dict]:
|
||||
"""Detect, score, merge, and tag S/R levels from OHLCV data.
|
||||
|
||||
Returns list of dicts with keys: price_level, type, strength,
|
||||
detection_method — sorted by strength descending.
|
||||
"""
|
||||
if not closes:
|
||||
return []
|
||||
|
||||
candidates = _extract_candidate_levels(highs, lows, closes, volumes)
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
total_bars = len(closes)
|
||||
current_price = closes[-1]
|
||||
|
||||
# Build level dicts with strength scores
|
||||
raw_levels: list[dict] = []
|
||||
for price, method in candidates:
|
||||
touches = _count_price_touches(price, highs, lows, closes, tolerance)
|
||||
strength = _strength_from_touches(touches, total_bars)
|
||||
raw_levels.append({
|
||||
"price_level": price,
|
||||
"strength": strength,
|
||||
"detection_method": method,
|
||||
"type": "", # will be tagged after merge
|
||||
})
|
||||
|
||||
# Merge nearby levels
|
||||
merged = _merge_levels(raw_levels, tolerance)
|
||||
|
||||
# Tag as support/resistance
|
||||
tagged = _tag_levels(merged, current_price)
|
||||
|
||||
# Sort by strength descending
|
||||
tagged.sort(key=lambda x: x["strength"], reverse=True)
|
||||
|
||||
return tagged
|
||||
|
||||
|
||||
async def recalculate_sr_levels(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
tolerance: float = DEFAULT_TOLERANCE,
|
||||
) -> list[SRLevel]:
|
||||
"""Recalculate S/R levels for a ticker and persist to DB.
|
||||
|
||||
1. Fetch OHLCV data
|
||||
2. Detect levels
|
||||
3. Delete old levels for ticker
|
||||
4. Insert new levels
|
||||
5. Return new levels sorted by strength desc
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
records = await query_ohlcv(db, symbol)
|
||||
if not records:
|
||||
# No OHLCV data — clear any existing levels
|
||||
await db.execute(
|
||||
delete(SRLevel).where(SRLevel.ticker_id == ticker.id)
|
||||
)
|
||||
await db.commit()
|
||||
return []
|
||||
|
||||
_, highs, lows, closes, volumes = _extract_ohlcv(records)
|
||||
|
||||
levels = detect_sr_levels(highs, lows, closes, volumes, tolerance)
|
||||
|
||||
# Delete old levels
|
||||
await db.execute(
|
||||
delete(SRLevel).where(SRLevel.ticker_id == ticker.id)
|
||||
)
|
||||
|
||||
# Insert new levels
|
||||
now = datetime.utcnow()
|
||||
new_models: list[SRLevel] = []
|
||||
for lvl in levels:
|
||||
model = SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=lvl["price_level"],
|
||||
type=lvl["type"],
|
||||
strength=lvl["strength"],
|
||||
detection_method=lvl["detection_method"],
|
||||
created_at=now,
|
||||
)
|
||||
db.add(model)
|
||||
new_models.append(model)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Refresh to get IDs
|
||||
for m in new_models:
|
||||
await db.refresh(m)
|
||||
|
||||
return new_models
|
||||
|
||||
|
||||
async def get_sr_levels(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
tolerance: float = DEFAULT_TOLERANCE,
|
||||
) -> list[SRLevel]:
|
||||
"""Get S/R levels for a ticker, recalculating on every request (MVP).
|
||||
|
||||
Returns levels sorted by strength descending.
|
||||
"""
|
||||
return await recalculate_sr_levels(db, symbol, tolerance)
|
||||
57
app/services/ticker_service.py
Normal file
57
app/services/ticker_service.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Ticker Registry service: add, delete, and list tracked tickers."""
|
||||
|
||||
import re
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import DuplicateError, NotFoundError, ValidationError
|
||||
from app.models.ticker import Ticker
|
||||
|
||||
|
||||
async def add_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
"""Add a new ticker after validation.
|
||||
|
||||
Validates: non-empty, uppercase alphanumeric. Auto-uppercases input.
|
||||
Raises DuplicateError if symbol already tracked.
|
||||
"""
|
||||
stripped = symbol.strip()
|
||||
if not stripped:
|
||||
raise ValidationError("Ticker symbol must not be empty or whitespace-only")
|
||||
|
||||
normalised = stripped.upper()
|
||||
if not re.fullmatch(r"[A-Z0-9]+", normalised):
|
||||
raise ValidationError(
|
||||
f"Ticker symbol must be alphanumeric: {normalised}"
|
||||
)
|
||||
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
if result.scalar_one_or_none() is not None:
|
||||
raise DuplicateError(f"Ticker already exists: {normalised}")
|
||||
|
||||
ticker = Ticker(symbol=normalised)
|
||||
db.add(ticker)
|
||||
await db.commit()
|
||||
await db.refresh(ticker)
|
||||
return ticker
|
||||
|
||||
|
||||
async def delete_ticker(db: AsyncSession, symbol: str) -> None:
|
||||
"""Delete a ticker and cascade all associated data.
|
||||
|
||||
Raises NotFoundError if the symbol is not tracked.
|
||||
"""
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
|
||||
await db.delete(ticker)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def list_tickers(db: AsyncSession) -> list[Ticker]:
|
||||
"""Return all tracked tickers sorted alphabetically by symbol."""
|
||||
result = await db.execute(select(Ticker).order_by(Ticker.symbol.asc()))
|
||||
return list(result.scalars().all())
|
||||
288
app/services/watchlist_service.py
Normal file
288
app/services/watchlist_service.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Watchlist service.
|
||||
|
||||
Auto-populates top-X tickers by composite score (default 10), supports
|
||||
manual add/remove (tagged, not subject to auto-population), enforces
|
||||
cap (auto + 10 manual, default max 20), and updates auto entries on
|
||||
score recomputation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import DuplicateError, NotFoundError, ValidationError
|
||||
from app.models.score import CompositeScore, DimensionScore
|
||||
from app.models.sr_level import SRLevel
|
||||
from app.models.ticker import Ticker
|
||||
from app.models.trade_setup import TradeSetup
|
||||
from app.models.watchlist import WatchlistEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_AUTO_SIZE = 10
|
||||
MAX_MANUAL = 10
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
async def auto_populate(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
top_x: int = DEFAULT_AUTO_SIZE,
|
||||
) -> None:
|
||||
"""Auto-populate watchlist with top-X tickers by composite score.
|
||||
|
||||
Replaces existing auto entries. Manual entries are untouched.
|
||||
"""
|
||||
# Get top-X tickers by composite score (non-stale, descending)
|
||||
stmt = (
|
||||
select(CompositeScore)
|
||||
.where(CompositeScore.is_stale == False) # noqa: E712
|
||||
.order_by(CompositeScore.score.desc())
|
||||
.limit(top_x)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
top_scores = list(result.scalars().all())
|
||||
top_ticker_ids = {cs.ticker_id for cs in top_scores}
|
||||
|
||||
# Delete existing auto entries for this user
|
||||
await db.execute(
|
||||
delete(WatchlistEntry).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
WatchlistEntry.entry_type == "auto",
|
||||
)
|
||||
)
|
||||
|
||||
# Get manual ticker_ids so we don't duplicate
|
||||
manual_result = await db.execute(
|
||||
select(WatchlistEntry.ticker_id).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
WatchlistEntry.entry_type == "manual",
|
||||
)
|
||||
)
|
||||
manual_ticker_ids = {row[0] for row in manual_result.all()}
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
for ticker_id in top_ticker_ids:
|
||||
if ticker_id in manual_ticker_ids:
|
||||
continue # Already on watchlist as manual
|
||||
entry = WatchlistEntry(
|
||||
user_id=user_id,
|
||||
ticker_id=ticker_id,
|
||||
entry_type="auto",
|
||||
added_at=now,
|
||||
)
|
||||
db.add(entry)
|
||||
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def add_manual_entry(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
symbol: str,
|
||||
) -> WatchlistEntry:
|
||||
"""Add a manual watchlist entry.
|
||||
|
||||
Raises DuplicateError if already on watchlist.
|
||||
Raises ValidationError if manual cap exceeded.
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
# Check if already on watchlist
|
||||
existing = await db.execute(
|
||||
select(WatchlistEntry).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
WatchlistEntry.ticker_id == ticker.id,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise DuplicateError(f"Ticker already on watchlist: {ticker.symbol}")
|
||||
|
||||
# Count current manual entries
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(WatchlistEntry).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
WatchlistEntry.entry_type == "manual",
|
||||
)
|
||||
)
|
||||
manual_count = count_result.scalar() or 0
|
||||
|
||||
if manual_count >= MAX_MANUAL:
|
||||
raise ValidationError(
|
||||
f"Manual watchlist cap reached ({MAX_MANUAL}). "
|
||||
"Remove an entry before adding a new one."
|
||||
)
|
||||
|
||||
# Check total cap
|
||||
total_result = await db.execute(
|
||||
select(func.count()).select_from(WatchlistEntry).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
)
|
||||
)
|
||||
total_count = total_result.scalar() or 0
|
||||
max_total = DEFAULT_AUTO_SIZE + MAX_MANUAL
|
||||
|
||||
if total_count >= max_total:
|
||||
raise ValidationError(
|
||||
f"Watchlist cap reached ({max_total}). "
|
||||
"Remove an entry before adding a new one."
|
||||
)
|
||||
|
||||
entry = WatchlistEntry(
|
||||
user_id=user_id,
|
||||
ticker_id=ticker.id,
|
||||
entry_type="manual",
|
||||
added_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(entry)
|
||||
await db.commit()
|
||||
await db.refresh(entry)
|
||||
return entry
|
||||
|
||||
|
||||
async def remove_entry(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
symbol: str,
|
||||
) -> None:
|
||||
"""Remove a watchlist entry (manual or auto)."""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
result = await db.execute(
|
||||
select(WatchlistEntry).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
WatchlistEntry.ticker_id == ticker.id,
|
||||
)
|
||||
)
|
||||
entry = result.scalar_one_or_none()
|
||||
if entry is None:
|
||||
raise NotFoundError(f"Ticker not on watchlist: {ticker.symbol}")
|
||||
|
||||
await db.delete(entry)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _enrich_entry(
|
||||
db: AsyncSession,
|
||||
entry: WatchlistEntry,
|
||||
symbol: str,
|
||||
) -> dict:
|
||||
"""Build enriched watchlist entry dict with scores, R:R, and SR levels."""
|
||||
ticker_id = entry.ticker_id
|
||||
|
||||
# Composite score
|
||||
comp_result = await db.execute(
|
||||
select(CompositeScore).where(CompositeScore.ticker_id == ticker_id)
|
||||
)
|
||||
comp = comp_result.scalar_one_or_none()
|
||||
|
||||
# Dimension scores
|
||||
dim_result = await db.execute(
|
||||
select(DimensionScore).where(DimensionScore.ticker_id == ticker_id)
|
||||
)
|
||||
dims = [
|
||||
{"dimension": ds.dimension, "score": ds.score}
|
||||
for ds in dim_result.scalars().all()
|
||||
]
|
||||
|
||||
# Best trade setup (highest R:R) for this ticker
|
||||
setup_result = await db.execute(
|
||||
select(TradeSetup)
|
||||
.where(TradeSetup.ticker_id == ticker_id)
|
||||
.order_by(TradeSetup.rr_ratio.desc())
|
||||
.limit(1)
|
||||
)
|
||||
setup = setup_result.scalar_one_or_none()
|
||||
|
||||
# Active SR levels
|
||||
sr_result = await db.execute(
|
||||
select(SRLevel)
|
||||
.where(SRLevel.ticker_id == ticker_id)
|
||||
.order_by(SRLevel.strength.desc())
|
||||
)
|
||||
sr_levels = [
|
||||
{
|
||||
"price_level": lv.price_level,
|
||||
"type": lv.type,
|
||||
"strength": lv.strength,
|
||||
}
|
||||
for lv in sr_result.scalars().all()
|
||||
]
|
||||
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"entry_type": entry.entry_type,
|
||||
"composite_score": comp.score if comp else None,
|
||||
"dimensions": dims,
|
||||
"rr_ratio": setup.rr_ratio if setup else None,
|
||||
"rr_direction": setup.direction if setup else None,
|
||||
"sr_levels": sr_levels,
|
||||
"added_at": entry.added_at,
|
||||
}
|
||||
|
||||
|
||||
async def get_watchlist(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
sort_by: str = "composite",
|
||||
) -> list[dict]:
|
||||
"""Get user's watchlist with enriched data.
|
||||
|
||||
Runs auto_populate first to ensure auto entries are current,
|
||||
then enriches each entry with scores, R:R, and SR levels.
|
||||
|
||||
sort_by: "composite", "rr", or a dimension name
|
||||
(e.g. "technical", "sr_quality", "sentiment", "fundamental", "momentum").
|
||||
"""
|
||||
# Auto-populate to refresh auto entries
|
||||
await auto_populate(db, user_id)
|
||||
await db.commit()
|
||||
|
||||
# Fetch all entries with ticker symbol
|
||||
stmt = (
|
||||
select(WatchlistEntry, Ticker.symbol)
|
||||
.join(Ticker, WatchlistEntry.ticker_id == Ticker.id)
|
||||
.where(WatchlistEntry.user_id == user_id)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
rows = result.all()
|
||||
|
||||
entries: list[dict] = []
|
||||
for entry, symbol in rows:
|
||||
enriched = await _enrich_entry(db, entry, symbol)
|
||||
entries.append(enriched)
|
||||
|
||||
# Sort
|
||||
if sort_by == "composite":
|
||||
entries.sort(
|
||||
key=lambda e: e["composite_score"] if e["composite_score"] is not None else -1,
|
||||
reverse=True,
|
||||
)
|
||||
elif sort_by == "rr":
|
||||
entries.sort(
|
||||
key=lambda e: e["rr_ratio"] if e["rr_ratio"] is not None else -1,
|
||||
reverse=True,
|
||||
)
|
||||
else:
|
||||
# Sort by a specific dimension score
|
||||
def _dim_sort_key(e: dict) -> float:
|
||||
for d in e["dimensions"]:
|
||||
if d["dimension"] == sort_by:
|
||||
return d["score"]
|
||||
return -1.0
|
||||
|
||||
entries.sort(key=_dim_sort_key, reverse=True)
|
||||
|
||||
return entries
|
||||
Reference in New Issue
Block a user