first commit
This commit is contained in:
1
app/__init__.py
Normal file
1
app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
86
app/cache.py
Normal file
86
app/cache.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""LRU cache wrapper with per-ticker invalidation.
|
||||
|
||||
Provides an in-memory cache (max 1000 entries) keyed on
|
||||
(ticker, start_date, end_date, indicator_type). Supports selective
|
||||
invalidation of all entries for a given ticker — needed when new
|
||||
OHLCV data is ingested.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Hashable
|
||||
|
||||
CacheKey = tuple[str, Any, Any, str] # (ticker, start_date, end_date, indicator_type)
|
||||
|
||||
_DEFAULT_MAX_SIZE = 1000
|
||||
|
||||
|
||||
class LRUCache:
|
||||
"""Simple LRU cache backed by an ``OrderedDict``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_size:
|
||||
Maximum number of entries. When exceeded the least-recently-used
|
||||
entry is evicted. Defaults to 1000.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = _DEFAULT_MAX_SIZE) -> None:
|
||||
self._max_size = max_size
|
||||
self._store: OrderedDict[Hashable, Any] = OrderedDict()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get(self, key: CacheKey) -> Any | None:
|
||||
"""Return cached value or ``None`` on miss.
|
||||
|
||||
Accessing an entry promotes it to most-recently-used.
|
||||
"""
|
||||
if key not in self._store:
|
||||
return None
|
||||
self._store.move_to_end(key)
|
||||
return self._store[key]
|
||||
|
||||
def set(self, key: CacheKey, value: Any) -> None:
|
||||
"""Insert or update *key* with *value*.
|
||||
|
||||
If the cache is full the least-recently-used entry is evicted.
|
||||
"""
|
||||
if key in self._store:
|
||||
self._store.move_to_end(key)
|
||||
self._store[key] = value
|
||||
return
|
||||
if len(self._store) >= self._max_size:
|
||||
self._store.popitem(last=False) # evict LRU
|
||||
self._store[key] = value
|
||||
|
||||
def invalidate_ticker(self, ticker: str) -> int:
|
||||
"""Remove all entries whose first key element matches *ticker*.
|
||||
|
||||
Returns the number of evicted entries.
|
||||
"""
|
||||
keys_to_remove = [k for k in self._store if k[0] == ticker]
|
||||
for k in keys_to_remove:
|
||||
del self._store[k]
|
||||
return len(keys_to_remove)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all entries."""
|
||||
self._store.clear()
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Current number of cached entries."""
|
||||
return len(self._store)
|
||||
|
||||
@property
|
||||
def max_size(self) -> int:
|
||||
"""Maximum capacity."""
|
||||
return self._max_size
|
||||
|
||||
|
||||
# Module-level singleton used by the indicator service.
|
||||
indicator_cache = LRUCache()
|
||||
43
app/config.py
Normal file
43
app/config.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
|
||||
# Database
|
||||
database_url: str = "postgresql+asyncpg://stock_backend:changeme@localhost:5432/stock_data_backend"
|
||||
|
||||
# Auth
|
||||
jwt_secret: str = "change-this-to-a-random-secret"
|
||||
jwt_expiry_minutes: int = 60
|
||||
|
||||
# OHLCV Provider — Alpaca Markets
|
||||
alpaca_api_key: str = ""
|
||||
alpaca_api_secret: str = ""
|
||||
|
||||
# Sentiment Provider — Gemini with Search Grounding
|
||||
gemini_api_key: str = ""
|
||||
gemini_model: str = "gemini-2.0-flash"
|
||||
|
||||
# Fundamentals Provider — Financial Modeling Prep
|
||||
fmp_api_key: str = ""
|
||||
|
||||
# Scheduled Jobs
|
||||
data_collector_frequency: str = "daily"
|
||||
sentiment_poll_interval_minutes: int = 30
|
||||
fundamental_fetch_frequency: str = "daily"
|
||||
rr_scan_frequency: str = "daily"
|
||||
|
||||
# Scoring Defaults
|
||||
default_watchlist_auto_size: int = 10
|
||||
default_rr_threshold: float = 3.0
|
||||
|
||||
# Database Pool
|
||||
db_pool_size: int = 5
|
||||
db_pool_timeout: int = 30
|
||||
|
||||
# Logging
|
||||
log_level: str = "INFO"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
33
app/database.py
Normal file
33
app/database.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config import settings
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.database_url,
|
||||
pool_size=settings.db_pool_size,
|
||||
pool_timeout=settings.db_pool_timeout,
|
||||
pool_pre_ping=True,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
async_session_factory = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session_factory() as session:
|
||||
yield session
|
||||
82
app/dependencies.py
Normal file
82
app/dependencies.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""FastAPI dependency injection factories.
|
||||
|
||||
Provides DB session, current user extraction from JWT, and role/access guards.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_session
|
||||
from app.exceptions import AuthenticationError, AuthorizationError
|
||||
from app.models.user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Yield an async DB session."""
|
||||
async for session in get_session():
|
||||
yield session
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(_bearer_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""Extract and validate JWT from Authorization header, return the User."""
|
||||
if credentials is None:
|
||||
raise AuthenticationError("Authentication required")
|
||||
|
||||
token = credentials.credentials
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.jwt_secret,
|
||||
algorithms=[JWT_ALGORITHM],
|
||||
)
|
||||
user_id_str: str | None = payload.get("sub")
|
||||
if user_id_str is None:
|
||||
raise AuthenticationError("Invalid token: missing subject")
|
||||
user_id = int(user_id_str)
|
||||
except JWTError as exc:
|
||||
if "expired" in str(exc).lower():
|
||||
raise AuthenticationError("Token expired") from exc
|
||||
raise AuthenticationError("Invalid token") from exc
|
||||
except (ValueError, TypeError) as exc:
|
||||
raise AuthenticationError("Invalid token: bad subject") from exc
|
||||
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def require_admin(
|
||||
user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
"""Guard that ensures the current user has the admin role."""
|
||||
if user.role != "admin":
|
||||
raise AuthorizationError("Insufficient permissions")
|
||||
return user
|
||||
|
||||
|
||||
async def require_access(
|
||||
user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
"""Guard that ensures the current user has API access granted."""
|
||||
if not user.has_access:
|
||||
raise AuthorizationError("Insufficient permissions")
|
||||
return user
|
||||
52
app/exceptions.py
Normal file
52
app/exceptions.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Application exception hierarchy.
|
||||
|
||||
All custom exceptions inherit from AppError. The global exception handler
|
||||
in middleware.py catches these and returns the appropriate JSON envelope.
|
||||
"""
|
||||
|
||||
|
||||
class AppError(Exception):
|
||||
"""Base application error."""
|
||||
|
||||
status_code: int = 500
|
||||
message: str = "Internal server error"
|
||||
|
||||
def __init__(self, message: str | None = None):
|
||||
if message is not None:
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class ValidationError(AppError):
|
||||
status_code = 400
|
||||
message = "Validation error"
|
||||
|
||||
|
||||
class NotFoundError(AppError):
|
||||
status_code = 404
|
||||
message = "Resource not found"
|
||||
|
||||
|
||||
class DuplicateError(AppError):
|
||||
status_code = 409
|
||||
message = "Resource already exists"
|
||||
|
||||
|
||||
class AuthenticationError(AppError):
|
||||
status_code = 401
|
||||
message = "Authentication required"
|
||||
|
||||
|
||||
class AuthorizationError(AppError):
|
||||
status_code = 403
|
||||
message = "Insufficient permissions"
|
||||
|
||||
|
||||
class ProviderError(AppError):
|
||||
status_code = 502
|
||||
message = "Market data provider unavailable"
|
||||
|
||||
|
||||
class RateLimitError(AppError):
|
||||
status_code = 429
|
||||
message = "Rate limited"
|
||||
106
app/main.py
Normal file
106
app/main.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""FastAPI application entry point with lifespan management."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI
|
||||
from passlib.hash import bcrypt
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import async_session_factory, engine
|
||||
from app.middleware import register_exception_handlers
|
||||
from app.models.user import User
|
||||
from app.scheduler import configure_scheduler, scheduler
|
||||
from app.routers.admin import router as admin_router
|
||||
from app.routers.auth import router as auth_router
|
||||
from app.routers.health import router as health_router
|
||||
from app.routers.ingestion import router as ingestion_router
|
||||
from app.routers.ohlcv import router as ohlcv_router
|
||||
from app.routers.indicators import router as indicators_router
|
||||
from app.routers.fundamentals import router as fundamentals_router
|
||||
from app.routers.scores import router as scores_router
|
||||
from app.routers.trades import router as trades_router
|
||||
from app.routers.watchlist import router as watchlist_router
|
||||
from app.routers.sentiment import router as sentiment_router
|
||||
from app.routers.sr_levels import router as sr_levels_router
|
||||
from app.routers.tickers import router as tickers_router
|
||||
|
||||
|
||||
def _configure_logging() -> None:
|
||||
"""Set up structured JSON-style logging."""
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(
|
||||
logging.Formatter(
|
||||
'{"time":"%(asctime)s","level":"%(levelname)s",'
|
||||
'"logger":"%(name)s","message":"%(message)s"}'
|
||||
)
|
||||
)
|
||||
root = logging.getLogger()
|
||||
root.handlers.clear()
|
||||
root.addHandler(handler)
|
||||
root.setLevel(settings.log_level.upper())
|
||||
|
||||
|
||||
async def _create_default_admin(session: AsyncSession) -> None:
|
||||
"""Create the default admin account if no admin user exists."""
|
||||
result = await session.execute(
|
||||
select(User).where(User.role == "admin")
|
||||
)
|
||||
if result.scalar_one_or_none() is None:
|
||||
admin = User(
|
||||
username="admin",
|
||||
password_hash=bcrypt.hash("admin"),
|
||||
role="admin",
|
||||
has_access=True,
|
||||
)
|
||||
session.add(admin)
|
||||
await session.commit()
|
||||
logging.getLogger(__name__).info("Default admin account created")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Manage startup and shutdown lifecycle."""
|
||||
logger = logging.getLogger(__name__)
|
||||
_configure_logging()
|
||||
logger.info("Starting Stock Data Backend")
|
||||
|
||||
async with async_session_factory() as session:
|
||||
await _create_default_admin(session)
|
||||
|
||||
configure_scheduler()
|
||||
scheduler.start()
|
||||
logger.info("Scheduler started")
|
||||
|
||||
yield
|
||||
|
||||
scheduler.shutdown(wait=False)
|
||||
logger.info("Scheduler stopped")
|
||||
await engine.dispose()
|
||||
logger.info("Shutting down")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Stock Data Backend",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
register_exception_handlers(app)
|
||||
app.include_router(health_router, prefix="/api/v1")
|
||||
app.include_router(auth_router, prefix="/api/v1")
|
||||
app.include_router(admin_router, prefix="/api/v1")
|
||||
app.include_router(tickers_router, prefix="/api/v1")
|
||||
app.include_router(ohlcv_router, prefix="/api/v1")
|
||||
app.include_router(ingestion_router, prefix="/api/v1")
|
||||
app.include_router(indicators_router, prefix="/api/v1")
|
||||
app.include_router(sr_levels_router, prefix="/api/v1")
|
||||
app.include_router(sentiment_router, prefix="/api/v1")
|
||||
app.include_router(fundamentals_router, prefix="/api/v1")
|
||||
app.include_router(scores_router, prefix="/api/v1")
|
||||
app.include_router(trades_router, prefix="/api/v1")
|
||||
app.include_router(watchlist_router, prefix="/api/v1")
|
||||
61
app/middleware.py
Normal file
61
app/middleware.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Global exception handlers for the FastAPI application.
|
||||
|
||||
Maps AppError subclasses and other exceptions to JSON envelope responses.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.exceptions import AppError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register all global exception handlers on the FastAPI app."""
|
||||
|
||||
@app.exception_handler(AppError)
|
||||
async def app_error_handler(_request: Request, exc: AppError) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"status": "error",
|
||||
"data": None,
|
||||
"error": exc.message,
|
||||
},
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_error_handler(
|
||||
_request: Request, exc: RequestValidationError
|
||||
) -> JSONResponse:
|
||||
details = "; ".join(
|
||||
f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}"
|
||||
for e in exc.errors()
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"status": "error",
|
||||
"data": None,
|
||||
"error": f"Validation error: {details}",
|
||||
},
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_error_handler(
|
||||
_request: Request, exc: Exception
|
||||
) -> JSONResponse:
|
||||
logger.error("Unhandled exception:\n%s", traceback.format_exc())
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"status": "error",
|
||||
"data": None,
|
||||
"error": "Internal server error",
|
||||
},
|
||||
)
|
||||
25
app/models/__init__.py
Normal file
25
app/models/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from app.models.ticker import Ticker
|
||||
from app.models.ohlcv import OHLCVRecord
|
||||
from app.models.user import User
|
||||
from app.models.sentiment import SentimentScore
|
||||
from app.models.fundamental import FundamentalData
|
||||
from app.models.score import DimensionScore, CompositeScore
|
||||
from app.models.sr_level import SRLevel
|
||||
from app.models.trade_setup import TradeSetup
|
||||
from app.models.watchlist import WatchlistEntry
|
||||
from app.models.settings import SystemSetting, IngestionProgress
|
||||
|
||||
__all__ = [
|
||||
"Ticker",
|
||||
"OHLCVRecord",
|
||||
"User",
|
||||
"SentimentScore",
|
||||
"FundamentalData",
|
||||
"DimensionScore",
|
||||
"CompositeScore",
|
||||
"SRLevel",
|
||||
"TradeSetup",
|
||||
"WatchlistEntry",
|
||||
"SystemSetting",
|
||||
"IngestionProgress",
|
||||
]
|
||||
24
app/models/fundamental.py
Normal file
24
app/models/fundamental.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, Float, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class FundamentalData(Base):
|
||||
__tablename__ = "fundamental_data"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
ticker_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("tickers.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
pe_ratio: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
revenue_growth: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
earnings_surprise: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
market_cap: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
fetched_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False
|
||||
)
|
||||
|
||||
ticker = relationship("Ticker", back_populates="fundamental_data")
|
||||
30
app/models/ohlcv.py
Normal file
30
app/models/ohlcv.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from datetime import date, datetime
|
||||
|
||||
from sqlalchemy import BigInteger, Date, DateTime, Float, ForeignKey, Index, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class OHLCVRecord(Base):
|
||||
__tablename__ = "ohlcv_records"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("ticker_id", "date", name="uq_ohlcv_ticker_date"),
|
||||
Index("ix_ohlcv_ticker_date", "ticker_id", "date"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
ticker_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("tickers.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
open: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
high: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
low: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
close: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
volume: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
ticker = relationship("Ticker", back_populates="ohlcv_records")
|
||||
40
app/models/score.py
Normal file
40
app/models/score.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Float, ForeignKey, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class DimensionScore(Base):
|
||||
__tablename__ = "dimension_scores"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
ticker_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("tickers.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
dimension: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
score: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
is_stale: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
computed_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False
|
||||
)
|
||||
|
||||
ticker = relationship("Ticker", back_populates="dimension_scores")
|
||||
|
||||
|
||||
class CompositeScore(Base):
|
||||
__tablename__ = "composite_scores"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
ticker_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("tickers.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
score: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
is_stale: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
weights_json: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
computed_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False
|
||||
)
|
||||
|
||||
ticker = relationship("Ticker", back_populates="composite_scores")
|
||||
23
app/models/sentiment.py
Normal file
23
app/models/sentiment.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class SentimentScore(Base):
|
||||
__tablename__ = "sentiment_scores"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
ticker_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("tickers.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
classification: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
confidence: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
source: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False
|
||||
)
|
||||
|
||||
ticker = relationship("Ticker", back_populates="sentiment_scores")
|
||||
35
app/models/settings.py
Normal file
35
app/models/settings.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from datetime import date, datetime
|
||||
|
||||
from sqlalchemy import Date, DateTime, ForeignKey, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class SystemSetting(Base):
|
||||
__tablename__ = "system_settings"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
key: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
|
||||
value: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
|
||||
class IngestionProgress(Base):
|
||||
__tablename__ = "ingestion_progress"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("ticker_id", name="uq_ingestion_progress_ticker"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
ticker_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("tickers.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
last_ingested_date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
ticker = relationship("Ticker", back_populates="ingestion_progress")
|
||||
24
app/models/sr_level.py
Normal file
24
app/models/sr_level.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, Float, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class SRLevel(Base):
|
||||
__tablename__ = "sr_levels"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
ticker_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("tickers.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
price_level: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
strength: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
detection_method: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
ticker = relationship("Ticker", back_populates="sr_levels")
|
||||
27
app/models/ticker.py
Normal file
27
app/models/ticker.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, DateTime
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Ticker(Base):
|
||||
__tablename__ = "tickers"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
symbol: Mapped[str] = mapped_column(String(10), unique=True, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
# Relationships (cascade deletes)
|
||||
ohlcv_records = relationship("OHLCVRecord", back_populates="ticker", cascade="all, delete-orphan")
|
||||
sentiment_scores = relationship("SentimentScore", back_populates="ticker", cascade="all, delete-orphan")
|
||||
fundamental_data = relationship("FundamentalData", back_populates="ticker", cascade="all, delete-orphan")
|
||||
sr_levels = relationship("SRLevel", back_populates="ticker", cascade="all, delete-orphan")
|
||||
dimension_scores = relationship("DimensionScore", back_populates="ticker", cascade="all, delete-orphan")
|
||||
composite_scores = relationship("CompositeScore", back_populates="ticker", cascade="all, delete-orphan")
|
||||
trade_setups = relationship("TradeSetup", back_populates="ticker", cascade="all, delete-orphan")
|
||||
watchlist_entries = relationship("WatchlistEntry", back_populates="ticker", cascade="all, delete-orphan")
|
||||
ingestion_progress = relationship("IngestionProgress", back_populates="ticker", cascade="all, delete-orphan", uselist=False)
|
||||
26
app/models/trade_setup.py
Normal file
26
app/models/trade_setup.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, Float, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class TradeSetup(Base):
|
||||
__tablename__ = "trade_setups"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
ticker_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("tickers.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
direction: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
entry_price: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
stop_loss: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
target: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
rr_ratio: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
composite_score: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
detected_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False
|
||||
)
|
||||
|
||||
ticker = relationship("Ticker", back_populates="trade_setups")
|
||||
24
app/models/user.py
Normal file
24
app/models/user.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
username: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
role: Mapped[str] = mapped_column(String(20), nullable=False, default="user")
|
||||
has_access: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
watchlist_entries = relationship("WatchlistEntry", back_populates="user", cascade="all, delete-orphan")
|
||||
28
app/models/watchlist.py
Normal file
28
app/models/watchlist.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class WatchlistEntry(Base):
|
||||
__tablename__ = "watchlist_entries"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "ticker_id", name="uq_watchlist_user_ticker"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
ticker_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("tickers.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
entry_type: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
added_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
user = relationship("User", back_populates="watchlist_entries")
|
||||
ticker = relationship("Ticker", back_populates="watchlist_entries")
|
||||
1
app/providers/__init__.py
Normal file
1
app/providers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
63
app/providers/alpaca.py
Normal file
63
app/providers/alpaca.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Alpaca Markets OHLCV provider using the alpaca-py SDK."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import date
|
||||
|
||||
from alpaca.data.historical import StockHistoricalDataClient
|
||||
from alpaca.data.requests import StockBarsRequest
|
||||
from alpaca.data.timeframe import TimeFrame
|
||||
|
||||
from app.exceptions import ProviderError, RateLimitError
|
||||
from app.providers.protocol import OHLCVData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AlpacaOHLCVProvider:
|
||||
"""Fetches daily OHLCV bars from Alpaca Markets Data API."""
|
||||
|
||||
def __init__(self, api_key: str, api_secret: str) -> None:
|
||||
if not api_key or not api_secret:
|
||||
raise ProviderError("Alpaca API key and secret are required")
|
||||
self._client = StockHistoricalDataClient(api_key, api_secret)
|
||||
|
||||
async def fetch_ohlcv(
|
||||
self, ticker: str, start_date: date, end_date: date
|
||||
) -> list[OHLCVData]:
|
||||
"""Fetch daily OHLCV bars for *ticker* between *start_date* and *end_date*."""
|
||||
try:
|
||||
request = StockBarsRequest(
|
||||
symbol_or_symbols=ticker,
|
||||
timeframe=TimeFrame.Day,
|
||||
start=start_date,
|
||||
end=end_date,
|
||||
)
|
||||
|
||||
# alpaca-py's client is synchronous — run in a thread
|
||||
bars = await asyncio.to_thread(self._client.get_stock_bars, request)
|
||||
|
||||
results: list[OHLCVData] = []
|
||||
bar_set = bars.get(ticker, []) if hasattr(bars, "get") else getattr(bars, "data", {}).get(ticker, [])
|
||||
for bar in bar_set:
|
||||
results.append(
|
||||
OHLCVData(
|
||||
ticker=ticker,
|
||||
date=bar.timestamp.date(),
|
||||
open=float(bar.open),
|
||||
high=float(bar.high),
|
||||
low=float(bar.low),
|
||||
close=float(bar.close),
|
||||
volume=int(bar.volume),
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
except Exception as exc:
|
||||
msg = str(exc).lower()
|
||||
if "rate" in msg and "limit" in msg:
|
||||
raise RateLimitError(f"Alpaca rate limit hit for {ticker}") from exc
|
||||
logger.error("Alpaca provider error for %s: %s", ticker, exc)
|
||||
raise ProviderError(f"Alpaca provider error for {ticker}: {exc}") from exc
|
||||
94
app/providers/fmp.py
Normal file
94
app/providers/fmp.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Financial Modeling Prep (FMP) fundamentals provider using httpx."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import httpx
|
||||
|
||||
from app.exceptions import ProviderError, RateLimitError
|
||||
from app.providers.protocol import FundamentalData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_FMP_BASE_URL = "https://financialmodelingprep.com/api/v3"
|
||||
|
||||
|
||||
class FMPFundamentalProvider:
|
||||
"""Fetches fundamental data from Financial Modeling Prep REST API."""
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
if not api_key:
|
||||
raise ProviderError("FMP API key is required")
|
||||
self._api_key = api_key
|
||||
|
||||
async def fetch_fundamentals(self, ticker: str) -> FundamentalData:
|
||||
"""Fetch P/E, revenue growth, earnings surprise, and market cap."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
profile = await self._fetch_profile(client, ticker)
|
||||
earnings = await self._fetch_earnings_surprise(client, ticker)
|
||||
|
||||
pe_ratio = self._safe_float(profile.get("pe"))
|
||||
revenue_growth = self._safe_float(profile.get("revenueGrowth"))
|
||||
market_cap = self._safe_float(profile.get("mktCap"))
|
||||
earnings_surprise = self._safe_float(earnings)
|
||||
|
||||
return FundamentalData(
|
||||
ticker=ticker,
|
||||
pe_ratio=pe_ratio,
|
||||
revenue_growth=revenue_growth,
|
||||
earnings_surprise=earnings_surprise,
|
||||
market_cap=market_cap,
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
except (ProviderError, RateLimitError):
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("FMP provider error for %s: %s", ticker, exc)
|
||||
raise ProviderError(f"FMP provider error for {ticker}: {exc}") from exc
|
||||
|
||||
async def _fetch_profile(self, client: httpx.AsyncClient, ticker: str) -> dict:
|
||||
"""Fetch company profile (P/E, revenue growth, market cap)."""
|
||||
url = f"{_FMP_BASE_URL}/profile/{ticker}"
|
||||
resp = await client.get(url, params={"apikey": self._api_key})
|
||||
self._check_response(resp, ticker, "profile")
|
||||
data = resp.json()
|
||||
if isinstance(data, list) and data:
|
||||
return data[0]
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def _fetch_earnings_surprise(
|
||||
self, client: httpx.AsyncClient, ticker: str
|
||||
) -> float | None:
|
||||
"""Fetch the most recent earnings surprise percentage."""
|
||||
url = f"{_FMP_BASE_URL}/earnings-surprises/{ticker}"
|
||||
resp = await client.get(url, params={"apikey": self._api_key})
|
||||
self._check_response(resp, ticker, "earnings-surprises")
|
||||
data = resp.json()
|
||||
if isinstance(data, list) and data:
|
||||
return self._safe_float(data[0].get("actualEarningResult"))
|
||||
return None
|
||||
|
||||
def _check_response(
|
||||
self, resp: httpx.Response, ticker: str, endpoint: str
|
||||
) -> None:
|
||||
"""Raise appropriate errors for non-200 responses."""
|
||||
if resp.status_code == 429:
|
||||
raise RateLimitError(f"FMP rate limit hit for {ticker} ({endpoint})")
|
||||
if resp.status_code != 200:
|
||||
raise ProviderError(
|
||||
f"FMP {endpoint} error for {ticker}: HTTP {resp.status_code}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _safe_float(value: object) -> float | None:
|
||||
"""Convert a value to float, returning None on failure."""
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
90
app/providers/gemini_sentiment.py
Normal file
90
app/providers/gemini_sentiment.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Gemini sentiment provider using google-genai with search grounding."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from app.exceptions import ProviderError, RateLimitError
|
||||
from app.providers.protocol import SentimentData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SENTIMENT_PROMPT = """\
|
||||
Analyze the current market sentiment for the stock ticker {ticker}.
|
||||
Search the web for recent news articles, social media mentions, and analyst opinions.
|
||||
|
||||
Respond ONLY with a JSON object in this exact format (no markdown, no extra text):
|
||||
{{"classification": "<bullish|bearish|neutral>", "confidence": <0-100>, "reasoning": "<brief explanation>"}}
|
||||
|
||||
Rules:
|
||||
- classification must be exactly one of: bullish, bearish, neutral
|
||||
- confidence must be an integer from 0 to 100
|
||||
- reasoning should be a brief one-sentence explanation
|
||||
"""
|
||||
|
||||
VALID_CLASSIFICATIONS = {"bullish", "bearish", "neutral"}
|
||||
|
||||
|
||||
class GeminiSentimentProvider:
|
||||
"""Fetches sentiment analysis from Gemini with search grounding."""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "gemini-2.0-flash") -> None:
|
||||
if not api_key:
|
||||
raise ProviderError("Gemini API key is required")
|
||||
self._client = genai.Client(api_key=api_key)
|
||||
self._model = model
|
||||
|
||||
async def fetch_sentiment(self, ticker: str) -> SentimentData:
|
||||
"""Send a structured prompt to Gemini and parse the JSON response."""
|
||||
try:
|
||||
response = await self._client.aio.models.generate_content(
|
||||
model=self._model,
|
||||
contents=_SENTIMENT_PROMPT.format(ticker=ticker),
|
||||
config=types.GenerateContentConfig(
|
||||
tools=[types.Tool(google_search=types.GoogleSearch())],
|
||||
response_mime_type="application/json",
|
||||
),
|
||||
)
|
||||
|
||||
raw_text = response.text.strip()
|
||||
logger.debug("Gemini raw response for %s: %s", ticker, raw_text)
|
||||
parsed = json.loads(raw_text)
|
||||
|
||||
classification = parsed.get("classification", "").lower()
|
||||
if classification not in VALID_CLASSIFICATIONS:
|
||||
raise ProviderError(
|
||||
f"Invalid classification '{classification}' from Gemini for {ticker}"
|
||||
)
|
||||
|
||||
confidence = int(parsed.get("confidence", 50))
|
||||
confidence = max(0, min(100, confidence))
|
||||
|
||||
reasoning = parsed.get("reasoning", "")
|
||||
if reasoning:
|
||||
logger.info("Gemini sentiment for %s: %s (confidence=%d) — %s",
|
||||
ticker, classification, confidence, reasoning)
|
||||
|
||||
return SentimentData(
|
||||
ticker=ticker,
|
||||
classification=classification,
|
||||
confidence=confidence,
|
||||
source="gemini",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.error("Failed to parse Gemini JSON for %s: %s", ticker, exc)
|
||||
raise ProviderError(f"Invalid JSON from Gemini for {ticker}") from exc
|
||||
except ProviderError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
msg = str(exc).lower()
|
||||
if "rate" in msg or "quota" in msg or "429" in msg:
|
||||
raise RateLimitError(f"Gemini rate limit hit for {ticker}") from exc
|
||||
logger.error("Gemini provider error for %s: %s", ticker, exc)
|
||||
raise ProviderError(f"Gemini provider error for {ticker}: {exc}") from exc
|
||||
84
app/providers/protocol.py
Normal file
84
app/providers/protocol.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Provider protocols and lightweight data transfer objects.
|
||||
|
||||
Protocols define the interface for external data providers.
|
||||
DTOs are simple dataclasses — NOT SQLAlchemy models — used to
|
||||
transfer data between providers and the service layer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data Transfer Objects
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OHLCVData:
|
||||
"""Lightweight OHLCV record returned by market data providers."""
|
||||
|
||||
ticker: str
|
||||
date: date
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: int
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SentimentData:
|
||||
"""Sentiment analysis result returned by sentiment providers."""
|
||||
|
||||
ticker: str
|
||||
classification: str # "bullish" | "bearish" | "neutral"
|
||||
confidence: int # 0-100
|
||||
source: str
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FundamentalData:
|
||||
"""Fundamental metrics returned by fundamental providers."""
|
||||
|
||||
ticker: str
|
||||
pe_ratio: float | None
|
||||
revenue_growth: float | None
|
||||
earnings_surprise: float | None
|
||||
market_cap: float | None
|
||||
fetched_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider Protocols
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MarketDataProvider(Protocol):
|
||||
"""Protocol for OHLCV market data providers."""
|
||||
|
||||
async def fetch_ohlcv(
|
||||
self, ticker: str, start_date: date, end_date: date
|
||||
) -> list[OHLCVData]:
|
||||
"""Fetch OHLCV data for a ticker in a date range."""
|
||||
...
|
||||
|
||||
|
||||
class SentimentProvider(Protocol):
|
||||
"""Protocol for sentiment analysis providers."""
|
||||
|
||||
async def fetch_sentiment(self, ticker: str) -> SentimentData:
|
||||
"""Fetch current sentiment analysis for a ticker."""
|
||||
...
|
||||
|
||||
|
||||
class FundamentalProvider(Protocol):
|
||||
"""Protocol for fundamental data providers."""
|
||||
|
||||
async def fetch_fundamentals(self, ticker: str) -> FundamentalData:
|
||||
"""Fetch fundamental data for a ticker."""
|
||||
...
|
||||
1
app/routers/__init__.py
Normal file
1
app/routers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
193
app/routers/admin.py
Normal file
193
app/routers/admin.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Admin router: user management, system settings, data cleanup, job control.
|
||||
|
||||
All endpoints require admin role.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, require_admin
|
||||
from app.models.user import User
|
||||
from app.schemas.admin import (
|
||||
CreateUserRequest,
|
||||
DataCleanupRequest,
|
||||
JobToggle,
|
||||
PasswordReset,
|
||||
RegistrationToggle,
|
||||
SystemSettingUpdate,
|
||||
UserManagement,
|
||||
)
|
||||
from app.schemas.common import APIEnvelope
|
||||
from app.services import admin_service
|
||||
|
||||
router = APIRouter(tags=["admin"])
|
||||
|
||||
|
||||
def _user_dict(user: User) -> dict:
|
||||
return {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"role": user.role,
|
||||
"has_access": user.has_access,
|
||||
"created_at": user.created_at.isoformat() if user.created_at else None,
|
||||
"updated_at": user.updated_at.isoformat() if user.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/admin/users", response_model=APIEnvelope)
|
||||
async def list_users(
|
||||
_admin: User = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List all user accounts."""
|
||||
users = await admin_service.list_users(db)
|
||||
return APIEnvelope(status="success", data=[_user_dict(u) for u in users])
|
||||
|
||||
|
||||
@router.post("/admin/users", response_model=APIEnvelope, status_code=201)
|
||||
async def create_user(
|
||||
body: CreateUserRequest,
|
||||
_admin: User = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new user account."""
|
||||
user = await admin_service.create_user(
|
||||
db, body.username, body.password, body.role, body.has_access
|
||||
)
|
||||
return APIEnvelope(status="success", data=_user_dict(user))
|
||||
|
||||
|
||||
@router.put("/admin/users/{user_id}/access", response_model=APIEnvelope)
|
||||
async def set_user_access(
|
||||
user_id: int,
|
||||
body: UserManagement,
|
||||
_admin: User = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Grant or revoke API access for a user."""
|
||||
user = await admin_service.set_user_access(db, user_id, body.has_access)
|
||||
return APIEnvelope(status="success", data=_user_dict(user))
|
||||
|
||||
|
||||
@router.put("/admin/users/{user_id}/password", response_model=APIEnvelope)
|
||||
async def reset_password(
|
||||
user_id: int,
|
||||
body: PasswordReset,
|
||||
_admin: User = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Reset a user's password."""
|
||||
user = await admin_service.reset_password(db, user_id, body.new_password)
|
||||
return APIEnvelope(status="success", data=_user_dict(user))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registration toggle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.put("/admin/settings/registration", response_model=APIEnvelope)
|
||||
async def toggle_registration(
|
||||
body: RegistrationToggle,
|
||||
_admin: User = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Enable or disable user registration."""
|
||||
setting = await admin_service.toggle_registration(db, body.enabled)
|
||||
return APIEnvelope(
|
||||
status="success",
|
||||
data={"key": setting.key, "value": setting.value},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System settings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/admin/settings", response_model=APIEnvelope)
|
||||
async def list_settings(
|
||||
_admin: User = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List all system settings."""
|
||||
settings_list = await admin_service.list_settings(db)
|
||||
return APIEnvelope(
|
||||
status="success",
|
||||
data=[
|
||||
{"key": s.key, "value": s.value, "updated_at": s.updated_at.isoformat() if s.updated_at else None}
|
||||
for s in settings_list
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.put("/admin/settings/{key}", response_model=APIEnvelope)
|
||||
async def update_setting(
|
||||
key: str,
|
||||
body: SystemSettingUpdate,
|
||||
_admin: User = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create or update a system setting."""
|
||||
setting = await admin_service.update_setting(db, key, body.value)
|
||||
return APIEnvelope(
|
||||
status="success",
|
||||
data={"key": setting.key, "value": setting.value, "updated_at": setting.updated_at.isoformat() if setting.updated_at else None},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/admin/data/cleanup", response_model=APIEnvelope)
|
||||
async def cleanup_data(
|
||||
body: DataCleanupRequest,
|
||||
_admin: User = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete OHLCV, sentiment, and fundamental data older than N days."""
|
||||
counts = await admin_service.cleanup_data(db, body.older_than_days)
|
||||
return APIEnvelope(status="success", data=counts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Job control
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/admin/jobs", response_model=APIEnvelope)
|
||||
async def list_jobs(
|
||||
_admin: User = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List all scheduled jobs with their current status."""
|
||||
jobs = await admin_service.list_jobs(db)
|
||||
return APIEnvelope(status="success", data=jobs)
|
||||
|
||||
|
||||
@router.post("/admin/jobs/{job_name}/trigger", response_model=APIEnvelope)
|
||||
async def trigger_job(
|
||||
job_name: str,
|
||||
_admin: User = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Trigger a manual job run (placeholder)."""
|
||||
result = await admin_service.trigger_job(db, job_name)
|
||||
return APIEnvelope(status="success", data=result)
|
||||
|
||||
|
||||
@router.put("/admin/jobs/{job_name}/toggle", response_model=APIEnvelope)
|
||||
async def toggle_job(
|
||||
job_name: str,
|
||||
body: JobToggle,
|
||||
_admin: User = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Enable or disable a scheduled job (placeholder)."""
|
||||
setting = await admin_service.toggle_job(db, job_name, body.enabled)
|
||||
return APIEnvelope(
|
||||
status="success",
|
||||
data={"key": setting.key, "value": setting.value},
|
||||
)
|
||||
34
app/routers/auth.py
Normal file
34
app/routers/auth.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Auth router: registration and login endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db
|
||||
from app.schemas.auth import LoginRequest, RegisterRequest, TokenResponse
|
||||
from app.schemas.common import APIEnvelope
|
||||
from app.services import auth_service
|
||||
|
||||
router = APIRouter(tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/auth/register", response_model=APIEnvelope)
|
||||
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Public endpoint — register a new user."""
|
||||
user = await auth_service.register(db, body.username, body.password)
|
||||
return APIEnvelope(
|
||||
status="success",
|
||||
data={
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"role": user.role,
|
||||
"has_access": user.has_access,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/auth/login", response_model=APIEnvelope)
|
||||
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Public endpoint — login and receive a JWT."""
|
||||
token = await auth_service.login(db, body.username, body.password)
|
||||
token_resp = TokenResponse(access_token=token)
|
||||
return APIEnvelope(status="success", data=token_resp.model_dump())
|
||||
35
app/routers/fundamentals.py
Normal file
35
app/routers/fundamentals.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Fundamentals router — fundamental data endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, require_access
|
||||
from app.schemas.common import APIEnvelope
|
||||
from app.schemas.fundamental import FundamentalResponse
|
||||
from app.services.fundamental_service import get_fundamental
|
||||
|
||||
router = APIRouter(tags=["fundamentals"])
|
||||
|
||||
|
||||
@router.get("/fundamentals/{symbol}", response_model=APIEnvelope)
|
||||
async def read_fundamentals(
|
||||
symbol: str,
|
||||
_user=Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> APIEnvelope:
|
||||
"""Get latest fundamental data for a symbol."""
|
||||
record = await get_fundamental(db, symbol)
|
||||
|
||||
if record is None:
|
||||
data = FundamentalResponse(symbol=symbol.strip().upper())
|
||||
else:
|
||||
data = FundamentalResponse(
|
||||
symbol=symbol.strip().upper(),
|
||||
pe_ratio=record.pe_ratio,
|
||||
revenue_growth=record.revenue_growth,
|
||||
earnings_surprise=record.earnings_surprise,
|
||||
market_cap=record.market_cap,
|
||||
fetched_at=record.fetched_at,
|
||||
)
|
||||
|
||||
return APIEnvelope(status="success", data=data.model_dump())
|
||||
36
app/routers/health.py
Normal file
36
app/routers/health.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Health check endpoint — unauthenticated."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db
|
||||
from app.schemas.common import APIEnvelope
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["health"])
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check(db: AsyncSession = Depends(get_db)) -> APIEnvelope:
|
||||
"""Return service health including database connectivity."""
|
||||
try:
|
||||
await db.execute(text("SELECT 1"))
|
||||
return APIEnvelope(
|
||||
status="success",
|
||||
data={"status": "healthy", "database": "connected"},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Health check: database unreachable")
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"status": "error",
|
||||
"data": None,
|
||||
"error": "Database unreachable",
|
||||
},
|
||||
)
|
||||
64
app/routers/indicators.py
Normal file
64
app/routers/indicators.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Indicators router — technical analysis endpoints."""
|
||||
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, require_access
|
||||
from app.schemas.common import APIEnvelope
|
||||
from app.schemas.indicator import (
|
||||
EMACrossResponse,
|
||||
EMACrossResult,
|
||||
IndicatorResponse,
|
||||
IndicatorResult,
|
||||
)
|
||||
from app.services.indicator_service import get_ema_cross, get_indicator
|
||||
|
||||
router = APIRouter(tags=["indicators"])
|
||||
|
||||
|
||||
# NOTE: ema-cross must be registered BEFORE {indicator_type} to avoid
|
||||
# FastAPI matching "ema-cross" as an indicator_type path parameter.
|
||||
|
||||
|
||||
@router.get("/indicators/{symbol}/ema-cross", response_model=APIEnvelope)
|
||||
async def read_ema_cross(
|
||||
symbol: str,
|
||||
start_date: date | None = Query(None),
|
||||
end_date: date | None = Query(None),
|
||||
short_period: int = Query(20),
|
||||
long_period: int = Query(50),
|
||||
_user=Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> APIEnvelope:
|
||||
"""Compute EMA cross signal for a symbol."""
|
||||
result = await get_ema_cross(
|
||||
db, symbol, start_date, end_date, short_period, long_period
|
||||
)
|
||||
data = EMACrossResponse(
|
||||
symbol=symbol.upper(),
|
||||
ema_cross=EMACrossResult(**result),
|
||||
)
|
||||
return APIEnvelope(status="success", data=data.model_dump())
|
||||
|
||||
|
||||
@router.get("/indicators/{symbol}/{indicator_type}", response_model=APIEnvelope)
|
||||
async def read_indicator(
|
||||
symbol: str,
|
||||
indicator_type: str,
|
||||
start_date: date | None = Query(None),
|
||||
end_date: date | None = Query(None),
|
||||
period: int | None = Query(None),
|
||||
_user=Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> APIEnvelope:
|
||||
"""Compute a technical indicator for a symbol."""
|
||||
result = await get_indicator(
|
||||
db, symbol, indicator_type, start_date, end_date, period
|
||||
)
|
||||
data = IndicatorResponse(
|
||||
symbol=symbol.upper(),
|
||||
indicator=IndicatorResult(**result),
|
||||
)
|
||||
return APIEnvelope(status="success", data=data.model_dump())
|
||||
127
app/routers/ingestion.py
Normal file
127
app/routers/ingestion.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Ingestion router: trigger data fetches from the market data provider.
|
||||
|
||||
Provides both a single-source OHLCV endpoint and a comprehensive
|
||||
fetch-all endpoint that collects OHLCV + sentiment + fundamentals
|
||||
in one call with per-source status reporting.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.dependencies import get_db, require_access
|
||||
from app.exceptions import ProviderError
|
||||
from app.models.user import User
|
||||
from app.providers.alpaca import AlpacaOHLCVProvider
|
||||
from app.providers.fmp import FMPFundamentalProvider
|
||||
from app.providers.gemini_sentiment import GeminiSentimentProvider
|
||||
from app.schemas.common import APIEnvelope
|
||||
from app.services import fundamental_service, ingestion_service, sentiment_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["ingestion"])
|
||||
|
||||
|
||||
def _get_provider() -> AlpacaOHLCVProvider:
|
||||
"""Build the OHLCV provider from current settings."""
|
||||
if not settings.alpaca_api_key or not settings.alpaca_api_secret:
|
||||
raise ProviderError("Alpaca API credentials not configured")
|
||||
return AlpacaOHLCVProvider(settings.alpaca_api_key, settings.alpaca_api_secret)
|
||||
|
||||
|
||||
@router.post("/ingestion/fetch/{symbol}", response_model=APIEnvelope)
|
||||
async def fetch_symbol(
|
||||
symbol: str,
|
||||
start_date: date | None = Query(None, description="Start date (YYYY-MM-DD)"),
|
||||
end_date: date | None = Query(None, description="End date (YYYY-MM-DD)"),
|
||||
_user: User = Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Fetch all data sources for a ticker: OHLCV, sentiment, and fundamentals.
|
||||
|
||||
Returns a per-source breakdown so the frontend can show exactly what
|
||||
succeeded and what failed.
|
||||
"""
|
||||
symbol_upper = symbol.strip().upper()
|
||||
sources: dict[str, dict] = {}
|
||||
|
||||
# --- OHLCV ---
|
||||
try:
|
||||
provider = _get_provider()
|
||||
result = await ingestion_service.fetch_and_ingest(
|
||||
db, provider, symbol_upper, start_date, end_date
|
||||
)
|
||||
sources["ohlcv"] = {
|
||||
"status": "ok" if result.status in ("complete", "partial") else "error",
|
||||
"records": result.records_ingested,
|
||||
"message": result.message,
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error("OHLCV fetch failed for %s: %s", symbol_upper, exc)
|
||||
sources["ohlcv"] = {"status": "error", "records": 0, "message": str(exc)}
|
||||
|
||||
# --- Sentiment ---
|
||||
if settings.gemini_api_key:
|
||||
try:
|
||||
sent_provider = GeminiSentimentProvider(
|
||||
settings.gemini_api_key, settings.gemini_model
|
||||
)
|
||||
data = await sent_provider.fetch_sentiment(symbol_upper)
|
||||
await sentiment_service.store_sentiment(
|
||||
db,
|
||||
symbol=symbol_upper,
|
||||
classification=data.classification,
|
||||
confidence=data.confidence,
|
||||
source=data.source,
|
||||
timestamp=data.timestamp,
|
||||
)
|
||||
sources["sentiment"] = {
|
||||
"status": "ok",
|
||||
"classification": data.classification,
|
||||
"confidence": data.confidence,
|
||||
"message": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error("Sentiment fetch failed for %s: %s", symbol_upper, exc)
|
||||
sources["sentiment"] = {"status": "error", "message": str(exc)}
|
||||
else:
|
||||
sources["sentiment"] = {
|
||||
"status": "skipped",
|
||||
"message": "Gemini API key not configured",
|
||||
}
|
||||
|
||||
# --- Fundamentals ---
|
||||
if settings.fmp_api_key:
|
||||
try:
|
||||
fmp_provider = FMPFundamentalProvider(settings.fmp_api_key)
|
||||
fdata = await fmp_provider.fetch_fundamentals(symbol_upper)
|
||||
await fundamental_service.store_fundamental(
|
||||
db,
|
||||
symbol=symbol_upper,
|
||||
pe_ratio=fdata.pe_ratio,
|
||||
revenue_growth=fdata.revenue_growth,
|
||||
earnings_surprise=fdata.earnings_surprise,
|
||||
market_cap=fdata.market_cap,
|
||||
)
|
||||
sources["fundamentals"] = {"status": "ok", "message": None}
|
||||
except Exception as exc:
|
||||
logger.error("Fundamentals fetch failed for %s: %s", symbol_upper, exc)
|
||||
sources["fundamentals"] = {"status": "error", "message": str(exc)}
|
||||
else:
|
||||
sources["fundamentals"] = {
|
||||
"status": "skipped",
|
||||
"message": "FMP API key not configured",
|
||||
}
|
||||
|
||||
# Always return success — per-source breakdown tells the full story
|
||||
return APIEnvelope(
|
||||
status="success",
|
||||
data={"symbol": symbol_upper, "sources": sources},
|
||||
error=None,
|
||||
)
|
||||
56
app/routers/ohlcv.py
Normal file
56
app/routers/ohlcv.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""OHLCV router: endpoints for storing and querying price data."""
|
||||
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, require_access
|
||||
from app.models.user import User
|
||||
from app.schemas.common import APIEnvelope
|
||||
from app.schemas.ohlcv import OHLCVCreate, OHLCVResponse
|
||||
from app.services import price_service
|
||||
|
||||
router = APIRouter(tags=["ohlcv"])
|
||||
|
||||
|
||||
@router.post("/ohlcv", response_model=APIEnvelope)
|
||||
async def create_ohlcv(
|
||||
body: OHLCVCreate,
|
||||
_user: User = Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Upsert an OHLCV record for a ticker and date."""
|
||||
record = await price_service.upsert_ohlcv(
|
||||
db,
|
||||
symbol=body.symbol,
|
||||
record_date=body.date,
|
||||
open_=body.open,
|
||||
high=body.high,
|
||||
low=body.low,
|
||||
close=body.close,
|
||||
volume=body.volume,
|
||||
)
|
||||
return APIEnvelope(
|
||||
status="success",
|
||||
data=OHLCVResponse.model_validate(record).model_dump(mode="json"),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/ohlcv/{symbol}", response_model=APIEnvelope)
|
||||
async def get_ohlcv(
|
||||
symbol: str,
|
||||
start_date: date | None = Query(None, description="Start date (YYYY-MM-DD)"),
|
||||
end_date: date | None = Query(None, description="End date (YYYY-MM-DD)"),
|
||||
_user: User = Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Query OHLCV records for a ticker, optionally filtered by date range."""
|
||||
records = await price_service.query_ohlcv(db, symbol, start_date, end_date)
|
||||
return APIEnvelope(
|
||||
status="success",
|
||||
data=[
|
||||
OHLCVResponse.model_validate(r).model_dump(mode="json")
|
||||
for r in records
|
||||
],
|
||||
)
|
||||
75
app/routers/scores.py
Normal file
75
app/routers/scores.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Scores router — scoring engine endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, require_access
|
||||
from app.schemas.common import APIEnvelope
|
||||
from app.schemas.score import (
|
||||
DimensionScoreResponse,
|
||||
RankingEntry,
|
||||
RankingResponse,
|
||||
ScoreResponse,
|
||||
WeightUpdateRequest,
|
||||
)
|
||||
from app.services.scoring_service import get_rankings, get_score, update_weights
|
||||
|
||||
router = APIRouter(tags=["scores"])
|
||||
|
||||
|
||||
@router.get("/scores/{symbol}", response_model=APIEnvelope)
|
||||
async def read_score(
|
||||
symbol: str,
|
||||
_user=Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> APIEnvelope:
|
||||
"""Get composite + dimension scores for a symbol. Recomputes stale scores."""
|
||||
result = await get_score(db, symbol)
|
||||
|
||||
data = ScoreResponse(
|
||||
symbol=result["symbol"],
|
||||
composite_score=result["composite_score"],
|
||||
composite_stale=result["composite_stale"],
|
||||
weights=result["weights"],
|
||||
dimensions=[
|
||||
DimensionScoreResponse(**d) for d in result["dimensions"]
|
||||
],
|
||||
missing_dimensions=result["missing_dimensions"],
|
||||
computed_at=result["computed_at"],
|
||||
)
|
||||
return APIEnvelope(status="success", data=data.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.get("/rankings", response_model=APIEnvelope)
|
||||
async def read_rankings(
|
||||
_user=Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> APIEnvelope:
|
||||
"""Get all tickers ranked by composite score descending."""
|
||||
result = await get_rankings(db)
|
||||
|
||||
data = RankingResponse(
|
||||
rankings=[
|
||||
RankingEntry(
|
||||
symbol=r["symbol"],
|
||||
composite_score=r["composite_score"],
|
||||
dimensions=[
|
||||
DimensionScoreResponse(**d) for d in r["dimensions"]
|
||||
],
|
||||
)
|
||||
for r in result["rankings"]
|
||||
],
|
||||
weights=result["weights"],
|
||||
)
|
||||
return APIEnvelope(status="success", data=data.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.put("/scores/weights", response_model=APIEnvelope)
|
||||
async def update_score_weights(
|
||||
body: WeightUpdateRequest,
|
||||
_user=Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> APIEnvelope:
|
||||
"""Update dimension weights and recompute all composite scores."""
|
||||
new_weights = await update_weights(db, body.weights)
|
||||
return APIEnvelope(status="success", data={"weights": new_weights})
|
||||
46
app/routers/sentiment.py
Normal file
46
app/routers/sentiment.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Sentiment router — sentiment data endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, require_access
|
||||
from app.schemas.common import APIEnvelope
|
||||
from app.schemas.sentiment import SentimentResponse, SentimentScoreResult
|
||||
from app.services.sentiment_service import (
|
||||
compute_sentiment_dimension_score,
|
||||
get_sentiment_scores,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["sentiment"])
|
||||
|
||||
|
||||
@router.get("/sentiment/{symbol}", response_model=APIEnvelope)
|
||||
async def read_sentiment(
|
||||
symbol: str,
|
||||
lookback_hours: float = Query(24, gt=0, description="Lookback window in hours"),
|
||||
_user=Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> APIEnvelope:
|
||||
"""Get recent sentiment scores and computed dimension score for a symbol."""
|
||||
scores = await get_sentiment_scores(db, symbol, lookback_hours)
|
||||
dimension_score = await compute_sentiment_dimension_score(
|
||||
db, symbol, lookback_hours
|
||||
)
|
||||
|
||||
data = SentimentResponse(
|
||||
symbol=symbol.strip().upper(),
|
||||
scores=[
|
||||
SentimentScoreResult(
|
||||
id=s.id,
|
||||
classification=s.classification,
|
||||
confidence=s.confidence,
|
||||
source=s.source,
|
||||
timestamp=s.timestamp,
|
||||
)
|
||||
for s in scores
|
||||
],
|
||||
count=len(scores),
|
||||
dimension_score=round(dimension_score, 2) if dimension_score is not None else None,
|
||||
lookback_hours=lookback_hours,
|
||||
)
|
||||
return APIEnvelope(status="success", data=data.model_dump())
|
||||
38
app/routers/sr_levels.py
Normal file
38
app/routers/sr_levels.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""S/R Levels router — support/resistance detection endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, require_access
|
||||
from app.schemas.common import APIEnvelope
|
||||
from app.schemas.sr_level import SRLevelResponse, SRLevelResult
|
||||
from app.services.sr_service import get_sr_levels
|
||||
|
||||
router = APIRouter(tags=["sr-levels"])
|
||||
|
||||
|
||||
@router.get("/sr-levels/{symbol}", response_model=APIEnvelope)
|
||||
async def read_sr_levels(
|
||||
symbol: str,
|
||||
tolerance: float = Query(0.005, ge=0, le=0.1, description="Merge tolerance (default 0.5%)"),
|
||||
_user=Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> APIEnvelope:
|
||||
"""Get support/resistance levels for a symbol, sorted by strength descending."""
|
||||
levels = await get_sr_levels(db, symbol, tolerance)
|
||||
data = SRLevelResponse(
|
||||
symbol=symbol.upper(),
|
||||
levels=[
|
||||
SRLevelResult(
|
||||
id=lvl.id,
|
||||
price_level=lvl.price_level,
|
||||
type=lvl.type,
|
||||
strength=lvl.strength,
|
||||
detection_method=lvl.detection_method,
|
||||
created_at=lvl.created_at,
|
||||
)
|
||||
for lvl in levels
|
||||
],
|
||||
count=len(levels),
|
||||
)
|
||||
return APIEnvelope(status="success", data=data.model_dump())
|
||||
53
app/routers/tickers.py
Normal file
53
app/routers/tickers.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Tickers router: CRUD endpoints for the Ticker Registry."""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, require_access
|
||||
from app.models.user import User
|
||||
from app.schemas.common import APIEnvelope
|
||||
from app.schemas.ticker import TickerCreate, TickerResponse
|
||||
from app.services import ticker_service
|
||||
|
||||
router = APIRouter(tags=["tickers"])
|
||||
|
||||
|
||||
@router.post("/tickers", response_model=APIEnvelope)
|
||||
async def create_ticker(
|
||||
body: TickerCreate,
|
||||
_user: User = Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Add a new ticker to the registry."""
|
||||
ticker = await ticker_service.add_ticker(db, body.symbol)
|
||||
return APIEnvelope(
|
||||
status="success",
|
||||
data=TickerResponse.model_validate(ticker).model_dump(mode="json"),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tickers", response_model=APIEnvelope)
|
||||
async def list_tickers(
|
||||
_user: User = Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List all tracked tickers sorted alphabetically."""
|
||||
tickers = await ticker_service.list_tickers(db)
|
||||
return APIEnvelope(
|
||||
status="success",
|
||||
data=[
|
||||
TickerResponse.model_validate(t).model_dump(mode="json")
|
||||
for t in tickers
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/tickers/{symbol}", response_model=APIEnvelope)
|
||||
async def delete_ticker(
|
||||
symbol: str,
|
||||
_user: User = Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete a ticker and all associated data."""
|
||||
await ticker_service.delete_ticker(db, symbol)
|
||||
return APIEnvelope(status="success", data=None)
|
||||
28
app/routers/trades.py
Normal file
28
app/routers/trades.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Trades router — R:R scanner trade setup endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, require_access
|
||||
from app.schemas.common import APIEnvelope
|
||||
from app.schemas.trade_setup import TradeSetupResponse
|
||||
from app.services.rr_scanner_service import get_trade_setups
|
||||
|
||||
router = APIRouter(tags=["trades"])
|
||||
|
||||
|
||||
@router.get("/trades", response_model=APIEnvelope)
|
||||
async def list_trade_setups(
|
||||
direction: str | None = Query(
|
||||
None, description="Filter by direction: long or short"
|
||||
),
|
||||
_user=Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> APIEnvelope:
|
||||
"""Get all trade setups sorted by R:R desc, secondary composite desc.
|
||||
|
||||
Optional direction filter (long/short).
|
||||
"""
|
||||
rows = await get_trade_setups(db, direction=direction)
|
||||
data = [TradeSetupResponse(**r).model_dump(mode="json") for r in rows]
|
||||
return APIEnvelope(status="success", data=data)
|
||||
59
app/routers/watchlist.py
Normal file
59
app/routers/watchlist.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Watchlist router — manage user's curated watchlist."""
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, require_access
|
||||
from app.models.user import User
|
||||
from app.schemas.common import APIEnvelope
|
||||
from app.schemas.watchlist import WatchlistEntryResponse
|
||||
from app.services.watchlist_service import (
|
||||
add_manual_entry,
|
||||
get_watchlist,
|
||||
remove_entry,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["watchlist"])
|
||||
|
||||
|
||||
@router.get("/watchlist", response_model=APIEnvelope)
|
||||
async def list_watchlist(
|
||||
sort_by: str = Query(
|
||||
"composite",
|
||||
description=(
|
||||
"Sort by: composite, rr, or a dimension name "
|
||||
"(technical, sr_quality, sentiment, fundamental, momentum)"
|
||||
),
|
||||
),
|
||||
user: User = Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> APIEnvelope:
|
||||
"""Get current user's watchlist with enriched data."""
|
||||
rows = await get_watchlist(db, user.id, sort_by=sort_by)
|
||||
data = [WatchlistEntryResponse(**r).model_dump(mode="json") for r in rows]
|
||||
return APIEnvelope(status="success", data=data)
|
||||
|
||||
|
||||
@router.post("/watchlist/{symbol}", response_model=APIEnvelope)
|
||||
async def add_to_watchlist(
|
||||
symbol: str,
|
||||
user: User = Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> APIEnvelope:
|
||||
"""Add a manual entry to the watchlist."""
|
||||
entry = await add_manual_entry(db, user.id, symbol)
|
||||
return APIEnvelope(
|
||||
status="success",
|
||||
data={"symbol": symbol.strip().upper(), "entry_type": entry.entry_type},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/watchlist/{symbol}", response_model=APIEnvelope)
|
||||
async def remove_from_watchlist(
|
||||
symbol: str,
|
||||
user: User = Depends(require_access),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> APIEnvelope:
|
||||
"""Remove an entry from the watchlist."""
|
||||
await remove_entry(db, user.id, symbol)
|
||||
return APIEnvelope(status="success", data=None)
|
||||
437
app/scheduler.py
Normal file
437
app/scheduler.py
Normal file
@@ -0,0 +1,437 @@
|
||||
"""APScheduler job definitions and FastAPI lifespan integration.
|
||||
|
||||
Defines four scheduled jobs:
|
||||
- Data Collector (OHLCV fetch for all tickers)
|
||||
- Sentiment Collector (sentiment for all tickers)
|
||||
- Fundamental Collector (fundamentals for all tickers)
|
||||
- R:R Scanner (trade setup scan for all tickers)
|
||||
|
||||
Each job processes tickers independently, logs errors as structured JSON,
|
||||
handles rate limits by recording the last successful ticker, and checks
|
||||
SystemSetting for enabled/disabled state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import date, timedelta
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import async_session_factory
|
||||
from app.models.settings import SystemSetting
|
||||
from app.models.ticker import Ticker
|
||||
from app.providers.alpaca import AlpacaOHLCVProvider
|
||||
from app.providers.fmp import FMPFundamentalProvider
|
||||
from app.providers.gemini_sentiment import GeminiSentimentProvider
|
||||
from app.services import fundamental_service, ingestion_service, sentiment_service
|
||||
from app.services.rr_scanner_service import scan_all_tickers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level scheduler instance
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
# Track last successful ticker per job for rate-limit resume
|
||||
_last_successful: dict[str, str | None] = {
|
||||
"data_collector": None,
|
||||
"sentiment_collector": None,
|
||||
"fundamental_collector": None,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _log_job_error(job_name: str, ticker: str, error: Exception) -> None:
|
||||
"""Log a job error as structured JSON."""
|
||||
logger.error(
|
||||
json.dumps({
|
||||
"event": "job_error",
|
||||
"job": job_name,
|
||||
"ticker": ticker,
|
||||
"error_type": type(error).__name__,
|
||||
"message": str(error),
|
||||
})
|
||||
)
|
||||
|
||||
|
||||
async def _is_job_enabled(db: AsyncSession, job_name: str) -> bool:
|
||||
"""Check SystemSetting for job enabled state. Defaults to True."""
|
||||
key = f"job_{job_name}_enabled"
|
||||
result = await db.execute(
|
||||
select(SystemSetting).where(SystemSetting.key == key)
|
||||
)
|
||||
setting = result.scalar_one_or_none()
|
||||
if setting is None:
|
||||
return True
|
||||
return setting.value.lower() == "true"
|
||||
|
||||
|
||||
async def _get_all_tickers(db: AsyncSession) -> list[str]:
|
||||
"""Return all tracked ticker symbols sorted alphabetically."""
|
||||
result = await db.execute(select(Ticker.symbol).order_by(Ticker.symbol))
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
def _resume_tickers(symbols: list[str], job_name: str) -> list[str]:
|
||||
"""Reorder tickers to resume after the last successful one (rate-limit resume).
|
||||
|
||||
If a previous run was rate-limited, start from the ticker after the last
|
||||
successful one. Otherwise return the full list.
|
||||
"""
|
||||
last = _last_successful.get(job_name)
|
||||
if last is None or last not in symbols:
|
||||
return symbols
|
||||
idx = symbols.index(last)
|
||||
# Start from the next ticker, then wrap around
|
||||
return symbols[idx + 1:] + symbols[:idx + 1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Job: Data Collector (OHLCV)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def collect_ohlcv() -> None:
|
||||
"""Fetch latest daily OHLCV for all tracked tickers.
|
||||
|
||||
Uses AlpacaOHLCVProvider. Processes each ticker independently.
|
||||
On rate limit, records last successful ticker for resume.
|
||||
"""
|
||||
job_name = "data_collector"
|
||||
logger.info(json.dumps({"event": "job_start", "job": job_name}))
|
||||
|
||||
async with async_session_factory() as db:
|
||||
if not await _is_job_enabled(db, job_name):
|
||||
logger.info(json.dumps({"event": "job_skipped", "job": job_name, "reason": "disabled"}))
|
||||
return
|
||||
|
||||
symbols = await _get_all_tickers(db)
|
||||
if not symbols:
|
||||
logger.info(json.dumps({"event": "job_complete", "job": job_name, "tickers": 0}))
|
||||
return
|
||||
|
||||
# Reorder for rate-limit resume
|
||||
symbols = _resume_tickers(symbols, job_name)
|
||||
|
||||
# Build provider (skip if keys not configured)
|
||||
if not settings.alpaca_api_key or not settings.alpaca_api_secret:
|
||||
logger.warning(json.dumps({"event": "job_skipped", "job": job_name, "reason": "alpaca keys not configured"}))
|
||||
return
|
||||
|
||||
try:
|
||||
provider = AlpacaOHLCVProvider(settings.alpaca_api_key, settings.alpaca_api_secret)
|
||||
except Exception as exc:
|
||||
logger.error(json.dumps({"event": "job_error", "job": job_name, "error_type": type(exc).__name__, "message": str(exc)}))
|
||||
return
|
||||
|
||||
end_date = date.today()
|
||||
start_date = end_date - timedelta(days=5) # Fetch last 5 days to catch up
|
||||
processed = 0
|
||||
|
||||
for symbol in symbols:
|
||||
async with async_session_factory() as db:
|
||||
try:
|
||||
result = await ingestion_service.fetch_and_ingest(
|
||||
db, provider, symbol, start_date=start_date, end_date=end_date,
|
||||
)
|
||||
_last_successful[job_name] = symbol
|
||||
processed += 1
|
||||
logger.info(json.dumps({
|
||||
"event": "ticker_collected",
|
||||
"job": job_name,
|
||||
"ticker": symbol,
|
||||
"status": result.status,
|
||||
"records": result.records_ingested,
|
||||
}))
|
||||
if result.status == "partial":
|
||||
# Rate limited — stop and resume next run
|
||||
logger.warning(json.dumps({
|
||||
"event": "rate_limited",
|
||||
"job": job_name,
|
||||
"ticker": symbol,
|
||||
"processed": processed,
|
||||
}))
|
||||
return
|
||||
except Exception as exc:
|
||||
_log_job_error(job_name, symbol, exc)
|
||||
|
||||
# Reset resume pointer on full completion
|
||||
_last_successful[job_name] = None
|
||||
logger.info(json.dumps({"event": "job_complete", "job": job_name, "tickers": processed}))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Job: Sentiment Collector
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def collect_sentiment() -> None:
|
||||
"""Fetch sentiment for all tracked tickers via Gemini.
|
||||
|
||||
Processes each ticker independently. On rate limit, records last
|
||||
successful ticker for resume.
|
||||
"""
|
||||
job_name = "sentiment_collector"
|
||||
logger.info(json.dumps({"event": "job_start", "job": job_name}))
|
||||
|
||||
async with async_session_factory() as db:
|
||||
if not await _is_job_enabled(db, job_name):
|
||||
logger.info(json.dumps({"event": "job_skipped", "job": job_name, "reason": "disabled"}))
|
||||
return
|
||||
|
||||
symbols = await _get_all_tickers(db)
|
||||
if not symbols:
|
||||
logger.info(json.dumps({"event": "job_complete", "job": job_name, "tickers": 0}))
|
||||
return
|
||||
|
||||
symbols = _resume_tickers(symbols, job_name)
|
||||
|
||||
if not settings.gemini_api_key:
|
||||
logger.warning(json.dumps({"event": "job_skipped", "job": job_name, "reason": "gemini key not configured"}))
|
||||
return
|
||||
|
||||
try:
|
||||
provider = GeminiSentimentProvider(settings.gemini_api_key, settings.gemini_model)
|
||||
except Exception as exc:
|
||||
logger.error(json.dumps({"event": "job_error", "job": job_name, "error_type": type(exc).__name__, "message": str(exc)}))
|
||||
return
|
||||
|
||||
processed = 0
|
||||
|
||||
for symbol in symbols:
|
||||
async with async_session_factory() as db:
|
||||
try:
|
||||
data = await provider.fetch_sentiment(symbol)
|
||||
await sentiment_service.store_sentiment(
|
||||
db,
|
||||
symbol=symbol,
|
||||
classification=data.classification,
|
||||
confidence=data.confidence,
|
||||
source=data.source,
|
||||
timestamp=data.timestamp,
|
||||
)
|
||||
_last_successful[job_name] = symbol
|
||||
processed += 1
|
||||
logger.info(json.dumps({
|
||||
"event": "ticker_collected",
|
||||
"job": job_name,
|
||||
"ticker": symbol,
|
||||
"classification": data.classification,
|
||||
"confidence": data.confidence,
|
||||
}))
|
||||
except Exception as exc:
|
||||
msg = str(exc).lower()
|
||||
if "rate" in msg or "quota" in msg or "429" in msg:
|
||||
logger.warning(json.dumps({
|
||||
"event": "rate_limited",
|
||||
"job": job_name,
|
||||
"ticker": symbol,
|
||||
"processed": processed,
|
||||
}))
|
||||
return
|
||||
_log_job_error(job_name, symbol, exc)
|
||||
|
||||
_last_successful[job_name] = None
|
||||
logger.info(json.dumps({"event": "job_complete", "job": job_name, "tickers": processed}))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Job: Fundamental Collector
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def collect_fundamentals() -> None:
|
||||
"""Fetch fundamentals for all tracked tickers via FMP.
|
||||
|
||||
Processes each ticker independently. On rate limit, records last
|
||||
successful ticker for resume.
|
||||
"""
|
||||
job_name = "fundamental_collector"
|
||||
logger.info(json.dumps({"event": "job_start", "job": job_name}))
|
||||
|
||||
async with async_session_factory() as db:
|
||||
if not await _is_job_enabled(db, job_name):
|
||||
logger.info(json.dumps({"event": "job_skipped", "job": job_name, "reason": "disabled"}))
|
||||
return
|
||||
|
||||
symbols = await _get_all_tickers(db)
|
||||
if not symbols:
|
||||
logger.info(json.dumps({"event": "job_complete", "job": job_name, "tickers": 0}))
|
||||
return
|
||||
|
||||
symbols = _resume_tickers(symbols, job_name)
|
||||
|
||||
if not settings.fmp_api_key:
|
||||
logger.warning(json.dumps({"event": "job_skipped", "job": job_name, "reason": "fmp key not configured"}))
|
||||
return
|
||||
|
||||
try:
|
||||
provider = FMPFundamentalProvider(settings.fmp_api_key)
|
||||
except Exception as exc:
|
||||
logger.error(json.dumps({"event": "job_error", "job": job_name, "error_type": type(exc).__name__, "message": str(exc)}))
|
||||
return
|
||||
|
||||
processed = 0
|
||||
|
||||
for symbol in symbols:
|
||||
async with async_session_factory() as db:
|
||||
try:
|
||||
data = await provider.fetch_fundamentals(symbol)
|
||||
await fundamental_service.store_fundamental(
|
||||
db,
|
||||
symbol=symbol,
|
||||
pe_ratio=data.pe_ratio,
|
||||
revenue_growth=data.revenue_growth,
|
||||
earnings_surprise=data.earnings_surprise,
|
||||
market_cap=data.market_cap,
|
||||
)
|
||||
_last_successful[job_name] = symbol
|
||||
processed += 1
|
||||
logger.info(json.dumps({
|
||||
"event": "ticker_collected",
|
||||
"job": job_name,
|
||||
"ticker": symbol,
|
||||
}))
|
||||
except Exception as exc:
|
||||
msg = str(exc).lower()
|
||||
if "rate" in msg or "429" in msg:
|
||||
logger.warning(json.dumps({
|
||||
"event": "rate_limited",
|
||||
"job": job_name,
|
||||
"ticker": symbol,
|
||||
"processed": processed,
|
||||
}))
|
||||
return
|
||||
_log_job_error(job_name, symbol, exc)
|
||||
|
||||
_last_successful[job_name] = None
|
||||
logger.info(json.dumps({"event": "job_complete", "job": job_name, "tickers": processed}))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Job: R:R Scanner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def scan_rr() -> None:
|
||||
"""Scan all tickers for trade setups meeting the R:R threshold.
|
||||
|
||||
Uses rr_scanner_service.scan_all_tickers which already handles
|
||||
per-ticker error isolation internally.
|
||||
"""
|
||||
job_name = "rr_scanner"
|
||||
logger.info(json.dumps({"event": "job_start", "job": job_name}))
|
||||
|
||||
async with async_session_factory() as db:
|
||||
if not await _is_job_enabled(db, job_name):
|
||||
logger.info(json.dumps({"event": "job_skipped", "job": job_name, "reason": "disabled"}))
|
||||
return
|
||||
|
||||
try:
|
||||
setups = await scan_all_tickers(
|
||||
db, rr_threshold=settings.default_rr_threshold,
|
||||
)
|
||||
logger.info(json.dumps({
|
||||
"event": "job_complete",
|
||||
"job": job_name,
|
||||
"setups_found": len(setups),
|
||||
}))
|
||||
except Exception as exc:
|
||||
logger.error(json.dumps({
|
||||
"event": "job_error",
|
||||
"job": job_name,
|
||||
"error_type": type(exc).__name__,
|
||||
"message": str(exc),
|
||||
}))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frequency helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FREQUENCY_MAP: dict[str, dict[str, int]] = {
|
||||
"hourly": {"hours": 1},
|
||||
"daily": {"hours": 24},
|
||||
}
|
||||
|
||||
|
||||
def _parse_frequency(freq: str) -> dict[str, int]:
|
||||
"""Convert a frequency string to APScheduler interval kwargs."""
|
||||
return _FREQUENCY_MAP.get(freq.lower(), {"hours": 24})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scheduler setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def configure_scheduler() -> None:
|
||||
"""Add all jobs to the scheduler with configured intervals.
|
||||
|
||||
Call this once before scheduler.start(). Removes any existing jobs first
|
||||
to ensure idempotency.
|
||||
"""
|
||||
scheduler.remove_all_jobs()
|
||||
|
||||
# Data Collector — configurable frequency (default: hourly)
|
||||
ohlcv_interval = _parse_frequency(settings.data_collector_frequency)
|
||||
scheduler.add_job(
|
||||
collect_ohlcv,
|
||||
"interval",
|
||||
**ohlcv_interval,
|
||||
id="data_collector",
|
||||
name="Data Collector (OHLCV)",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# Sentiment Collector — default 30 min
|
||||
scheduler.add_job(
|
||||
collect_sentiment,
|
||||
"interval",
|
||||
minutes=settings.sentiment_poll_interval_minutes,
|
||||
id="sentiment_collector",
|
||||
name="Sentiment Collector",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# Fundamental Collector — configurable frequency (default: daily)
|
||||
fund_interval = _parse_frequency(settings.fundamental_fetch_frequency)
|
||||
scheduler.add_job(
|
||||
collect_fundamentals,
|
||||
"interval",
|
||||
**fund_interval,
|
||||
id="fundamental_collector",
|
||||
name="Fundamental Collector",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# R:R Scanner — configurable frequency (default: hourly)
|
||||
rr_interval = _parse_frequency(settings.rr_scan_frequency)
|
||||
scheduler.add_job(
|
||||
scan_rr,
|
||||
"interval",
|
||||
**rr_interval,
|
||||
id="rr_scanner",
|
||||
name="R:R Scanner",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
json.dumps({
|
||||
"event": "scheduler_configured",
|
||||
"jobs": {
|
||||
"data_collector": ohlcv_interval,
|
||||
"sentiment_collector": {"minutes": settings.sentiment_poll_interval_minutes},
|
||||
"fundamental_collector": fund_interval,
|
||||
"rr_scanner": rr_interval,
|
||||
},
|
||||
})
|
||||
)
|
||||
1
app/schemas/__init__.py
Normal file
1
app/schemas/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
41
app/schemas/admin.py
Normal file
41
app/schemas/admin.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Admin request/response schemas."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class UserManagement(BaseModel):
|
||||
"""Schema for user access management."""
|
||||
has_access: bool
|
||||
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
"""Schema for resetting a user's password."""
|
||||
new_password: str = Field(..., min_length=6)
|
||||
|
||||
|
||||
class CreateUserRequest(BaseModel):
|
||||
"""Schema for admin-created user accounts."""
|
||||
username: str = Field(..., min_length=1)
|
||||
password: str = Field(..., min_length=6)
|
||||
role: str = Field(default="user", pattern=r"^(user|admin)$")
|
||||
has_access: bool = False
|
||||
|
||||
|
||||
class RegistrationToggle(BaseModel):
|
||||
"""Schema for toggling registration on/off."""
|
||||
enabled: bool
|
||||
|
||||
|
||||
class SystemSettingUpdate(BaseModel):
|
||||
"""Schema for updating a system setting."""
|
||||
value: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
class DataCleanupRequest(BaseModel):
|
||||
"""Schema for data cleanup — delete records older than N days."""
|
||||
older_than_days: int = Field(..., gt=0)
|
||||
|
||||
|
||||
class JobToggle(BaseModel):
|
||||
"""Schema for enabling/disabling a scheduled job."""
|
||||
enabled: bool
|
||||
18
app/schemas/auth.py
Normal file
18
app/schemas/auth.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Auth request/response schemas."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str = Field(..., min_length=1)
|
||||
password: str = Field(..., min_length=6)
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
13
app/schemas/common.py
Normal file
13
app/schemas/common.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Shared API schemas used across all endpoints."""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class APIEnvelope(BaseModel):
|
||||
"""Standard JSON envelope for all API responses."""
|
||||
|
||||
status: Literal["success", "error"]
|
||||
data: Any | None = None
|
||||
error: str | None = None
|
||||
18
app/schemas/fundamental.py
Normal file
18
app/schemas/fundamental.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Pydantic schemas for fundamental data endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FundamentalResponse(BaseModel):
|
||||
"""Envelope-ready fundamental data response."""
|
||||
|
||||
symbol: str
|
||||
pe_ratio: float | None = None
|
||||
revenue_growth: float | None = None
|
||||
earnings_surprise: float | None = None
|
||||
market_cap: float | None = None
|
||||
fetched_at: datetime | None = None
|
||||
49
app/schemas/indicator.py
Normal file
49
app/schemas/indicator.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Pydantic schemas for technical indicator endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class IndicatorRequest(BaseModel):
|
||||
"""Query parameters for indicator computation."""
|
||||
|
||||
start_date: date | None = None
|
||||
end_date: date | None = None
|
||||
period: int | None = None
|
||||
|
||||
|
||||
class IndicatorResult(BaseModel):
|
||||
"""Raw indicator values plus normalized score."""
|
||||
|
||||
indicator_type: str
|
||||
values: dict[str, Any]
|
||||
score: float = Field(ge=0, le=100)
|
||||
bars_used: int
|
||||
|
||||
|
||||
class IndicatorResponse(BaseModel):
|
||||
"""Envelope-ready indicator response."""
|
||||
|
||||
symbol: str
|
||||
indicator: IndicatorResult
|
||||
|
||||
|
||||
class EMACrossResult(BaseModel):
|
||||
"""EMA cross signal details."""
|
||||
|
||||
short_ema: float
|
||||
long_ema: float
|
||||
short_period: int
|
||||
long_period: int
|
||||
signal: Literal["bullish", "bearish", "neutral"]
|
||||
|
||||
|
||||
class EMACrossResponse(BaseModel):
|
||||
"""Envelope-ready EMA cross response."""
|
||||
|
||||
symbol: str
|
||||
ema_cross: EMACrossResult
|
||||
31
app/schemas/ohlcv.py
Normal file
31
app/schemas/ohlcv.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""OHLCV request/response schemas."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OHLCVCreate(BaseModel):
|
||||
symbol: str = Field(..., description="Ticker symbol (e.g. AAPL)")
|
||||
date: _dt.date = Field(..., description="Trading date (YYYY-MM-DD)")
|
||||
open: float = Field(..., ge=0, description="Opening price")
|
||||
high: float = Field(..., ge=0, description="High price")
|
||||
low: float = Field(..., ge=0, description="Low price")
|
||||
close: float = Field(..., ge=0, description="Closing price")
|
||||
volume: int = Field(..., ge=0, description="Trading volume")
|
||||
|
||||
|
||||
class OHLCVResponse(BaseModel):
|
||||
id: int
|
||||
ticker_id: int
|
||||
date: _dt.date
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: int
|
||||
created_at: _dt.datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
52
app/schemas/score.py
Normal file
52
app/schemas/score.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Pydantic schemas for scoring endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DimensionScoreResponse(BaseModel):
|
||||
"""A single dimension score."""
|
||||
|
||||
dimension: str
|
||||
score: float
|
||||
is_stale: bool
|
||||
computed_at: datetime | None = None
|
||||
|
||||
|
||||
class ScoreResponse(BaseModel):
|
||||
"""Full score response for a ticker: composite + all dimensions."""
|
||||
|
||||
symbol: str
|
||||
composite_score: float | None = None
|
||||
composite_stale: bool = False
|
||||
weights: dict[str, float] = {}
|
||||
dimensions: list[DimensionScoreResponse] = []
|
||||
missing_dimensions: list[str] = []
|
||||
computed_at: datetime | None = None
|
||||
|
||||
|
||||
class WeightUpdateRequest(BaseModel):
|
||||
"""Request to update dimension weights."""
|
||||
|
||||
weights: dict[str, float] = Field(
|
||||
...,
|
||||
description="Dimension name → weight mapping. All weights must be positive.",
|
||||
)
|
||||
|
||||
|
||||
class RankingEntry(BaseModel):
|
||||
"""A single entry in the rankings list."""
|
||||
|
||||
symbol: str
|
||||
composite_score: float
|
||||
dimensions: list[DimensionScoreResponse] = []
|
||||
|
||||
|
||||
class RankingResponse(BaseModel):
|
||||
"""Rankings response: tickers sorted by composite score descending."""
|
||||
|
||||
rankings: list[RankingEntry] = []
|
||||
weights: dict[str, float] = {}
|
||||
30
app/schemas/sentiment.py
Normal file
30
app/schemas/sentiment.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Pydantic schemas for sentiment endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SentimentScoreResult(BaseModel):
|
||||
"""A single sentiment score record."""
|
||||
|
||||
id: int
|
||||
classification: Literal["bullish", "bearish", "neutral"]
|
||||
confidence: int = Field(ge=0, le=100)
|
||||
source: str
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class SentimentResponse(BaseModel):
|
||||
"""Envelope-ready sentiment response."""
|
||||
|
||||
symbol: str
|
||||
scores: list[SentimentScoreResult]
|
||||
count: int
|
||||
dimension_score: float | None = Field(
|
||||
None, ge=0, le=100, description="Time-decay weighted sentiment dimension score"
|
||||
)
|
||||
lookback_hours: float
|
||||
27
app/schemas/sr_level.py
Normal file
27
app/schemas/sr_level.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Pydantic schemas for S/R level endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SRLevelResult(BaseModel):
|
||||
"""A single support/resistance level."""
|
||||
|
||||
id: int
|
||||
price_level: float
|
||||
type: Literal["support", "resistance"]
|
||||
strength: int = Field(ge=0, le=100)
|
||||
detection_method: Literal["volume_profile", "pivot_point", "merged"]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class SRLevelResponse(BaseModel):
|
||||
"""Envelope-ready S/R levels response."""
|
||||
|
||||
symbol: str
|
||||
levels: list[SRLevelResult]
|
||||
count: int
|
||||
17
app/schemas/ticker.py
Normal file
17
app/schemas/ticker.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Ticker request/response schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TickerCreate(BaseModel):
|
||||
symbol: str = Field(..., description="NASDAQ ticker symbol (e.g. AAPL)")
|
||||
|
||||
|
||||
class TickerResponse(BaseModel):
|
||||
id: int
|
||||
symbol: str
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
21
app/schemas/trade_setup.py
Normal file
21
app/schemas/trade_setup.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Pydantic schemas for trade setup endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TradeSetupResponse(BaseModel):
|
||||
"""A single trade setup detected by the R:R scanner."""
|
||||
|
||||
id: int
|
||||
symbol: str
|
||||
direction: str
|
||||
entry_price: float
|
||||
stop_loss: float
|
||||
target: float
|
||||
rr_ratio: float
|
||||
composite_score: float
|
||||
detected_at: datetime
|
||||
36
app/schemas/watchlist.py
Normal file
36
app/schemas/watchlist.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Pydantic schemas for watchlist endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SRLevelSummary(BaseModel):
|
||||
"""Compact SR level for watchlist entry."""
|
||||
|
||||
price_level: float
|
||||
type: Literal["support", "resistance"]
|
||||
strength: int = Field(ge=0, le=100)
|
||||
|
||||
|
||||
class DimensionScoreSummary(BaseModel):
|
||||
"""Compact dimension score for watchlist entry."""
|
||||
|
||||
dimension: str
|
||||
score: float
|
||||
|
||||
|
||||
class WatchlistEntryResponse(BaseModel):
|
||||
"""A single watchlist entry with enriched data."""
|
||||
|
||||
symbol: str
|
||||
entry_type: Literal["auto", "manual"]
|
||||
composite_score: float | None = None
|
||||
dimensions: list[DimensionScoreSummary] = []
|
||||
rr_ratio: float | None = None
|
||||
rr_direction: str | None = None
|
||||
sr_levels: list[SRLevelSummary] = []
|
||||
added_at: datetime
|
||||
1
app/services/__init__.py
Normal file
1
app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
238
app/services/admin_service.py
Normal file
238
app/services/admin_service.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Admin service: user management, system settings, data cleanup, job control."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from passlib.hash import bcrypt
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import DuplicateError, NotFoundError, ValidationError
|
||||
from app.models.fundamental import FundamentalData
|
||||
from app.models.ohlcv import OHLCVRecord
|
||||
from app.models.sentiment import SentimentScore
|
||||
from app.models.settings import SystemSetting
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def list_users(db: AsyncSession) -> list[User]:
|
||||
"""Return all users ordered by id."""
|
||||
result = await db.execute(select(User).order_by(User.id))
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def create_user(
|
||||
db: AsyncSession,
|
||||
username: str,
|
||||
password: str,
|
||||
role: str = "user",
|
||||
has_access: bool = False,
|
||||
) -> User:
|
||||
"""Create a new user account (admin action)."""
|
||||
result = await db.execute(select(User).where(User.username == username))
|
||||
if result.scalar_one_or_none() is not None:
|
||||
raise DuplicateError(f"Username already exists: {username}")
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
password_hash=bcrypt.hash(password),
|
||||
role=role,
|
||||
has_access=has_access,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
async def set_user_access(db: AsyncSession, user_id: int, has_access: bool) -> User:
|
||||
"""Grant or revoke API access for a user."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise NotFoundError(f"User not found: {user_id}")
|
||||
|
||||
user.has_access = has_access
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
async def reset_password(db: AsyncSession, user_id: int, new_password: str) -> User:
|
||||
"""Reset a user's password."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise NotFoundError(f"User not found: {user_id}")
|
||||
|
||||
user.password_hash = bcrypt.hash(new_password)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registration toggle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def toggle_registration(db: AsyncSession, enabled: bool) -> SystemSetting:
|
||||
"""Enable or disable user registration via SystemSetting."""
|
||||
result = await db.execute(
|
||||
select(SystemSetting).where(SystemSetting.key == "registration_enabled")
|
||||
)
|
||||
setting = result.scalar_one_or_none()
|
||||
value = str(enabled).lower()
|
||||
|
||||
if setting is None:
|
||||
setting = SystemSetting(key="registration_enabled", value=value)
|
||||
db.add(setting)
|
||||
else:
|
||||
setting.value = value
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(setting)
|
||||
return setting
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System settings CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def list_settings(db: AsyncSession) -> list[SystemSetting]:
|
||||
"""Return all system settings."""
|
||||
result = await db.execute(select(SystemSetting).order_by(SystemSetting.key))
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def update_setting(db: AsyncSession, key: str, value: str) -> SystemSetting:
|
||||
"""Create or update a system setting."""
|
||||
result = await db.execute(
|
||||
select(SystemSetting).where(SystemSetting.key == key)
|
||||
)
|
||||
setting = result.scalar_one_or_none()
|
||||
|
||||
if setting is None:
|
||||
setting = SystemSetting(key=key, value=value)
|
||||
db.add(setting)
|
||||
else:
|
||||
setting.value = value
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(setting)
|
||||
return setting
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def cleanup_data(db: AsyncSession, older_than_days: int) -> dict[str, int]:
|
||||
"""Delete OHLCV, sentiment, and fundamental records older than N days.
|
||||
|
||||
Preserves tickers, users, and latest scores.
|
||||
Returns a dict with counts of deleted records per table.
|
||||
"""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=older_than_days)
|
||||
counts: dict[str, int] = {}
|
||||
|
||||
# OHLCV — date column is a date, compare with cutoff date
|
||||
result = await db.execute(
|
||||
delete(OHLCVRecord).where(OHLCVRecord.date < cutoff.date())
|
||||
)
|
||||
counts["ohlcv"] = result.rowcount # type: ignore[assignment]
|
||||
|
||||
# Sentiment — timestamp is datetime
|
||||
result = await db.execute(
|
||||
delete(SentimentScore).where(SentimentScore.timestamp < cutoff)
|
||||
)
|
||||
counts["sentiment"] = result.rowcount # type: ignore[assignment]
|
||||
|
||||
# Fundamentals — fetched_at is datetime
|
||||
result = await db.execute(
|
||||
delete(FundamentalData).where(FundamentalData.fetched_at < cutoff)
|
||||
)
|
||||
counts["fundamentals"] = result.rowcount # type: ignore[assignment]
|
||||
|
||||
await db.commit()
|
||||
return counts
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Job control (placeholder — scheduler is Task 12.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
VALID_JOB_NAMES = {"data_collector", "sentiment_collector", "fundamental_collector", "rr_scanner"}
|
||||
|
||||
JOB_LABELS = {
|
||||
"data_collector": "Data Collector (OHLCV)",
|
||||
"sentiment_collector": "Sentiment Collector",
|
||||
"fundamental_collector": "Fundamental Collector",
|
||||
"rr_scanner": "R:R Scanner",
|
||||
}
|
||||
|
||||
|
||||
async def list_jobs(db: AsyncSession) -> list[dict]:
|
||||
"""Return status of all scheduled jobs."""
|
||||
from app.scheduler import scheduler
|
||||
|
||||
jobs_out = []
|
||||
for name in sorted(VALID_JOB_NAMES):
|
||||
# Check enabled setting
|
||||
key = f"job_{name}_enabled"
|
||||
result = await db.execute(
|
||||
select(SystemSetting).where(SystemSetting.key == key)
|
||||
)
|
||||
setting = result.scalar_one_or_none()
|
||||
enabled = setting.value == "true" if setting else True # default enabled
|
||||
|
||||
# Get scheduler job info
|
||||
job = scheduler.get_job(name)
|
||||
next_run = None
|
||||
if job and job.next_run_time:
|
||||
next_run = job.next_run_time.isoformat()
|
||||
|
||||
jobs_out.append({
|
||||
"name": name,
|
||||
"label": JOB_LABELS.get(name, name),
|
||||
"enabled": enabled,
|
||||
"next_run_at": next_run,
|
||||
"registered": job is not None,
|
||||
})
|
||||
|
||||
return jobs_out
|
||||
|
||||
|
||||
async def trigger_job(db: AsyncSession, job_name: str) -> dict[str, str]:
|
||||
"""Trigger a manual job run via the scheduler.
|
||||
|
||||
Runs the job immediately (in addition to its regular schedule).
|
||||
"""
|
||||
if job_name not in VALID_JOB_NAMES:
|
||||
raise ValidationError(f"Unknown job: {job_name}. Valid jobs: {', '.join(sorted(VALID_JOB_NAMES))}")
|
||||
|
||||
from app.scheduler import scheduler
|
||||
|
||||
job = scheduler.get_job(job_name)
|
||||
if job is None:
|
||||
return {"job": job_name, "status": "not_found", "message": f"Job '{job_name}' is not registered in the scheduler"}
|
||||
|
||||
job.modify(next_run_time=None) # Reset, then trigger immediately
|
||||
from datetime import datetime, timezone
|
||||
job.modify(next_run_time=datetime.now(timezone.utc))
|
||||
|
||||
return {"job": job_name, "status": "triggered", "message": f"Job '{job_name}' triggered for immediate execution"}
|
||||
|
||||
|
||||
async def toggle_job(db: AsyncSession, job_name: str, enabled: bool) -> SystemSetting:
|
||||
"""Enable or disable a scheduled job by storing state in SystemSetting.
|
||||
|
||||
Actual scheduler integration happens in Task 12.1.
|
||||
"""
|
||||
if job_name not in VALID_JOB_NAMES:
|
||||
raise ValidationError(f"Unknown job: {job_name}. Valid jobs: {', '.join(sorted(VALID_JOB_NAMES))}")
|
||||
|
||||
key = f"job_{job_name}_enabled"
|
||||
return await update_setting(db, key, str(enabled).lower())
|
||||
66
app/services/auth_service.py
Normal file
66
app/services/auth_service.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Auth service: registration, login, and JWT token generation."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from jose import jwt
|
||||
from passlib.hash import bcrypt
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.dependencies import JWT_ALGORITHM
|
||||
from app.exceptions import AuthenticationError, AuthorizationError, DuplicateError
|
||||
from app.models.settings import SystemSetting
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
async def register(db: AsyncSession, username: str, password: str) -> User:
|
||||
"""Register a new user.
|
||||
|
||||
Checks if registration is enabled via SystemSetting, rejects duplicates,
|
||||
and creates a user with role='user' and has_access=False.
|
||||
"""
|
||||
# Check registration toggle
|
||||
result = await db.execute(
|
||||
select(SystemSetting).where(SystemSetting.key == "registration_enabled")
|
||||
)
|
||||
setting = result.scalar_one_or_none()
|
||||
if setting is not None and setting.value.lower() == "false":
|
||||
raise AuthorizationError("Registration is closed")
|
||||
|
||||
# Check duplicate username
|
||||
result = await db.execute(select(User).where(User.username == username))
|
||||
if result.scalar_one_or_none() is not None:
|
||||
raise DuplicateError(f"Username already exists: {username}")
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
password_hash=bcrypt.hash(password),
|
||||
role="user",
|
||||
has_access=False,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
async def login(db: AsyncSession, username: str, password: str) -> str:
|
||||
"""Authenticate user and return a JWT access token.
|
||||
|
||||
Returns the same error message for wrong username or wrong password
|
||||
to avoid leaking which field is incorrect.
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.username == username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None or not bcrypt.verify(password, user.password_hash):
|
||||
raise AuthenticationError("Invalid credentials")
|
||||
|
||||
payload = {
|
||||
"sub": str(user.id),
|
||||
"role": user.role,
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=settings.jwt_expiry_minutes),
|
||||
}
|
||||
token = jwt.encode(payload, settings.jwt_secret, algorithm=JWT_ALGORITHM)
|
||||
return token
|
||||
101
app/services/fundamental_service.py
Normal file
101
app/services/fundamental_service.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Fundamental data service.
|
||||
|
||||
Stores fundamental data (P/E, revenue growth, earnings surprise, market cap)
|
||||
and marks the fundamental dimension score as stale on new data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import NotFoundError
|
||||
from app.models.fundamental import FundamentalData
|
||||
from app.models.score import DimensionScore
|
||||
from app.models.ticker import Ticker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
"""Look up a ticker by symbol."""
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
async def store_fundamental(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
pe_ratio: float | None = None,
|
||||
revenue_growth: float | None = None,
|
||||
earnings_surprise: float | None = None,
|
||||
market_cap: float | None = None,
|
||||
) -> FundamentalData:
|
||||
"""Store or update fundamental data for a ticker.
|
||||
|
||||
Keeps a single latest snapshot per ticker. On new data, marks the
|
||||
fundamental dimension score as stale (if one exists).
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
# Check for existing record
|
||||
result = await db.execute(
|
||||
select(FundamentalData).where(FundamentalData.ticker_id == ticker.id)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if existing is not None:
|
||||
existing.pe_ratio = pe_ratio
|
||||
existing.revenue_growth = revenue_growth
|
||||
existing.earnings_surprise = earnings_surprise
|
||||
existing.market_cap = market_cap
|
||||
existing.fetched_at = now
|
||||
record = existing
|
||||
else:
|
||||
record = FundamentalData(
|
||||
ticker_id=ticker.id,
|
||||
pe_ratio=pe_ratio,
|
||||
revenue_growth=revenue_growth,
|
||||
earnings_surprise=earnings_surprise,
|
||||
market_cap=market_cap,
|
||||
fetched_at=now,
|
||||
)
|
||||
db.add(record)
|
||||
|
||||
# Mark fundamental dimension score as stale if it exists
|
||||
# TODO: Use DimensionScore service when built
|
||||
dim_result = await db.execute(
|
||||
select(DimensionScore).where(
|
||||
DimensionScore.ticker_id == ticker.id,
|
||||
DimensionScore.dimension == "fundamental",
|
||||
)
|
||||
)
|
||||
dim_score = dim_result.scalar_one_or_none()
|
||||
if dim_score is not None:
|
||||
dim_score.is_stale = True
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(record)
|
||||
return record
|
||||
|
||||
|
||||
async def get_fundamental(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
) -> FundamentalData | None:
|
||||
"""Get the latest fundamental data for a ticker."""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
result = await db.execute(
|
||||
select(FundamentalData).where(FundamentalData.ticker_id == ticker.id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
509
app/services/indicator_service.py
Normal file
509
app/services/indicator_service.py
Normal file
@@ -0,0 +1,509 @@
|
||||
"""Technical Analysis service.
|
||||
|
||||
Computes indicators from OHLCV data. Each indicator function is a pure
|
||||
function that takes a list of OHLCV-like records and returns raw values
|
||||
plus a normalized 0-100 score. The service layer handles DB fetching,
|
||||
caching, and minimum-data validation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.cache import indicator_cache
|
||||
from app.exceptions import ValidationError
|
||||
from app.services.price_service import query_ohlcv
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimum data requirements per indicator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MIN_BARS: dict[str, int] = {
|
||||
"adx": 28,
|
||||
"ema": 0, # dynamic: period + 1
|
||||
"rsi": 15,
|
||||
"atr": 15,
|
||||
"volume_profile": 20,
|
||||
"pivot_points": 5,
|
||||
}
|
||||
|
||||
DEFAULT_PERIODS: dict[str, int] = {
|
||||
"adx": 14,
|
||||
"ema": 20,
|
||||
"rsi": 14,
|
||||
"atr": 14,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure computation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ema(values: list[float], period: int) -> list[float]:
|
||||
"""Compute EMA series. Returns list same length as *values*."""
|
||||
if len(values) < period:
|
||||
return []
|
||||
k = 2.0 / (period + 1)
|
||||
ema_vals: list[float] = [sum(values[:period]) / period]
|
||||
for v in values[period:]:
|
||||
ema_vals.append(v * k + ema_vals[-1] * (1 - k))
|
||||
return ema_vals
|
||||
|
||||
|
||||
def compute_adx(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
period: int = 14,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute ADX from high/low/close arrays.
|
||||
|
||||
Returns dict with ``adx``, ``plus_di``, ``minus_di``, ``score``.
|
||||
"""
|
||||
n = len(closes)
|
||||
if n < 2 * period:
|
||||
raise ValidationError(
|
||||
f"ADX requires at least {2 * period} bars, got {n}"
|
||||
)
|
||||
|
||||
# True Range, +DM, -DM
|
||||
tr_list: list[float] = []
|
||||
plus_dm: list[float] = []
|
||||
minus_dm: list[float] = []
|
||||
for i in range(1, n):
|
||||
h, l, pc = highs[i], lows[i], closes[i - 1]
|
||||
tr_list.append(max(h - l, abs(h - pc), abs(l - pc)))
|
||||
up = highs[i] - highs[i - 1]
|
||||
down = lows[i - 1] - lows[i]
|
||||
plus_dm.append(up if up > down and up > 0 else 0.0)
|
||||
minus_dm.append(down if down > up and down > 0 else 0.0)
|
||||
|
||||
# Smoothed TR, +DM, -DM (Wilder smoothing)
|
||||
def _smooth(vals: list[float], p: int) -> list[float]:
|
||||
s = [sum(vals[:p])]
|
||||
for v in vals[p:]:
|
||||
s.append(s[-1] - s[-1] / p + v)
|
||||
return s
|
||||
|
||||
s_tr = _smooth(tr_list, period)
|
||||
s_plus = _smooth(plus_dm, period)
|
||||
s_minus = _smooth(minus_dm, period)
|
||||
|
||||
# +DI, -DI, DX
|
||||
dx_list: list[float] = []
|
||||
plus_di_last = 0.0
|
||||
minus_di_last = 0.0
|
||||
for i in range(len(s_tr)):
|
||||
tr_v = s_tr[i] if s_tr[i] != 0 else 1e-10
|
||||
pdi = 100.0 * s_plus[i] / tr_v
|
||||
mdi = 100.0 * s_minus[i] / tr_v
|
||||
denom = pdi + mdi if (pdi + mdi) != 0 else 1e-10
|
||||
dx_list.append(100.0 * abs(pdi - mdi) / denom)
|
||||
plus_di_last = pdi
|
||||
minus_di_last = mdi
|
||||
|
||||
# ADX = smoothed DX
|
||||
if len(dx_list) < period:
|
||||
adx_val = sum(dx_list) / len(dx_list) if dx_list else 0.0
|
||||
else:
|
||||
adx_vals = _smooth(dx_list, period)
|
||||
adx_val = adx_vals[-1]
|
||||
|
||||
score = max(0.0, min(100.0, adx_val))
|
||||
|
||||
return {
|
||||
"adx": round(adx_val, 4),
|
||||
"plus_di": round(plus_di_last, 4),
|
||||
"minus_di": round(minus_di_last, 4),
|
||||
"score": round(score, 4),
|
||||
}
|
||||
|
||||
|
||||
def compute_ema(
|
||||
closes: list[float],
|
||||
period: int = 20,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute EMA for *closes* with given *period*.
|
||||
|
||||
Score: normalized position of latest close relative to EMA.
|
||||
Above EMA → higher score, below → lower.
|
||||
"""
|
||||
min_bars = period + 1
|
||||
if len(closes) < min_bars:
|
||||
raise ValidationError(
|
||||
f"EMA({period}) requires at least {min_bars} bars, got {len(closes)}"
|
||||
)
|
||||
|
||||
ema_vals = _ema(closes, period)
|
||||
latest_ema = ema_vals[-1]
|
||||
latest_close = closes[-1]
|
||||
|
||||
# Score: 50 = at EMA, 100 = 5%+ above, 0 = 5%+ below
|
||||
if latest_ema == 0:
|
||||
pct = 0.0
|
||||
else:
|
||||
pct = (latest_close - latest_ema) / latest_ema * 100.0
|
||||
score = max(0.0, min(100.0, 50.0 + pct * 10.0))
|
||||
|
||||
return {
|
||||
"ema": round(latest_ema, 4),
|
||||
"period": period,
|
||||
"latest_close": round(latest_close, 4),
|
||||
"score": round(score, 4),
|
||||
}
|
||||
|
||||
|
||||
def compute_rsi(
|
||||
closes: list[float],
|
||||
period: int = 14,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute RSI. Score = RSI value (already 0-100)."""
|
||||
n = len(closes)
|
||||
if n < period + 1:
|
||||
raise ValidationError(
|
||||
f"RSI requires at least {period + 1} bars, got {n}"
|
||||
)
|
||||
|
||||
deltas = [closes[i] - closes[i - 1] for i in range(1, n)]
|
||||
gains = [d if d > 0 else 0.0 for d in deltas]
|
||||
losses = [-d if d < 0 else 0.0 for d in deltas]
|
||||
|
||||
avg_gain = sum(gains[:period]) / period
|
||||
avg_loss = sum(losses[:period]) / period
|
||||
|
||||
for i in range(period, len(deltas)):
|
||||
avg_gain = (avg_gain * (period - 1) + gains[i]) / period
|
||||
avg_loss = (avg_loss * (period - 1) + losses[i]) / period
|
||||
|
||||
if avg_loss == 0:
|
||||
rsi = 100.0
|
||||
else:
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100.0 - 100.0 / (1.0 + rs)
|
||||
|
||||
score = max(0.0, min(100.0, rsi))
|
||||
|
||||
return {
|
||||
"rsi": round(rsi, 4),
|
||||
"period": period,
|
||||
"score": round(score, 4),
|
||||
}
|
||||
|
||||
|
||||
def compute_atr(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
period: int = 14,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute ATR. Score = normalized inverse (lower ATR = higher score)."""
|
||||
n = len(closes)
|
||||
if n < period + 1:
|
||||
raise ValidationError(
|
||||
f"ATR requires at least {period + 1} bars, got {n}"
|
||||
)
|
||||
|
||||
tr_list: list[float] = []
|
||||
for i in range(1, n):
|
||||
h, l, pc = highs[i], lows[i], closes[i - 1]
|
||||
tr_list.append(max(h - l, abs(h - pc), abs(l - pc)))
|
||||
|
||||
# Wilder smoothing
|
||||
atr = sum(tr_list[:period]) / period
|
||||
for tr in tr_list[period:]:
|
||||
atr = (atr * (period - 1) + tr) / period
|
||||
|
||||
# Score: inverse normalized. ATR as % of price; lower = higher score.
|
||||
latest_close = closes[-1]
|
||||
if latest_close == 0:
|
||||
atr_pct = 0.0
|
||||
else:
|
||||
atr_pct = atr / latest_close * 100.0
|
||||
# 0% ATR → 100 score, 10%+ ATR → 0 score
|
||||
score = max(0.0, min(100.0, 100.0 - atr_pct * 10.0))
|
||||
|
||||
return {
|
||||
"atr": round(atr, 4),
|
||||
"period": period,
|
||||
"atr_percent": round(atr_pct, 4),
|
||||
"score": round(score, 4),
|
||||
}
|
||||
|
||||
|
||||
def compute_volume_profile(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
volumes: list[int],
|
||||
num_bins: int = 20,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute Volume Profile: POC, Value Area, HVN, LVN.
|
||||
|
||||
Score: proximity of latest close to POC (closer = higher).
|
||||
"""
|
||||
n = len(closes)
|
||||
if n < 20:
|
||||
raise ValidationError(
|
||||
f"Volume Profile requires at least 20 bars, got {n}"
|
||||
)
|
||||
|
||||
price_min = min(lows)
|
||||
price_max = max(highs)
|
||||
if price_max == price_min:
|
||||
price_max = price_min + 1.0 # avoid zero-width range
|
||||
|
||||
bin_width = (price_max - price_min) / num_bins
|
||||
bins: list[float] = [0.0] * num_bins
|
||||
bin_prices: list[float] = [
|
||||
price_min + (i + 0.5) * bin_width for i in range(num_bins)
|
||||
]
|
||||
|
||||
for i in range(n):
|
||||
# Distribute volume across bins the bar spans
|
||||
bar_low, bar_high = lows[i], highs[i]
|
||||
for b in range(num_bins):
|
||||
bl = price_min + b * bin_width
|
||||
bh = bl + bin_width
|
||||
if bar_high >= bl and bar_low <= bh:
|
||||
bins[b] += volumes[i]
|
||||
|
||||
total_vol = sum(bins)
|
||||
if total_vol == 0:
|
||||
total_vol = 1.0
|
||||
|
||||
# POC = bin with highest volume
|
||||
poc_idx = bins.index(max(bins))
|
||||
poc = round(bin_prices[poc_idx], 4)
|
||||
|
||||
# Value Area: 70% of total volume around POC
|
||||
sorted_bins = sorted(range(num_bins), key=lambda i: bins[i], reverse=True)
|
||||
va_vol = 0.0
|
||||
va_indices: list[int] = []
|
||||
for idx in sorted_bins:
|
||||
va_vol += bins[idx]
|
||||
va_indices.append(idx)
|
||||
if va_vol >= total_vol * 0.7:
|
||||
break
|
||||
va_low = round(price_min + min(va_indices) * bin_width, 4)
|
||||
va_high = round(price_min + (max(va_indices) + 1) * bin_width, 4)
|
||||
|
||||
# HVN / LVN: bins above/below average volume
|
||||
avg_vol = total_vol / num_bins
|
||||
hvn = [round(bin_prices[i], 4) for i in range(num_bins) if bins[i] > avg_vol]
|
||||
lvn = [round(bin_prices[i], 4) for i in range(num_bins) if bins[i] < avg_vol]
|
||||
|
||||
# Score: proximity of latest close to POC
|
||||
latest = closes[-1]
|
||||
price_range = price_max - price_min
|
||||
if price_range == 0:
|
||||
score = 100.0
|
||||
else:
|
||||
dist_pct = abs(latest - poc) / price_range
|
||||
score = max(0.0, min(100.0, 100.0 * (1.0 - dist_pct)))
|
||||
|
||||
return {
|
||||
"poc": poc,
|
||||
"value_area_low": va_low,
|
||||
"value_area_high": va_high,
|
||||
"hvn": hvn,
|
||||
"lvn": lvn,
|
||||
"score": round(score, 4),
|
||||
}
|
||||
|
||||
|
||||
def compute_pivot_points(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
window: int = 2,
|
||||
) -> dict[str, Any]:
|
||||
"""Detect swing highs/lows as pivot points.
|
||||
|
||||
A swing high at index *i* means highs[i] >= all highs in [i-window, i+window].
|
||||
Score: based on number of pivots near current price.
|
||||
"""
|
||||
n = len(closes)
|
||||
if n < 5:
|
||||
raise ValidationError(
|
||||
f"Pivot Points requires at least 5 bars, got {n}"
|
||||
)
|
||||
|
||||
swing_highs: list[float] = []
|
||||
swing_lows: list[float] = []
|
||||
|
||||
for i in range(window, n - window):
|
||||
# Swing high
|
||||
if all(highs[i] >= highs[j] for j in range(i - window, i + window + 1)):
|
||||
swing_highs.append(round(highs[i], 4))
|
||||
# Swing low
|
||||
if all(lows[i] <= lows[j] for j in range(i - window, i + window + 1)):
|
||||
swing_lows.append(round(lows[i], 4))
|
||||
|
||||
all_pivots = swing_highs + swing_lows
|
||||
latest = closes[-1]
|
||||
|
||||
# Score: fraction of pivots within 2% of current price → 0-100
|
||||
if not all_pivots or latest == 0:
|
||||
score = 0.0
|
||||
else:
|
||||
near = sum(1 for p in all_pivots if abs(p - latest) / latest <= 0.02)
|
||||
score = min(100.0, (near / max(len(all_pivots), 1)) * 100.0)
|
||||
|
||||
return {
|
||||
"swing_highs": swing_highs,
|
||||
"swing_lows": swing_lows,
|
||||
"pivot_count": len(all_pivots),
|
||||
"score": round(score, 4),
|
||||
}
|
||||
|
||||
|
||||
def compute_ema_cross(
|
||||
closes: list[float],
|
||||
short_period: int = 20,
|
||||
long_period: int = 50,
|
||||
tolerance: float = 1e-6,
|
||||
) -> dict[str, Any]:
|
||||
"""Compare short EMA vs long EMA.
|
||||
|
||||
Returns signal: bullish (short > long), bearish (short < long),
|
||||
neutral (within tolerance).
|
||||
"""
|
||||
min_bars = long_period + 1
|
||||
if len(closes) < min_bars:
|
||||
raise ValidationError(
|
||||
f"EMA Cross requires at least {min_bars} bars, got {len(closes)}"
|
||||
)
|
||||
|
||||
short_ema_vals = _ema(closes, short_period)
|
||||
long_ema_vals = _ema(closes, long_period)
|
||||
|
||||
short_ema = short_ema_vals[-1]
|
||||
long_ema = long_ema_vals[-1]
|
||||
|
||||
diff = short_ema - long_ema
|
||||
if abs(diff) <= tolerance:
|
||||
signal = "neutral"
|
||||
elif diff > 0:
|
||||
signal = "bullish"
|
||||
else:
|
||||
signal = "bearish"
|
||||
|
||||
return {
|
||||
"short_ema": round(short_ema, 4),
|
||||
"long_ema": round(long_ema, 4),
|
||||
"short_period": short_period,
|
||||
"long_period": long_period,
|
||||
"signal": signal,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Supported indicator types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
INDICATOR_TYPES = {"adx", "ema", "rsi", "atr", "volume_profile", "pivot_points"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service-layer functions (DB + cache + validation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _extract_ohlcv(records: list) -> tuple[
|
||||
list[float], list[float], list[float], list[float], list[int]
|
||||
]:
|
||||
"""Extract parallel arrays from OHLCVRecord list."""
|
||||
opens = [float(r.open) for r in records]
|
||||
highs = [float(r.high) for r in records]
|
||||
lows = [float(r.low) for r in records]
|
||||
closes = [float(r.close) for r in records]
|
||||
volumes = [int(r.volume) for r in records]
|
||||
return opens, highs, lows, closes, volumes
|
||||
|
||||
|
||||
async def get_indicator(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
indicator_type: str,
|
||||
start_date: date | None = None,
|
||||
end_date: date | None = None,
|
||||
period: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute a single indicator for *symbol*.
|
||||
|
||||
Checks cache first; stores result after computing.
|
||||
"""
|
||||
indicator_type = indicator_type.lower()
|
||||
if indicator_type not in INDICATOR_TYPES:
|
||||
raise ValidationError(
|
||||
f"Unknown indicator type: {indicator_type}. "
|
||||
f"Supported: {', '.join(sorted(INDICATOR_TYPES))}"
|
||||
)
|
||||
|
||||
cache_key = (symbol.upper(), str(start_date), str(end_date), indicator_type)
|
||||
cached = indicator_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
records = await query_ohlcv(db, symbol, start_date, end_date)
|
||||
_, highs, lows, closes, volumes = _extract_ohlcv(records)
|
||||
n = len(records)
|
||||
|
||||
if indicator_type == "adx":
|
||||
p = period or DEFAULT_PERIODS["adx"]
|
||||
result = compute_adx(highs, lows, closes, period=p)
|
||||
elif indicator_type == "ema":
|
||||
p = period or DEFAULT_PERIODS["ema"]
|
||||
result = compute_ema(closes, period=p)
|
||||
elif indicator_type == "rsi":
|
||||
p = period or DEFAULT_PERIODS["rsi"]
|
||||
result = compute_rsi(closes, period=p)
|
||||
elif indicator_type == "atr":
|
||||
p = period or DEFAULT_PERIODS["atr"]
|
||||
result = compute_atr(highs, lows, closes, period=p)
|
||||
elif indicator_type == "volume_profile":
|
||||
result = compute_volume_profile(highs, lows, closes, volumes)
|
||||
elif indicator_type == "pivot_points":
|
||||
result = compute_pivot_points(highs, lows, closes)
|
||||
else:
|
||||
raise ValidationError(f"Unknown indicator type: {indicator_type}")
|
||||
|
||||
response = {
|
||||
"indicator_type": indicator_type,
|
||||
"values": {k: v for k, v in result.items() if k != "score"},
|
||||
"score": result["score"],
|
||||
"bars_used": n,
|
||||
}
|
||||
|
||||
indicator_cache.set(cache_key, response)
|
||||
return response
|
||||
|
||||
|
||||
async def get_ema_cross(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
start_date: date | None = None,
|
||||
end_date: date | None = None,
|
||||
short_period: int = 20,
|
||||
long_period: int = 50,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute EMA cross signal for *symbol*."""
|
||||
cache_key = (
|
||||
symbol.upper(),
|
||||
str(start_date),
|
||||
str(end_date),
|
||||
f"ema_cross_{short_period}_{long_period}",
|
||||
)
|
||||
cached = indicator_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
records = await query_ohlcv(db, symbol, start_date, end_date)
|
||||
_, _, _, closes, _ = _extract_ohlcv(records)
|
||||
|
||||
result = compute_ema_cross(closes, short_period, long_period)
|
||||
|
||||
indicator_cache.set(cache_key, result)
|
||||
return result
|
||||
172
app/services/ingestion_service.py
Normal file
172
app/services/ingestion_service.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Ingestion Pipeline service: fetch from provider, validate, upsert into Price Store.
|
||||
|
||||
Handles rate-limit resume via IngestionProgress and provider error isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, timedelta
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import NotFoundError, ProviderError, RateLimitError
|
||||
from app.models.settings import IngestionProgress
|
||||
from app.models.ticker import Ticker
|
||||
from app.providers.protocol import MarketDataProvider
|
||||
from app.services import price_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IngestionResult:
|
||||
"""Result of an ingestion run."""
|
||||
|
||||
symbol: str
|
||||
records_ingested: int
|
||||
last_date: date | None
|
||||
status: str # "complete" | "partial" | "error"
|
||||
message: str | None = None
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
"""Look up ticker by symbol. Raises NotFoundError if missing."""
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
async def _get_progress(db: AsyncSession, ticker_id: int) -> IngestionProgress | None:
|
||||
"""Get the IngestionProgress record for a ticker, if any."""
|
||||
result = await db.execute(
|
||||
select(IngestionProgress).where(IngestionProgress.ticker_id == ticker_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def _update_progress(
|
||||
db: AsyncSession, ticker_id: int, last_date: date
|
||||
) -> None:
|
||||
"""Create or update the IngestionProgress record for a ticker."""
|
||||
progress = await _get_progress(db, ticker_id)
|
||||
if progress is None:
|
||||
progress = IngestionProgress(ticker_id=ticker_id, last_ingested_date=last_date)
|
||||
db.add(progress)
|
||||
else:
|
||||
progress.last_ingested_date = last_date
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def fetch_and_ingest(
|
||||
db: AsyncSession,
|
||||
provider: MarketDataProvider,
|
||||
symbol: str,
|
||||
start_date: date | None = None,
|
||||
end_date: date | None = None,
|
||||
) -> IngestionResult:
|
||||
"""Fetch OHLCV data from provider and upsert into Price Store.
|
||||
|
||||
- Resolves start_date from IngestionProgress if not provided (resume).
|
||||
- Defaults end_date to today.
|
||||
- Tracks last_ingested_date after each successful upsert.
|
||||
- On RateLimitError from provider: returns partial progress.
|
||||
- On ProviderError: returns error, no data modification.
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
# Resolve end_date
|
||||
if end_date is None:
|
||||
end_date = date.today()
|
||||
|
||||
# Resolve start_date: use progress resume or default to 1 year ago
|
||||
if start_date is None:
|
||||
progress = await _get_progress(db, ticker.id)
|
||||
if progress is not None:
|
||||
start_date = progress.last_ingested_date + timedelta(days=1)
|
||||
else:
|
||||
start_date = end_date - timedelta(days=365)
|
||||
|
||||
# If start > end, nothing to fetch
|
||||
if start_date > end_date:
|
||||
return IngestionResult(
|
||||
symbol=ticker.symbol,
|
||||
records_ingested=0,
|
||||
last_date=None,
|
||||
status="complete",
|
||||
message="Already up to date",
|
||||
)
|
||||
|
||||
# Fetch from provider
|
||||
try:
|
||||
records = await provider.fetch_ohlcv(ticker.symbol, start_date, end_date)
|
||||
except RateLimitError:
|
||||
# No data fetched at all — return partial with 0 records
|
||||
return IngestionResult(
|
||||
symbol=ticker.symbol,
|
||||
records_ingested=0,
|
||||
last_date=None,
|
||||
status="partial",
|
||||
message="Rate limited before any records fetched. Resume available.",
|
||||
)
|
||||
except ProviderError as exc:
|
||||
logger.error("Provider error for %s: %s", ticker.symbol, exc)
|
||||
return IngestionResult(
|
||||
symbol=ticker.symbol,
|
||||
records_ingested=0,
|
||||
last_date=None,
|
||||
status="error",
|
||||
message=str(exc),
|
||||
)
|
||||
|
||||
# Sort records by date to ensure ordered ingestion
|
||||
records.sort(key=lambda r: r.date)
|
||||
|
||||
ingested_count = 0
|
||||
last_ingested: date | None = None
|
||||
|
||||
for record in records:
|
||||
try:
|
||||
await price_service.upsert_ohlcv(
|
||||
db,
|
||||
symbol=ticker.symbol,
|
||||
record_date=record.date,
|
||||
open_=record.open,
|
||||
high=record.high,
|
||||
low=record.low,
|
||||
close=record.close,
|
||||
volume=record.volume,
|
||||
)
|
||||
ingested_count += 1
|
||||
last_ingested = record.date
|
||||
|
||||
# Update progress after each successful upsert
|
||||
await _update_progress(db, ticker.id, record.date)
|
||||
|
||||
except RateLimitError:
|
||||
# Mid-ingestion rate limit — return partial progress
|
||||
logger.warning(
|
||||
"Rate limited during ingestion for %s after %d records",
|
||||
ticker.symbol,
|
||||
ingested_count,
|
||||
)
|
||||
return IngestionResult(
|
||||
symbol=ticker.symbol,
|
||||
records_ingested=ingested_count,
|
||||
last_date=last_ingested,
|
||||
status="partial",
|
||||
message=f"Rate limited. Ingested {ingested_count} records. Resume available.",
|
||||
)
|
||||
|
||||
return IngestionResult(
|
||||
symbol=ticker.symbol,
|
||||
records_ingested=ingested_count,
|
||||
last_date=last_ingested,
|
||||
status="complete",
|
||||
message=f"Successfully ingested {ingested_count} records",
|
||||
)
|
||||
110
app/services/price_service.py
Normal file
110
app/services/price_service.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Price Store service: upsert and query OHLCV records."""
|
||||
|
||||
from datetime import date, datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import NotFoundError, ValidationError
|
||||
from app.models.ohlcv import OHLCVRecord
|
||||
from app.models.ticker import Ticker
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
"""Look up a ticker by symbol. Raises NotFoundError if missing."""
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
def _validate_ohlcv(
|
||||
high: float, low: float, open_: float, close: float, volume: int, record_date: date
|
||||
) -> None:
|
||||
"""Business-rule validation for an OHLCV record."""
|
||||
if high < low:
|
||||
raise ValidationError("Validation error: high must be >= low")
|
||||
if any(p < 0 for p in (open_, high, low, close)):
|
||||
raise ValidationError("Validation error: prices must be >= 0")
|
||||
if volume < 0:
|
||||
raise ValidationError("Validation error: volume must be >= 0")
|
||||
if record_date > date.today():
|
||||
raise ValidationError("Validation error: date must not be in the future")
|
||||
|
||||
|
||||
async def upsert_ohlcv(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
record_date: date,
|
||||
open_: float,
|
||||
high: float,
|
||||
low: float,
|
||||
close: float,
|
||||
volume: int,
|
||||
) -> OHLCVRecord:
|
||||
"""Insert or update an OHLCV record for (ticker, date).
|
||||
|
||||
Validates business rules, resolves ticker, then uses
|
||||
ON CONFLICT DO UPDATE on the (ticker_id, date) unique constraint.
|
||||
"""
|
||||
_validate_ohlcv(high, low, open_, close, volume, record_date)
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
stmt = pg_insert(OHLCVRecord).values(
|
||||
ticker_id=ticker.id,
|
||||
date=record_date,
|
||||
open=open_,
|
||||
high=high,
|
||||
low=low,
|
||||
close=close,
|
||||
volume=volume,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
constraint="uq_ohlcv_ticker_date",
|
||||
set_={
|
||||
"open": stmt.excluded.open,
|
||||
"high": stmt.excluded.high,
|
||||
"low": stmt.excluded.low,
|
||||
"close": stmt.excluded.close,
|
||||
"volume": stmt.excluded.volume,
|
||||
"created_at": stmt.excluded.created_at,
|
||||
},
|
||||
)
|
||||
stmt = stmt.returning(OHLCVRecord)
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
record = result.scalar_one()
|
||||
|
||||
# TODO: Invalidate LRU cache entries for this ticker (Task 7.1)
|
||||
# TODO: Mark composite score as stale for this ticker (Task 10.1)
|
||||
|
||||
return record
|
||||
|
||||
|
||||
async def query_ohlcv(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
start_date: date | None = None,
|
||||
end_date: date | None = None,
|
||||
) -> list[OHLCVRecord]:
|
||||
"""Query OHLCV records for a ticker, optionally filtered by date range.
|
||||
|
||||
Returns records sorted by date ascending.
|
||||
Raises NotFoundError if the ticker does not exist.
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
stmt = select(OHLCVRecord).where(OHLCVRecord.ticker_id == ticker.id)
|
||||
if start_date is not None:
|
||||
stmt = stmt.where(OHLCVRecord.date >= start_date)
|
||||
if end_date is not None:
|
||||
stmt = stmt.where(OHLCVRecord.date <= end_date)
|
||||
stmt = stmt.order_by(OHLCVRecord.date.asc())
|
||||
|
||||
result = await db.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
241
app/services/rr_scanner_service.py
Normal file
241
app/services/rr_scanner_service.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""R:R Scanner service.
|
||||
|
||||
Scans tracked tickers for asymmetric risk-reward trade setups.
|
||||
Long: target = nearest SR above, stop = entry - ATR × multiplier.
|
||||
Short: target = nearest SR below, stop = entry + ATR × multiplier.
|
||||
Filters by configurable R:R threshold (default 3:1).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import NotFoundError
|
||||
from app.models.score import CompositeScore
|
||||
from app.models.sr_level import SRLevel
|
||||
from app.models.ticker import Ticker
|
||||
from app.models.trade_setup import TradeSetup
|
||||
from app.services.indicator_service import _extract_ohlcv, compute_atr
|
||||
from app.services.price_service import query_ohlcv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
async def scan_ticker(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
rr_threshold: float = 3.0,
|
||||
atr_multiplier: float = 1.5,
|
||||
) -> list[TradeSetup]:
|
||||
"""Scan a single ticker for trade setups meeting the R:R threshold.
|
||||
|
||||
1. Fetch OHLCV data and compute ATR.
|
||||
2. Fetch SR levels.
|
||||
3. Compute long and short setups.
|
||||
4. Filter by R:R threshold.
|
||||
5. Delete old setups for this ticker and persist new ones.
|
||||
|
||||
Returns list of persisted TradeSetup models.
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
# Fetch OHLCV
|
||||
records = await query_ohlcv(db, symbol)
|
||||
if not records or len(records) < 15:
|
||||
logger.info(
|
||||
"Skipping %s: insufficient OHLCV data (%d bars, need 15+)",
|
||||
symbol, len(records),
|
||||
)
|
||||
# Clear any stale setups
|
||||
await db.execute(
|
||||
delete(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
|
||||
)
|
||||
return []
|
||||
|
||||
_, highs, lows, closes, _ = _extract_ohlcv(records)
|
||||
entry_price = closes[-1]
|
||||
|
||||
# Compute ATR
|
||||
try:
|
||||
atr_result = compute_atr(highs, lows, closes)
|
||||
atr_value = atr_result["atr"]
|
||||
except Exception:
|
||||
logger.info("Skipping %s: cannot compute ATR", symbol)
|
||||
await db.execute(
|
||||
delete(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
|
||||
)
|
||||
return []
|
||||
|
||||
if atr_value <= 0:
|
||||
logger.info("Skipping %s: ATR is zero or negative", symbol)
|
||||
await db.execute(
|
||||
delete(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
|
||||
)
|
||||
return []
|
||||
|
||||
# Fetch SR levels from DB (already computed by sr_service)
|
||||
sr_result = await db.execute(
|
||||
select(SRLevel).where(SRLevel.ticker_id == ticker.id)
|
||||
)
|
||||
sr_levels = list(sr_result.scalars().all())
|
||||
|
||||
if not sr_levels:
|
||||
logger.info("Skipping %s: no SR levels available", symbol)
|
||||
await db.execute(
|
||||
delete(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
|
||||
)
|
||||
return []
|
||||
|
||||
levels_above = sorted(
|
||||
[lv for lv in sr_levels if lv.price_level > entry_price],
|
||||
key=lambda lv: lv.price_level,
|
||||
)
|
||||
levels_below = sorted(
|
||||
[lv for lv in sr_levels if lv.price_level < entry_price],
|
||||
key=lambda lv: lv.price_level,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# Get composite score for this ticker
|
||||
comp_result = await db.execute(
|
||||
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
|
||||
)
|
||||
comp = comp_result.scalar_one_or_none()
|
||||
composite_score = comp.score if comp else 0.0
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
setups: list[TradeSetup] = []
|
||||
|
||||
# Long setup: target = nearest SR above, stop = entry - ATR × multiplier
|
||||
if levels_above:
|
||||
target = levels_above[0].price_level
|
||||
stop = entry_price - (atr_value * atr_multiplier)
|
||||
reward = target - entry_price
|
||||
risk = entry_price - stop
|
||||
if risk > 0 and reward > 0:
|
||||
rr = reward / risk
|
||||
if rr >= rr_threshold:
|
||||
setups.append(TradeSetup(
|
||||
ticker_id=ticker.id,
|
||||
direction="long",
|
||||
entry_price=round(entry_price, 4),
|
||||
stop_loss=round(stop, 4),
|
||||
target=round(target, 4),
|
||||
rr_ratio=round(rr, 4),
|
||||
composite_score=round(composite_score, 4),
|
||||
detected_at=now,
|
||||
))
|
||||
|
||||
# Short setup: target = nearest SR below, stop = entry + ATR × multiplier
|
||||
if levels_below:
|
||||
target = levels_below[0].price_level
|
||||
stop = entry_price + (atr_value * atr_multiplier)
|
||||
reward = entry_price - target
|
||||
risk = stop - entry_price
|
||||
if risk > 0 and reward > 0:
|
||||
rr = reward / risk
|
||||
if rr >= rr_threshold:
|
||||
setups.append(TradeSetup(
|
||||
ticker_id=ticker.id,
|
||||
direction="short",
|
||||
entry_price=round(entry_price, 4),
|
||||
stop_loss=round(stop, 4),
|
||||
target=round(target, 4),
|
||||
rr_ratio=round(rr, 4),
|
||||
composite_score=round(composite_score, 4),
|
||||
detected_at=now,
|
||||
))
|
||||
|
||||
# Delete old setups for this ticker, persist new ones
|
||||
await db.execute(
|
||||
delete(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
|
||||
)
|
||||
for setup in setups:
|
||||
db.add(setup)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Refresh to get IDs
|
||||
for s in setups:
|
||||
await db.refresh(s)
|
||||
|
||||
return setups
|
||||
|
||||
|
||||
async def scan_all_tickers(
|
||||
db: AsyncSession,
|
||||
rr_threshold: float = 3.0,
|
||||
atr_multiplier: float = 1.5,
|
||||
) -> list[TradeSetup]:
|
||||
"""Scan all tracked tickers for trade setups.
|
||||
|
||||
Processes each ticker independently — one failure doesn't stop others.
|
||||
Returns all setups found across all tickers.
|
||||
"""
|
||||
result = await db.execute(select(Ticker).order_by(Ticker.symbol))
|
||||
tickers = list(result.scalars().all())
|
||||
|
||||
all_setups: list[TradeSetup] = []
|
||||
for ticker in tickers:
|
||||
try:
|
||||
setups = await scan_ticker(
|
||||
db, ticker.symbol, rr_threshold, atr_multiplier
|
||||
)
|
||||
all_setups.extend(setups)
|
||||
except Exception:
|
||||
logger.exception("Error scanning ticker %s", ticker.symbol)
|
||||
|
||||
return all_setups
|
||||
|
||||
|
||||
async def get_trade_setups(
|
||||
db: AsyncSession,
|
||||
direction: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""Get all stored trade setups, optionally filtered by direction.
|
||||
|
||||
Returns dicts sorted by R:R desc, secondary composite desc.
|
||||
Each dict includes the ticker symbol.
|
||||
"""
|
||||
stmt = (
|
||||
select(TradeSetup, Ticker.symbol)
|
||||
.join(Ticker, TradeSetup.ticker_id == Ticker.id)
|
||||
)
|
||||
if direction is not None:
|
||||
stmt = stmt.where(TradeSetup.direction == direction.lower())
|
||||
|
||||
stmt = stmt.order_by(
|
||||
TradeSetup.rr_ratio.desc(),
|
||||
TradeSetup.composite_score.desc(),
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": setup.id,
|
||||
"symbol": symbol,
|
||||
"direction": setup.direction,
|
||||
"entry_price": setup.entry_price,
|
||||
"stop_loss": setup.stop_loss,
|
||||
"target": setup.target,
|
||||
"rr_ratio": setup.rr_ratio,
|
||||
"composite_score": setup.composite_score,
|
||||
"detected_at": setup.detected_at,
|
||||
}
|
||||
for setup, symbol in rows
|
||||
]
|
||||
584
app/services/scoring_service.py
Normal file
584
app/services/scoring_service.py
Normal file
@@ -0,0 +1,584 @@
|
||||
"""Scoring Engine service.
|
||||
|
||||
Computes dimension scores (technical, sr_quality, sentiment, fundamental,
|
||||
momentum) each 0-100, composite score as weighted average of available
|
||||
dimensions with re-normalized weights, staleness marking/recomputation
|
||||
on demand, and weight update triggers full recomputation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import NotFoundError, ValidationError
|
||||
from app.models.score import CompositeScore, DimensionScore
|
||||
from app.models.settings import SystemSetting
|
||||
from app.models.ticker import Ticker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DIMENSIONS = ["technical", "sr_quality", "sentiment", "fundamental", "momentum"]
|
||||
|
||||
DEFAULT_WEIGHTS: dict[str, float] = {
|
||||
"technical": 0.25,
|
||||
"sr_quality": 0.20,
|
||||
"sentiment": 0.15,
|
||||
"fundamental": 0.20,
|
||||
"momentum": 0.20,
|
||||
}
|
||||
|
||||
SCORING_WEIGHTS_KEY = "scoring_weights"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
async def _get_weights(db: AsyncSession) -> dict[str, float]:
|
||||
"""Load scoring weights from SystemSetting, falling back to defaults."""
|
||||
result = await db.execute(
|
||||
select(SystemSetting).where(SystemSetting.key == SCORING_WEIGHTS_KEY)
|
||||
)
|
||||
setting = result.scalar_one_or_none()
|
||||
if setting is not None:
|
||||
try:
|
||||
return json.loads(setting.value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning("Invalid scoring weights in DB, using defaults")
|
||||
return dict(DEFAULT_WEIGHTS)
|
||||
|
||||
|
||||
async def _save_weights(db: AsyncSession, weights: dict[str, float]) -> None:
|
||||
"""Persist scoring weights to SystemSetting."""
|
||||
result = await db.execute(
|
||||
select(SystemSetting).where(SystemSetting.key == SCORING_WEIGHTS_KEY)
|
||||
)
|
||||
setting = result.scalar_one_or_none()
|
||||
now = datetime.now(timezone.utc)
|
||||
if setting is not None:
|
||||
setting.value = json.dumps(weights)
|
||||
setting.updated_at = now
|
||||
else:
|
||||
setting = SystemSetting(
|
||||
key=SCORING_WEIGHTS_KEY,
|
||||
value=json.dumps(weights),
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(setting)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dimension score computation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _compute_technical_score(db: AsyncSession, symbol: str) -> float | None:
|
||||
"""Compute technical dimension score from ADX, EMA, RSI."""
|
||||
from app.services.indicator_service import (
|
||||
compute_adx,
|
||||
compute_ema,
|
||||
compute_rsi,
|
||||
_extract_ohlcv,
|
||||
)
|
||||
from app.services.price_service import query_ohlcv
|
||||
|
||||
records = await query_ohlcv(db, symbol)
|
||||
if not records:
|
||||
return None
|
||||
|
||||
_, highs, lows, closes, _ = _extract_ohlcv(records)
|
||||
|
||||
scores: list[tuple[float, float]] = [] # (weight, score)
|
||||
|
||||
# ADX (weight 0.4) — needs 28+ bars
|
||||
try:
|
||||
adx_result = compute_adx(highs, lows, closes)
|
||||
scores.append((0.4, adx_result["score"]))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# EMA (weight 0.3) — needs period+1 bars
|
||||
try:
|
||||
ema_result = compute_ema(closes)
|
||||
scores.append((0.3, ema_result["score"]))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# RSI (weight 0.3) — needs 15+ bars
|
||||
try:
|
||||
rsi_result = compute_rsi(closes)
|
||||
scores.append((0.3, rsi_result["score"]))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not scores:
|
||||
return None
|
||||
|
||||
total_weight = sum(w for w, _ in scores)
|
||||
if total_weight == 0:
|
||||
return None
|
||||
weighted = sum(w * s for w, s in scores) / total_weight
|
||||
return max(0.0, min(100.0, weighted))
|
||||
|
||||
|
||||
async def _compute_sr_quality_score(db: AsyncSession, symbol: str) -> float | None:
|
||||
"""Compute S/R quality dimension score.
|
||||
|
||||
Based on number of strong levels, proximity to current price, avg strength.
|
||||
"""
|
||||
from app.services.price_service import query_ohlcv
|
||||
from app.services.sr_service import get_sr_levels
|
||||
|
||||
records = await query_ohlcv(db, symbol)
|
||||
if not records:
|
||||
return None
|
||||
|
||||
current_price = float(records[-1].close)
|
||||
if current_price <= 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
levels = await get_sr_levels(db, symbol)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not levels:
|
||||
return None
|
||||
|
||||
# Factor 1: Number of strong levels (strength >= 50) — max 40 pts
|
||||
strong_count = sum(1 for lv in levels if lv.strength >= 50)
|
||||
count_score = min(40.0, strong_count * 10.0)
|
||||
|
||||
# Factor 2: Proximity of nearest level to current price — max 30 pts
|
||||
distances = [
|
||||
abs(lv.price_level - current_price) / current_price for lv in levels
|
||||
]
|
||||
nearest_dist = min(distances) if distances else 1.0
|
||||
# Closer = higher score. 0% distance = 30, 5%+ = 0
|
||||
proximity_score = max(0.0, min(30.0, 30.0 * (1.0 - nearest_dist / 0.05)))
|
||||
|
||||
# Factor 3: Average strength — max 30 pts
|
||||
avg_strength = sum(lv.strength for lv in levels) / len(levels)
|
||||
strength_score = min(30.0, avg_strength * 0.3)
|
||||
|
||||
total = count_score + proximity_score + strength_score
|
||||
return max(0.0, min(100.0, total))
|
||||
|
||||
|
||||
async def _compute_sentiment_score(db: AsyncSession, symbol: str) -> float | None:
|
||||
"""Compute sentiment dimension score via sentiment service."""
|
||||
from app.services.sentiment_service import compute_sentiment_dimension_score
|
||||
|
||||
try:
|
||||
return await compute_sentiment_dimension_score(db, symbol)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def _compute_fundamental_score(db: AsyncSession, symbol: str) -> float | None:
|
||||
"""Compute fundamental dimension score.
|
||||
|
||||
Normalized composite of P/E (lower is better), revenue growth
|
||||
(higher is better), earnings surprise (higher is better).
|
||||
"""
|
||||
from app.services.fundamental_service import get_fundamental
|
||||
|
||||
fund = await get_fundamental(db, symbol)
|
||||
if fund is None:
|
||||
return None
|
||||
|
||||
scores: list[float] = []
|
||||
|
||||
# P/E: lower is better. 0-15 = 100, 15-30 = 50-100, 30+ = 0-50
|
||||
if fund.pe_ratio is not None and fund.pe_ratio > 0:
|
||||
pe_score = max(0.0, min(100.0, 100.0 - (fund.pe_ratio - 15.0) * (100.0 / 30.0)))
|
||||
scores.append(pe_score)
|
||||
|
||||
# Revenue growth: higher is better. 0% = 50, 20%+ = 100, -20% = 0
|
||||
if fund.revenue_growth is not None:
|
||||
rg_score = max(0.0, min(100.0, 50.0 + fund.revenue_growth * 2.5))
|
||||
scores.append(rg_score)
|
||||
|
||||
# Earnings surprise: higher is better. 0% = 50, 10%+ = 100, -10% = 0
|
||||
if fund.earnings_surprise is not None:
|
||||
es_score = max(0.0, min(100.0, 50.0 + fund.earnings_surprise * 5.0))
|
||||
scores.append(es_score)
|
||||
|
||||
if not scores:
|
||||
return None
|
||||
|
||||
return sum(scores) / len(scores)
|
||||
|
||||
|
||||
async def _compute_momentum_score(db: AsyncSession, symbol: str) -> float | None:
|
||||
"""Compute momentum dimension score.
|
||||
|
||||
Rate of change of price over 5-day and 20-day lookback periods.
|
||||
"""
|
||||
from app.services.price_service import query_ohlcv
|
||||
|
||||
records = await query_ohlcv(db, symbol)
|
||||
if not records or len(records) < 6:
|
||||
return None
|
||||
|
||||
closes = [float(r.close) for r in records]
|
||||
latest = closes[-1]
|
||||
|
||||
scores: list[tuple[float, float]] = [] # (weight, score)
|
||||
|
||||
# 5-day ROC (weight 0.5)
|
||||
if len(closes) >= 6 and closes[-6] > 0:
|
||||
roc_5 = (latest - closes[-6]) / closes[-6] * 100.0
|
||||
# Map: -10% → 0, 0% → 50, +10% → 100
|
||||
score_5 = max(0.0, min(100.0, 50.0 + roc_5 * 5.0))
|
||||
scores.append((0.5, score_5))
|
||||
|
||||
# 20-day ROC (weight 0.5)
|
||||
if len(closes) >= 21 and closes[-21] > 0:
|
||||
roc_20 = (latest - closes[-21]) / closes[-21] * 100.0
|
||||
score_20 = max(0.0, min(100.0, 50.0 + roc_20 * 5.0))
|
||||
scores.append((0.5, score_20))
|
||||
|
||||
if not scores:
|
||||
return None
|
||||
|
||||
total_weight = sum(w for w, _ in scores)
|
||||
if total_weight == 0:
|
||||
return None
|
||||
weighted = sum(w * s for w, s in scores) / total_weight
|
||||
return max(0.0, min(100.0, weighted))
|
||||
|
||||
|
||||
_DIMENSION_COMPUTERS = {
|
||||
"technical": _compute_technical_score,
|
||||
"sr_quality": _compute_sr_quality_score,
|
||||
"sentiment": _compute_sentiment_score,
|
||||
"fundamental": _compute_fundamental_score,
|
||||
"momentum": _compute_momentum_score,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def compute_dimension_score(
|
||||
db: AsyncSession, symbol: str, dimension: str
|
||||
) -> float | None:
|
||||
"""Compute a single dimension score for a ticker.
|
||||
|
||||
Returns the score (0-100) or None if insufficient data.
|
||||
Persists the result to the DimensionScore table.
|
||||
"""
|
||||
if dimension not in _DIMENSION_COMPUTERS:
|
||||
raise ValidationError(
|
||||
f"Unknown dimension: {dimension}. Valid: {', '.join(DIMENSIONS)}"
|
||||
)
|
||||
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
score_val = await _DIMENSION_COMPUTERS[dimension](db, symbol)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Upsert dimension score
|
||||
result = await db.execute(
|
||||
select(DimensionScore).where(
|
||||
DimensionScore.ticker_id == ticker.id,
|
||||
DimensionScore.dimension == dimension,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if score_val is not None:
|
||||
score_val = max(0.0, min(100.0, score_val))
|
||||
|
||||
if existing is not None:
|
||||
if score_val is not None:
|
||||
existing.score = score_val
|
||||
existing.is_stale = False
|
||||
existing.computed_at = now
|
||||
else:
|
||||
# Can't compute — mark stale
|
||||
existing.is_stale = True
|
||||
elif score_val is not None:
|
||||
dim = DimensionScore(
|
||||
ticker_id=ticker.id,
|
||||
dimension=dimension,
|
||||
score=score_val,
|
||||
is_stale=False,
|
||||
computed_at=now,
|
||||
)
|
||||
db.add(dim)
|
||||
|
||||
return score_val
|
||||
|
||||
|
||||
async def compute_all_dimensions(
|
||||
db: AsyncSession, symbol: str
|
||||
) -> dict[str, float | None]:
|
||||
"""Compute all dimension scores for a ticker. Returns dimension → score map."""
|
||||
results: dict[str, float | None] = {}
|
||||
for dim in DIMENSIONS:
|
||||
results[dim] = await compute_dimension_score(db, symbol, dim)
|
||||
return results
|
||||
|
||||
|
||||
async def compute_composite_score(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
weights: dict[str, float] | None = None,
|
||||
) -> tuple[float | None, list[str]]:
|
||||
"""Compute composite score from available dimension scores.
|
||||
|
||||
Returns (composite_score, missing_dimensions).
|
||||
Missing dimensions are excluded and weights re-normalized.
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
if weights is None:
|
||||
weights = await _get_weights(db)
|
||||
|
||||
# Get current dimension scores
|
||||
result = await db.execute(
|
||||
select(DimensionScore).where(DimensionScore.ticker_id == ticker.id)
|
||||
)
|
||||
dim_scores = {ds.dimension: ds for ds in result.scalars().all()}
|
||||
|
||||
available: list[tuple[str, float, float]] = [] # (dim, weight, score)
|
||||
missing: list[str] = []
|
||||
|
||||
for dim in DIMENSIONS:
|
||||
w = weights.get(dim, 0.0)
|
||||
if w <= 0:
|
||||
continue
|
||||
ds = dim_scores.get(dim)
|
||||
if ds is not None and not ds.is_stale and ds.score is not None:
|
||||
available.append((dim, w, ds.score))
|
||||
else:
|
||||
missing.append(dim)
|
||||
|
||||
if not available:
|
||||
return None, missing
|
||||
|
||||
# Re-normalize weights
|
||||
total_weight = sum(w for _, w, _ in available)
|
||||
if total_weight == 0:
|
||||
return None, missing
|
||||
|
||||
composite = sum(w * s for _, w, s in available) / total_weight
|
||||
composite = max(0.0, min(100.0, composite))
|
||||
|
||||
# Persist composite score
|
||||
now = datetime.now(timezone.utc)
|
||||
comp_result = await db.execute(
|
||||
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
|
||||
)
|
||||
existing = comp_result.scalar_one_or_none()
|
||||
|
||||
if existing is not None:
|
||||
existing.score = composite
|
||||
existing.is_stale = False
|
||||
existing.weights_json = json.dumps(weights)
|
||||
existing.computed_at = now
|
||||
else:
|
||||
comp = CompositeScore(
|
||||
ticker_id=ticker.id,
|
||||
score=composite,
|
||||
is_stale=False,
|
||||
weights_json=json.dumps(weights),
|
||||
computed_at=now,
|
||||
)
|
||||
db.add(comp)
|
||||
|
||||
return composite, missing
|
||||
|
||||
|
||||
async def get_score(
|
||||
db: AsyncSession, symbol: str
|
||||
) -> dict:
|
||||
"""Get composite + all dimension scores for a ticker.
|
||||
|
||||
Recomputes stale dimensions on demand, then recomputes composite.
|
||||
Returns a dict suitable for ScoreResponse.
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
weights = await _get_weights(db)
|
||||
|
||||
# Check for stale dimension scores and recompute them
|
||||
result = await db.execute(
|
||||
select(DimensionScore).where(DimensionScore.ticker_id == ticker.id)
|
||||
)
|
||||
dim_scores = {ds.dimension: ds for ds in result.scalars().all()}
|
||||
|
||||
for dim in DIMENSIONS:
|
||||
ds = dim_scores.get(dim)
|
||||
if ds is None or ds.is_stale:
|
||||
await compute_dimension_score(db, symbol, dim)
|
||||
|
||||
# Check composite staleness
|
||||
comp_result = await db.execute(
|
||||
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
|
||||
)
|
||||
comp = comp_result.scalar_one_or_none()
|
||||
|
||||
if comp is None or comp.is_stale:
|
||||
await compute_composite_score(db, symbol, weights)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Re-fetch everything fresh
|
||||
result = await db.execute(
|
||||
select(DimensionScore).where(DimensionScore.ticker_id == ticker.id)
|
||||
)
|
||||
dim_scores_list = list(result.scalars().all())
|
||||
|
||||
comp_result = await db.execute(
|
||||
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
|
||||
)
|
||||
comp = comp_result.scalar_one_or_none()
|
||||
|
||||
dimensions = []
|
||||
missing = []
|
||||
for dim in DIMENSIONS:
|
||||
found = next((ds for ds in dim_scores_list if ds.dimension == dim), None)
|
||||
if found is not None:
|
||||
dimensions.append({
|
||||
"dimension": found.dimension,
|
||||
"score": found.score,
|
||||
"is_stale": found.is_stale,
|
||||
"computed_at": found.computed_at,
|
||||
})
|
||||
else:
|
||||
missing.append(dim)
|
||||
|
||||
return {
|
||||
"symbol": ticker.symbol,
|
||||
"composite_score": comp.score if comp else None,
|
||||
"composite_stale": comp.is_stale if comp else False,
|
||||
"weights": weights,
|
||||
"dimensions": dimensions,
|
||||
"missing_dimensions": missing,
|
||||
"computed_at": comp.computed_at if comp else None,
|
||||
}
|
||||
|
||||
|
||||
async def get_rankings(db: AsyncSession) -> dict:
|
||||
"""Get all tickers ranked by composite score descending.
|
||||
|
||||
Returns dict suitable for RankingResponse.
|
||||
"""
|
||||
weights = await _get_weights(db)
|
||||
|
||||
# Get all tickers
|
||||
result = await db.execute(select(Ticker).order_by(Ticker.symbol))
|
||||
tickers = list(result.scalars().all())
|
||||
|
||||
rankings: list[dict] = []
|
||||
for ticker in tickers:
|
||||
# Get composite score
|
||||
comp_result = await db.execute(
|
||||
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
|
||||
)
|
||||
comp = comp_result.scalar_one_or_none()
|
||||
|
||||
# If no composite or stale, recompute
|
||||
if comp is None or comp.is_stale:
|
||||
# Recompute stale dimensions first
|
||||
dim_result = await db.execute(
|
||||
select(DimensionScore).where(
|
||||
DimensionScore.ticker_id == ticker.id
|
||||
)
|
||||
)
|
||||
dim_scores = {ds.dimension: ds for ds in dim_result.scalars().all()}
|
||||
for dim in DIMENSIONS:
|
||||
ds = dim_scores.get(dim)
|
||||
if ds is None or ds.is_stale:
|
||||
await compute_dimension_score(db, ticker.symbol, dim)
|
||||
|
||||
await compute_composite_score(db, ticker.symbol, weights)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Re-fetch
|
||||
comp_result = await db.execute(
|
||||
select(CompositeScore).where(CompositeScore.ticker_id == ticker.id)
|
||||
)
|
||||
comp = comp_result.scalar_one_or_none()
|
||||
if comp is None:
|
||||
continue
|
||||
|
||||
dim_result = await db.execute(
|
||||
select(DimensionScore).where(
|
||||
DimensionScore.ticker_id == ticker.id
|
||||
)
|
||||
)
|
||||
dims = [
|
||||
{
|
||||
"dimension": ds.dimension,
|
||||
"score": ds.score,
|
||||
"is_stale": ds.is_stale,
|
||||
"computed_at": ds.computed_at,
|
||||
}
|
||||
for ds in dim_result.scalars().all()
|
||||
]
|
||||
|
||||
rankings.append({
|
||||
"symbol": ticker.symbol,
|
||||
"composite_score": comp.score,
|
||||
"dimensions": dims,
|
||||
})
|
||||
|
||||
# Sort by composite score descending
|
||||
rankings.sort(key=lambda r: r["composite_score"], reverse=True)
|
||||
|
||||
return {
|
||||
"rankings": rankings,
|
||||
"weights": weights,
|
||||
}
|
||||
|
||||
|
||||
async def update_weights(
|
||||
db: AsyncSession, weights: dict[str, float]
|
||||
) -> dict[str, float]:
|
||||
"""Update scoring weights and recompute all composite scores.
|
||||
|
||||
Validates that all weights are positive and dimensions are valid.
|
||||
Returns the new weights.
|
||||
"""
|
||||
# Validate
|
||||
for dim, w in weights.items():
|
||||
if dim not in DIMENSIONS:
|
||||
raise ValidationError(
|
||||
f"Unknown dimension: {dim}. Valid: {', '.join(DIMENSIONS)}"
|
||||
)
|
||||
if w < 0:
|
||||
raise ValidationError(f"Weight for {dim} must be non-negative, got {w}")
|
||||
|
||||
# Ensure all dimensions have a weight (default 0 for unspecified)
|
||||
full_weights = {dim: weights.get(dim, 0.0) for dim in DIMENSIONS}
|
||||
|
||||
# Persist
|
||||
await _save_weights(db, full_weights)
|
||||
|
||||
# Recompute all composite scores
|
||||
result = await db.execute(select(Ticker))
|
||||
tickers = list(result.scalars().all())
|
||||
|
||||
for ticker in tickers:
|
||||
await compute_composite_score(db, ticker.symbol, full_weights)
|
||||
|
||||
await db.commit()
|
||||
return full_weights
|
||||
131
app/services/sentiment_service.py
Normal file
131
app/services/sentiment_service.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Sentiment service.
|
||||
|
||||
Stores sentiment records and computes the sentiment dimension score
|
||||
using a time-decay weighted average over a configurable lookback window.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import NotFoundError
|
||||
from app.models.sentiment import SentimentScore
|
||||
from app.models.ticker import Ticker
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
"""Look up a ticker by symbol."""
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
async def store_sentiment(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
classification: str,
|
||||
confidence: int,
|
||||
source: str,
|
||||
timestamp: datetime | None = None,
|
||||
) -> SentimentScore:
|
||||
"""Store a new sentiment record for a ticker."""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now(timezone.utc)
|
||||
|
||||
record = SentimentScore(
|
||||
ticker_id=ticker.id,
|
||||
classification=classification,
|
||||
confidence=confidence,
|
||||
source=source,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
db.add(record)
|
||||
await db.commit()
|
||||
await db.refresh(record)
|
||||
return record
|
||||
|
||||
|
||||
async def get_sentiment_scores(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
lookback_hours: float = 24,
|
||||
) -> list[SentimentScore]:
|
||||
"""Get recent sentiment records within the lookback window."""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(hours=lookback_hours)
|
||||
|
||||
result = await db.execute(
|
||||
select(SentimentScore)
|
||||
.where(
|
||||
SentimentScore.ticker_id == ticker.id,
|
||||
SentimentScore.timestamp >= cutoff,
|
||||
)
|
||||
.order_by(SentimentScore.timestamp.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
def _classification_to_base_score(classification: str, confidence: int) -> float:
|
||||
"""Map classification + confidence to a base score (0-100).
|
||||
|
||||
bullish → confidence (high confidence = high score)
|
||||
bearish → 100 - confidence (high confidence bearish = low score)
|
||||
neutral → 50
|
||||
"""
|
||||
cl = classification.lower()
|
||||
if cl == "bullish":
|
||||
return float(confidence)
|
||||
elif cl == "bearish":
|
||||
return float(100 - confidence)
|
||||
else:
|
||||
return 50.0
|
||||
|
||||
|
||||
async def compute_sentiment_dimension_score(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
lookback_hours: float = 24,
|
||||
decay_rate: float = 0.1,
|
||||
) -> float | None:
|
||||
"""Compute the sentiment dimension score using time-decay weighted average.
|
||||
|
||||
Returns a score in [0, 100] or None if no scores exist in the window.
|
||||
|
||||
Algorithm:
|
||||
1. For each score in the lookback window, compute base_score from
|
||||
classification + confidence.
|
||||
2. Apply time decay: weight = exp(-decay_rate * hours_since_score).
|
||||
3. Weighted average: sum(base_score * weight) / sum(weight).
|
||||
"""
|
||||
scores = await get_sentiment_scores(db, symbol, lookback_hours)
|
||||
if not scores:
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
weighted_sum = 0.0
|
||||
weight_total = 0.0
|
||||
|
||||
for score in scores:
|
||||
ts = score.timestamp
|
||||
if ts.tzinfo is None:
|
||||
ts = ts.replace(tzinfo=timezone.utc)
|
||||
hours_since = (now - ts).total_seconds() / 3600.0
|
||||
weight = math.exp(-decay_rate * hours_since)
|
||||
base = _classification_to_base_score(score.classification, score.confidence)
|
||||
weighted_sum += base * weight
|
||||
weight_total += weight
|
||||
|
||||
if weight_total == 0:
|
||||
return None
|
||||
|
||||
result = weighted_sum / weight_total
|
||||
return max(0.0, min(100.0, result))
|
||||
274
app/services/sr_service.py
Normal file
274
app/services/sr_service.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""S/R Detector service.
|
||||
|
||||
Detects support/resistance levels from Volume Profile (HVN/LVN) and
|
||||
Pivot Points (swing highs/lows), assigns strength scores, merges nearby
|
||||
levels, tags as support/resistance, and persists to DB.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import NotFoundError, ValidationError
|
||||
from app.models.sr_level import SRLevel
|
||||
from app.models.ticker import Ticker
|
||||
from app.services.indicator_service import (
|
||||
_extract_ohlcv,
|
||||
compute_pivot_points,
|
||||
compute_volume_profile,
|
||||
)
|
||||
from app.services.price_service import query_ohlcv
|
||||
|
||||
DEFAULT_TOLERANCE = 0.005 # 0.5%
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
"""Look up a ticker by symbol."""
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
def _count_price_touches(
|
||||
price_level: float,
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
tolerance: float = DEFAULT_TOLERANCE,
|
||||
) -> int:
|
||||
"""Count how many bars touched/respected a price level within tolerance."""
|
||||
count = 0
|
||||
tol = price_level * tolerance if price_level != 0 else tolerance
|
||||
for i in range(len(closes)):
|
||||
# A bar "touches" the level if the level is within the bar's range
|
||||
# (within tolerance)
|
||||
if lows[i] - tol <= price_level <= highs[i] + tol:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def _strength_from_touches(touches: int, total_bars: int) -> int:
|
||||
"""Convert touch count to a 0-100 strength score.
|
||||
|
||||
More touches relative to total bars = higher strength.
|
||||
Cap at 100.
|
||||
"""
|
||||
if total_bars == 0:
|
||||
return 0
|
||||
# Scale: each touch contributes proportionally, with a multiplier
|
||||
# so that a level touched ~20% of bars gets score ~100
|
||||
raw = (touches / total_bars) * 500.0
|
||||
return max(0, min(100, int(round(raw))))
|
||||
|
||||
|
||||
def _extract_candidate_levels(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
volumes: list[int],
|
||||
) -> list[tuple[float, str]]:
|
||||
"""Extract candidate S/R levels from Volume Profile and Pivot Points.
|
||||
|
||||
Returns list of (price_level, detection_method) tuples.
|
||||
"""
|
||||
candidates: list[tuple[float, str]] = []
|
||||
|
||||
# Volume Profile: HVN and LVN as candidate levels
|
||||
try:
|
||||
vp = compute_volume_profile(highs, lows, closes, volumes)
|
||||
for price in vp.get("hvn", []):
|
||||
candidates.append((price, "volume_profile"))
|
||||
for price in vp.get("lvn", []):
|
||||
candidates.append((price, "volume_profile"))
|
||||
except ValidationError:
|
||||
pass # Not enough data for volume profile
|
||||
|
||||
# Pivot Points: swing highs and lows
|
||||
try:
|
||||
pp = compute_pivot_points(highs, lows, closes)
|
||||
for price in pp.get("swing_highs", []):
|
||||
candidates.append((price, "pivot_point"))
|
||||
for price in pp.get("swing_lows", []):
|
||||
candidates.append((price, "pivot_point"))
|
||||
except ValidationError:
|
||||
pass # Not enough data for pivot points
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
def _merge_levels(
|
||||
levels: list[dict],
|
||||
tolerance: float = DEFAULT_TOLERANCE,
|
||||
) -> list[dict]:
|
||||
"""Merge levels within tolerance into consolidated levels.
|
||||
|
||||
Levels from different methods within tolerance are merged.
|
||||
Merged levels combine strength scores (capped at 100) and get
|
||||
detection_method = "merged".
|
||||
"""
|
||||
if not levels:
|
||||
return []
|
||||
|
||||
# Sort by price
|
||||
sorted_levels = sorted(levels, key=lambda x: x["price_level"])
|
||||
merged: list[dict] = []
|
||||
|
||||
for level in sorted_levels:
|
||||
if not merged:
|
||||
merged.append(dict(level))
|
||||
continue
|
||||
|
||||
last = merged[-1]
|
||||
ref_price = last["price_level"]
|
||||
tol = ref_price * tolerance if ref_price != 0 else tolerance
|
||||
|
||||
if abs(level["price_level"] - ref_price) <= tol:
|
||||
# Merge: average price, combine strength, mark as merged
|
||||
combined_strength = min(100, last["strength"] + level["strength"])
|
||||
avg_price = (last["price_level"] + level["price_level"]) / 2.0
|
||||
method = (
|
||||
"merged"
|
||||
if last["detection_method"] != level["detection_method"]
|
||||
else last["detection_method"]
|
||||
)
|
||||
last["price_level"] = round(avg_price, 4)
|
||||
last["strength"] = combined_strength
|
||||
last["detection_method"] = method
|
||||
else:
|
||||
merged.append(dict(level))
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def _tag_levels(
|
||||
levels: list[dict],
|
||||
current_price: float,
|
||||
) -> list[dict]:
|
||||
"""Tag each level as 'support' or 'resistance' relative to current price."""
|
||||
for level in levels:
|
||||
if level["price_level"] < current_price:
|
||||
level["type"] = "support"
|
||||
else:
|
||||
level["type"] = "resistance"
|
||||
return levels
|
||||
|
||||
|
||||
def detect_sr_levels(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
volumes: list[int],
|
||||
tolerance: float = DEFAULT_TOLERANCE,
|
||||
) -> list[dict]:
|
||||
"""Detect, score, merge, and tag S/R levels from OHLCV data.
|
||||
|
||||
Returns list of dicts with keys: price_level, type, strength,
|
||||
detection_method — sorted by strength descending.
|
||||
"""
|
||||
if not closes:
|
||||
return []
|
||||
|
||||
candidates = _extract_candidate_levels(highs, lows, closes, volumes)
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
total_bars = len(closes)
|
||||
current_price = closes[-1]
|
||||
|
||||
# Build level dicts with strength scores
|
||||
raw_levels: list[dict] = []
|
||||
for price, method in candidates:
|
||||
touches = _count_price_touches(price, highs, lows, closes, tolerance)
|
||||
strength = _strength_from_touches(touches, total_bars)
|
||||
raw_levels.append({
|
||||
"price_level": price,
|
||||
"strength": strength,
|
||||
"detection_method": method,
|
||||
"type": "", # will be tagged after merge
|
||||
})
|
||||
|
||||
# Merge nearby levels
|
||||
merged = _merge_levels(raw_levels, tolerance)
|
||||
|
||||
# Tag as support/resistance
|
||||
tagged = _tag_levels(merged, current_price)
|
||||
|
||||
# Sort by strength descending
|
||||
tagged.sort(key=lambda x: x["strength"], reverse=True)
|
||||
|
||||
return tagged
|
||||
|
||||
|
||||
async def recalculate_sr_levels(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
tolerance: float = DEFAULT_TOLERANCE,
|
||||
) -> list[SRLevel]:
|
||||
"""Recalculate S/R levels for a ticker and persist to DB.
|
||||
|
||||
1. Fetch OHLCV data
|
||||
2. Detect levels
|
||||
3. Delete old levels for ticker
|
||||
4. Insert new levels
|
||||
5. Return new levels sorted by strength desc
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
records = await query_ohlcv(db, symbol)
|
||||
if not records:
|
||||
# No OHLCV data — clear any existing levels
|
||||
await db.execute(
|
||||
delete(SRLevel).where(SRLevel.ticker_id == ticker.id)
|
||||
)
|
||||
await db.commit()
|
||||
return []
|
||||
|
||||
_, highs, lows, closes, volumes = _extract_ohlcv(records)
|
||||
|
||||
levels = detect_sr_levels(highs, lows, closes, volumes, tolerance)
|
||||
|
||||
# Delete old levels
|
||||
await db.execute(
|
||||
delete(SRLevel).where(SRLevel.ticker_id == ticker.id)
|
||||
)
|
||||
|
||||
# Insert new levels
|
||||
now = datetime.utcnow()
|
||||
new_models: list[SRLevel] = []
|
||||
for lvl in levels:
|
||||
model = SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=lvl["price_level"],
|
||||
type=lvl["type"],
|
||||
strength=lvl["strength"],
|
||||
detection_method=lvl["detection_method"],
|
||||
created_at=now,
|
||||
)
|
||||
db.add(model)
|
||||
new_models.append(model)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Refresh to get IDs
|
||||
for m in new_models:
|
||||
await db.refresh(m)
|
||||
|
||||
return new_models
|
||||
|
||||
|
||||
async def get_sr_levels(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
tolerance: float = DEFAULT_TOLERANCE,
|
||||
) -> list[SRLevel]:
|
||||
"""Get S/R levels for a ticker, recalculating on every request (MVP).
|
||||
|
||||
Returns levels sorted by strength descending.
|
||||
"""
|
||||
return await recalculate_sr_levels(db, symbol, tolerance)
|
||||
57
app/services/ticker_service.py
Normal file
57
app/services/ticker_service.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Ticker Registry service: add, delete, and list tracked tickers."""
|
||||
|
||||
import re
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import DuplicateError, NotFoundError, ValidationError
|
||||
from app.models.ticker import Ticker
|
||||
|
||||
|
||||
async def add_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
"""Add a new ticker after validation.
|
||||
|
||||
Validates: non-empty, uppercase alphanumeric. Auto-uppercases input.
|
||||
Raises DuplicateError if symbol already tracked.
|
||||
"""
|
||||
stripped = symbol.strip()
|
||||
if not stripped:
|
||||
raise ValidationError("Ticker symbol must not be empty or whitespace-only")
|
||||
|
||||
normalised = stripped.upper()
|
||||
if not re.fullmatch(r"[A-Z0-9]+", normalised):
|
||||
raise ValidationError(
|
||||
f"Ticker symbol must be alphanumeric: {normalised}"
|
||||
)
|
||||
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
if result.scalar_one_or_none() is not None:
|
||||
raise DuplicateError(f"Ticker already exists: {normalised}")
|
||||
|
||||
ticker = Ticker(symbol=normalised)
|
||||
db.add(ticker)
|
||||
await db.commit()
|
||||
await db.refresh(ticker)
|
||||
return ticker
|
||||
|
||||
|
||||
async def delete_ticker(db: AsyncSession, symbol: str) -> None:
|
||||
"""Delete a ticker and cascade all associated data.
|
||||
|
||||
Raises NotFoundError if the symbol is not tracked.
|
||||
"""
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
|
||||
await db.delete(ticker)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def list_tickers(db: AsyncSession) -> list[Ticker]:
|
||||
"""Return all tracked tickers sorted alphabetically by symbol."""
|
||||
result = await db.execute(select(Ticker).order_by(Ticker.symbol.asc()))
|
||||
return list(result.scalars().all())
|
||||
288
app/services/watchlist_service.py
Normal file
288
app/services/watchlist_service.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Watchlist service.
|
||||
|
||||
Auto-populates top-X tickers by composite score (default 10), supports
|
||||
manual add/remove (tagged, not subject to auto-population), enforces
|
||||
cap (auto + 10 manual, default max 20), and updates auto entries on
|
||||
score recomputation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import DuplicateError, NotFoundError, ValidationError
|
||||
from app.models.score import CompositeScore, DimensionScore
|
||||
from app.models.sr_level import SRLevel
|
||||
from app.models.ticker import Ticker
|
||||
from app.models.trade_setup import TradeSetup
|
||||
from app.models.watchlist import WatchlistEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_AUTO_SIZE = 10
|
||||
MAX_MANUAL = 10
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
async def auto_populate(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
top_x: int = DEFAULT_AUTO_SIZE,
|
||||
) -> None:
|
||||
"""Auto-populate watchlist with top-X tickers by composite score.
|
||||
|
||||
Replaces existing auto entries. Manual entries are untouched.
|
||||
"""
|
||||
# Get top-X tickers by composite score (non-stale, descending)
|
||||
stmt = (
|
||||
select(CompositeScore)
|
||||
.where(CompositeScore.is_stale == False) # noqa: E712
|
||||
.order_by(CompositeScore.score.desc())
|
||||
.limit(top_x)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
top_scores = list(result.scalars().all())
|
||||
top_ticker_ids = {cs.ticker_id for cs in top_scores}
|
||||
|
||||
# Delete existing auto entries for this user
|
||||
await db.execute(
|
||||
delete(WatchlistEntry).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
WatchlistEntry.entry_type == "auto",
|
||||
)
|
||||
)
|
||||
|
||||
# Get manual ticker_ids so we don't duplicate
|
||||
manual_result = await db.execute(
|
||||
select(WatchlistEntry.ticker_id).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
WatchlistEntry.entry_type == "manual",
|
||||
)
|
||||
)
|
||||
manual_ticker_ids = {row[0] for row in manual_result.all()}
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
for ticker_id in top_ticker_ids:
|
||||
if ticker_id in manual_ticker_ids:
|
||||
continue # Already on watchlist as manual
|
||||
entry = WatchlistEntry(
|
||||
user_id=user_id,
|
||||
ticker_id=ticker_id,
|
||||
entry_type="auto",
|
||||
added_at=now,
|
||||
)
|
||||
db.add(entry)
|
||||
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def add_manual_entry(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
symbol: str,
|
||||
) -> WatchlistEntry:
|
||||
"""Add a manual watchlist entry.
|
||||
|
||||
Raises DuplicateError if already on watchlist.
|
||||
Raises ValidationError if manual cap exceeded.
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
# Check if already on watchlist
|
||||
existing = await db.execute(
|
||||
select(WatchlistEntry).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
WatchlistEntry.ticker_id == ticker.id,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise DuplicateError(f"Ticker already on watchlist: {ticker.symbol}")
|
||||
|
||||
# Count current manual entries
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(WatchlistEntry).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
WatchlistEntry.entry_type == "manual",
|
||||
)
|
||||
)
|
||||
manual_count = count_result.scalar() or 0
|
||||
|
||||
if manual_count >= MAX_MANUAL:
|
||||
raise ValidationError(
|
||||
f"Manual watchlist cap reached ({MAX_MANUAL}). "
|
||||
"Remove an entry before adding a new one."
|
||||
)
|
||||
|
||||
# Check total cap
|
||||
total_result = await db.execute(
|
||||
select(func.count()).select_from(WatchlistEntry).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
)
|
||||
)
|
||||
total_count = total_result.scalar() or 0
|
||||
max_total = DEFAULT_AUTO_SIZE + MAX_MANUAL
|
||||
|
||||
if total_count >= max_total:
|
||||
raise ValidationError(
|
||||
f"Watchlist cap reached ({max_total}). "
|
||||
"Remove an entry before adding a new one."
|
||||
)
|
||||
|
||||
entry = WatchlistEntry(
|
||||
user_id=user_id,
|
||||
ticker_id=ticker.id,
|
||||
entry_type="manual",
|
||||
added_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(entry)
|
||||
await db.commit()
|
||||
await db.refresh(entry)
|
||||
return entry
|
||||
|
||||
|
||||
async def remove_entry(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
symbol: str,
|
||||
) -> None:
|
||||
"""Remove a watchlist entry (manual or auto)."""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
result = await db.execute(
|
||||
select(WatchlistEntry).where(
|
||||
WatchlistEntry.user_id == user_id,
|
||||
WatchlistEntry.ticker_id == ticker.id,
|
||||
)
|
||||
)
|
||||
entry = result.scalar_one_or_none()
|
||||
if entry is None:
|
||||
raise NotFoundError(f"Ticker not on watchlist: {ticker.symbol}")
|
||||
|
||||
await db.delete(entry)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _enrich_entry(
|
||||
db: AsyncSession,
|
||||
entry: WatchlistEntry,
|
||||
symbol: str,
|
||||
) -> dict:
|
||||
"""Build enriched watchlist entry dict with scores, R:R, and SR levels."""
|
||||
ticker_id = entry.ticker_id
|
||||
|
||||
# Composite score
|
||||
comp_result = await db.execute(
|
||||
select(CompositeScore).where(CompositeScore.ticker_id == ticker_id)
|
||||
)
|
||||
comp = comp_result.scalar_one_or_none()
|
||||
|
||||
# Dimension scores
|
||||
dim_result = await db.execute(
|
||||
select(DimensionScore).where(DimensionScore.ticker_id == ticker_id)
|
||||
)
|
||||
dims = [
|
||||
{"dimension": ds.dimension, "score": ds.score}
|
||||
for ds in dim_result.scalars().all()
|
||||
]
|
||||
|
||||
# Best trade setup (highest R:R) for this ticker
|
||||
setup_result = await db.execute(
|
||||
select(TradeSetup)
|
||||
.where(TradeSetup.ticker_id == ticker_id)
|
||||
.order_by(TradeSetup.rr_ratio.desc())
|
||||
.limit(1)
|
||||
)
|
||||
setup = setup_result.scalar_one_or_none()
|
||||
|
||||
# Active SR levels
|
||||
sr_result = await db.execute(
|
||||
select(SRLevel)
|
||||
.where(SRLevel.ticker_id == ticker_id)
|
||||
.order_by(SRLevel.strength.desc())
|
||||
)
|
||||
sr_levels = [
|
||||
{
|
||||
"price_level": lv.price_level,
|
||||
"type": lv.type,
|
||||
"strength": lv.strength,
|
||||
}
|
||||
for lv in sr_result.scalars().all()
|
||||
]
|
||||
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"entry_type": entry.entry_type,
|
||||
"composite_score": comp.score if comp else None,
|
||||
"dimensions": dims,
|
||||
"rr_ratio": setup.rr_ratio if setup else None,
|
||||
"rr_direction": setup.direction if setup else None,
|
||||
"sr_levels": sr_levels,
|
||||
"added_at": entry.added_at,
|
||||
}
|
||||
|
||||
|
||||
async def get_watchlist(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
sort_by: str = "composite",
|
||||
) -> list[dict]:
|
||||
"""Get user's watchlist with enriched data.
|
||||
|
||||
Runs auto_populate first to ensure auto entries are current,
|
||||
then enriches each entry with scores, R:R, and SR levels.
|
||||
|
||||
sort_by: "composite", "rr", or a dimension name
|
||||
(e.g. "technical", "sr_quality", "sentiment", "fundamental", "momentum").
|
||||
"""
|
||||
# Auto-populate to refresh auto entries
|
||||
await auto_populate(db, user_id)
|
||||
await db.commit()
|
||||
|
||||
# Fetch all entries with ticker symbol
|
||||
stmt = (
|
||||
select(WatchlistEntry, Ticker.symbol)
|
||||
.join(Ticker, WatchlistEntry.ticker_id == Ticker.id)
|
||||
.where(WatchlistEntry.user_id == user_id)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
rows = result.all()
|
||||
|
||||
entries: list[dict] = []
|
||||
for entry, symbol in rows:
|
||||
enriched = await _enrich_entry(db, entry, symbol)
|
||||
entries.append(enriched)
|
||||
|
||||
# Sort
|
||||
if sort_by == "composite":
|
||||
entries.sort(
|
||||
key=lambda e: e["composite_score"] if e["composite_score"] is not None else -1,
|
||||
reverse=True,
|
||||
)
|
||||
elif sort_by == "rr":
|
||||
entries.sort(
|
||||
key=lambda e: e["rr_ratio"] if e["rr_ratio"] is not None else -1,
|
||||
reverse=True,
|
||||
)
|
||||
else:
|
||||
# Sort by a specific dimension score
|
||||
def _dim_sort_key(e: dict) -> float:
|
||||
for d in e["dimensions"]:
|
||||
if d["dimension"] == sort_by:
|
||||
return d["score"]
|
||||
return -1.0
|
||||
|
||||
entries.sort(key=_dim_sort_key, reverse=True)
|
||||
|
||||
return entries
|
||||
Reference in New Issue
Block a user