111 lines
3.5 KiB
Python
111 lines
3.5 KiB
Python
"""Price Store service: upsert and query OHLCV records."""
|
|
|
|
from datetime import date, datetime
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.exceptions import NotFoundError, ValidationError
|
|
from app.models.ohlcv import OHLCVRecord
|
|
from app.models.ticker import Ticker
|
|
|
|
|
|
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
|
"""Look up a ticker by symbol. Raises NotFoundError if missing."""
|
|
normalised = symbol.strip().upper()
|
|
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
|
ticker = result.scalar_one_or_none()
|
|
if ticker is None:
|
|
raise NotFoundError(f"Ticker not found: {normalised}")
|
|
return ticker
|
|
|
|
|
|
def _validate_ohlcv(
|
|
high: float, low: float, open_: float, close: float, volume: int, record_date: date
|
|
) -> None:
|
|
"""Business-rule validation for an OHLCV record."""
|
|
if high < low:
|
|
raise ValidationError("Validation error: high must be >= low")
|
|
if any(p < 0 for p in (open_, high, low, close)):
|
|
raise ValidationError("Validation error: prices must be >= 0")
|
|
if volume < 0:
|
|
raise ValidationError("Validation error: volume must be >= 0")
|
|
if record_date > date.today():
|
|
raise ValidationError("Validation error: date must not be in the future")
|
|
|
|
|
|
async def upsert_ohlcv(
|
|
db: AsyncSession,
|
|
symbol: str,
|
|
record_date: date,
|
|
open_: float,
|
|
high: float,
|
|
low: float,
|
|
close: float,
|
|
volume: int,
|
|
) -> OHLCVRecord:
|
|
"""Insert or update an OHLCV record for (ticker, date).
|
|
|
|
Validates business rules, resolves ticker, then uses
|
|
ON CONFLICT DO UPDATE on the (ticker_id, date) unique constraint.
|
|
"""
|
|
_validate_ohlcv(high, low, open_, close, volume, record_date)
|
|
ticker = await _get_ticker(db, symbol)
|
|
|
|
stmt = pg_insert(OHLCVRecord).values(
|
|
ticker_id=ticker.id,
|
|
date=record_date,
|
|
open=open_,
|
|
high=high,
|
|
low=low,
|
|
close=close,
|
|
volume=volume,
|
|
created_at=datetime.utcnow(),
|
|
)
|
|
stmt = stmt.on_conflict_do_update(
|
|
constraint="uq_ohlcv_ticker_date",
|
|
set_={
|
|
"open": stmt.excluded.open,
|
|
"high": stmt.excluded.high,
|
|
"low": stmt.excluded.low,
|
|
"close": stmt.excluded.close,
|
|
"volume": stmt.excluded.volume,
|
|
"created_at": stmt.excluded.created_at,
|
|
},
|
|
)
|
|
stmt = stmt.returning(OHLCVRecord)
|
|
result = await db.execute(stmt)
|
|
await db.commit()
|
|
|
|
record = result.scalar_one()
|
|
|
|
# TODO: Invalidate LRU cache entries for this ticker (Task 7.1)
|
|
# TODO: Mark composite score as stale for this ticker (Task 10.1)
|
|
|
|
return record
|
|
|
|
|
|
async def query_ohlcv(
|
|
db: AsyncSession,
|
|
symbol: str,
|
|
start_date: date | None = None,
|
|
end_date: date | None = None,
|
|
) -> list[OHLCVRecord]:
|
|
"""Query OHLCV records for a ticker, optionally filtered by date range.
|
|
|
|
Returns records sorted by date ascending.
|
|
Raises NotFoundError if the ticker does not exist.
|
|
"""
|
|
ticker = await _get_ticker(db, symbol)
|
|
|
|
stmt = select(OHLCVRecord).where(OHLCVRecord.ticker_id == ticker.id)
|
|
if start_date is not None:
|
|
stmt = stmt.where(OHLCVRecord.date >= start_date)
|
|
if end_date is not None:
|
|
stmt = stmt.where(OHLCVRecord.date <= end_date)
|
|
stmt = stmt.order_by(OHLCVRecord.date.asc())
|
|
|
|
result = await db.execute(stmt)
|
|
return list(result.scalars().all())
|