124 lines
4.1 KiB
Python
124 lines
4.1 KiB
Python
"""Unit tests for ticker_universe_service bootstrap logic."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from collections.abc import AsyncGenerator
|
|
|
|
import pytest
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from app.database import Base
|
|
from app.exceptions import ProviderError
|
|
from app.models.settings import SystemSetting
|
|
from app.models.ticker import Ticker
|
|
from app.services import ticker_universe_service
|
|
|
|
_engine = create_async_engine("sqlite+aiosqlite://", echo=False)
|
|
_session_factory = async_sessionmaker(_engine, class_=AsyncSession, expire_on_commit=False)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
async def _setup_tables() -> AsyncGenerator[None, None]:
|
|
async with _engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield
|
|
async with _engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
|
|
|
|
@pytest.fixture
|
|
async def session() -> AsyncGenerator[AsyncSession, None]:
|
|
async with _session_factory() as s:
|
|
yield s
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_bootstrap_universe_adds_missing_symbols(session: AsyncSession, monkeypatch: pytest.MonkeyPatch):
|
|
session.add(Ticker(symbol="AAPL"))
|
|
await session.commit()
|
|
|
|
async def _fake_fetch(_db: AsyncSession, _universe: str) -> list[str]:
|
|
return ["AAPL", "MSFT", "NVDA"]
|
|
|
|
monkeypatch.setattr(ticker_universe_service, "fetch_universe_symbols", _fake_fetch)
|
|
|
|
result = await ticker_universe_service.bootstrap_universe(session, "sp500")
|
|
|
|
assert result["added"] == 2
|
|
assert result["already_tracked"] == 1
|
|
assert result["deleted"] == 0
|
|
|
|
rows = await session.execute(select(Ticker.symbol).order_by(Ticker.symbol.asc()))
|
|
assert list(rows.scalars().all()) == ["AAPL", "MSFT", "NVDA"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_bootstrap_universe_prunes_missing_symbols(session: AsyncSession, monkeypatch: pytest.MonkeyPatch):
|
|
session.add_all([Ticker(symbol="AAPL"), Ticker(symbol="MSFT"), Ticker(symbol="TSLA")])
|
|
await session.commit()
|
|
|
|
async def _fake_fetch(_db: AsyncSession, _universe: str) -> list[str]:
|
|
return ["AAPL", "MSFT"]
|
|
|
|
monkeypatch.setattr(ticker_universe_service, "fetch_universe_symbols", _fake_fetch)
|
|
|
|
result = await ticker_universe_service.bootstrap_universe(
|
|
session,
|
|
"sp500",
|
|
prune_missing=True,
|
|
)
|
|
|
|
assert result["added"] == 0
|
|
assert result["already_tracked"] == 2
|
|
assert result["deleted"] == 1
|
|
|
|
rows = await session.execute(select(Ticker.symbol).order_by(Ticker.symbol.asc()))
|
|
assert list(rows.scalars().all()) == ["AAPL", "MSFT"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_universe_symbols_uses_cached_snapshot_when_live_sources_fail(
|
|
session: AsyncSession,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
):
|
|
session.add(
|
|
SystemSetting(
|
|
key="ticker_universe_cache_sp500",
|
|
value=json.dumps({"symbols": ["AAPL", "MSFT"], "source": "test"}),
|
|
)
|
|
)
|
|
await session.commit()
|
|
|
|
async def _fake_public(_universe: str):
|
|
return [], ["public failed"], None
|
|
|
|
async def _fake_fmp(_universe: str):
|
|
raise ProviderError("fmp failed")
|
|
|
|
monkeypatch.setattr(ticker_universe_service, "_fetch_universe_symbols_from_public", _fake_public)
|
|
monkeypatch.setattr(ticker_universe_service, "_fetch_universe_symbols_from_fmp", _fake_fmp)
|
|
|
|
symbols = await ticker_universe_service.fetch_universe_symbols(session, "sp500")
|
|
assert symbols == ["AAPL", "MSFT"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_universe_symbols_uses_seed_when_live_and_cache_fail(
|
|
session: AsyncSession,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
):
|
|
async def _fake_public(_universe: str):
|
|
return [], ["public failed"], None
|
|
|
|
async def _fake_fmp(_universe: str):
|
|
raise ProviderError("fmp failed")
|
|
|
|
monkeypatch.setattr(ticker_universe_service, "_fetch_universe_symbols_from_public", _fake_public)
|
|
monkeypatch.setattr(ticker_universe_service, "_fetch_universe_symbols_from_fmp", _fake_fmp)
|
|
|
|
symbols = await ticker_universe_service.fetch_universe_symbols(session, "sp500")
|
|
assert "AAPL" in symbols
|
|
assert len(symbols) > 10
|