first commit
Some checks failed
Deploy / lint (push) Failing after 7s
Deploy / test (push) Has been skipped
Deploy / deploy (push) Has been skipped

This commit is contained in:
Dennis Thiessen
2026-02-20 17:31:01 +01:00
commit 61ab24490d
160 changed files with 17034 additions and 0 deletions

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@

260
tests/conftest.py Normal file
View File

@@ -0,0 +1,260 @@
"""Shared test fixtures and hypothesis strategies for the stock-data-backend test suite."""
from __future__ import annotations
import string
from datetime import date, datetime, timedelta, timezone
from typing import Any
import pytest
from httpx import ASGITransport, AsyncClient
from hypothesis import strategies as st
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from app.database import Base
from app.providers.protocol import OHLCVData
# ---------------------------------------------------------------------------
# Test database (SQLite in-memory, async via aiosqlite)
# ---------------------------------------------------------------------------
TEST_DATABASE_URL = "sqlite+aiosqlite://"
_test_engine = create_async_engine(TEST_DATABASE_URL, echo=False)
_test_session_factory = async_sessionmaker(
_test_engine,
class_=AsyncSession,
expire_on_commit=False,
)
@pytest.fixture(autouse=True)
async def _setup_db():
"""Create all tables before each test and drop them after."""
async with _test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with _test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def db_session() -> AsyncSession:
"""Provide a transactional DB session that rolls back after the test."""
async with _test_session_factory() as session:
async with session.begin():
yield session
await session.rollback()
# ---------------------------------------------------------------------------
# FastAPI test client
# ---------------------------------------------------------------------------
@pytest.fixture
async def client(db_session: AsyncSession) -> AsyncClient:
"""Async HTTP test client wired to the FastAPI app with the test DB session."""
from app.dependencies import get_db
from app.main import app
async def _override_get_db():
yield db_session
app.dependency_overrides[get_db] = _override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
app.dependency_overrides.clear()
# ---------------------------------------------------------------------------
# Mock MarketDataProvider
# ---------------------------------------------------------------------------
class MockMarketDataProvider:
"""Configurable mock that satisfies the MarketDataProvider protocol."""
def __init__(
self,
ohlcv_data: list[OHLCVData] | None = None,
error: Exception | None = None,
) -> None:
self.ohlcv_data = ohlcv_data or []
self.error = error
self.calls: list[dict[str, Any]] = []
async def fetch_ohlcv(
self, ticker: str, start_date: date, end_date: date
) -> list[OHLCVData]:
self.calls.append(
{"ticker": ticker, "start_date": start_date, "end_date": end_date}
)
if self.error is not None:
raise self.error
return [r for r in self.ohlcv_data if r.ticker == ticker]
@pytest.fixture
def mock_provider() -> MockMarketDataProvider:
"""Return a fresh MockMarketDataProvider instance."""
return MockMarketDataProvider()
# ---------------------------------------------------------------------------
# Hypothesis custom strategies
# ---------------------------------------------------------------------------
_TICKER_ALPHABET = string.ascii_uppercase + string.digits
@st.composite
def valid_ticker_symbols(draw: st.DrawFn) -> str:
"""Generate uppercase alphanumeric ticker symbols (1-10 chars)."""
return draw(
st.text(alphabet=_TICKER_ALPHABET, min_size=1, max_size=10)
)
@st.composite
def whitespace_strings(draw: st.DrawFn) -> str:
"""Generate strings composed entirely of whitespace (including empty)."""
return draw(
st.text(alphabet=" \t\n\r\x0b\x0c", min_size=0, max_size=20)
)
@st.composite
def valid_ohlcv_records(draw: st.DrawFn) -> OHLCVData:
"""Generate valid OHLCV records (high >= low, prices >= 0, volume >= 0, date <= today)."""
ticker = draw(valid_ticker_symbols())
low = draw(st.floats(min_value=0.01, max_value=10000.0, allow_nan=False, allow_infinity=False))
high = draw(st.floats(min_value=low, max_value=10000.0, allow_nan=False, allow_infinity=False))
open_ = draw(st.floats(min_value=low, max_value=high, allow_nan=False, allow_infinity=False))
close = draw(st.floats(min_value=low, max_value=high, allow_nan=False, allow_infinity=False))
volume = draw(st.integers(min_value=0, max_value=10**12))
record_date = draw(
st.dates(min_value=date(2000, 1, 1), max_value=date.today())
)
return OHLCVData(
ticker=ticker,
date=record_date,
open=open_,
high=high,
low=low,
close=close,
volume=volume,
)
@st.composite
def invalid_ohlcv_records(draw: st.DrawFn) -> OHLCVData:
"""Generate OHLCV records that violate at least one constraint."""
ticker = draw(valid_ticker_symbols())
violation = draw(st.sampled_from(["high_lt_low", "negative_price", "negative_volume", "future_date"]))
if violation == "high_lt_low":
high = draw(st.floats(min_value=0.01, max_value=100.0, allow_nan=False, allow_infinity=False))
low = draw(st.floats(min_value=high + 0.01, max_value=200.0, allow_nan=False, allow_infinity=False))
return OHLCVData(
ticker=ticker, date=date.today(),
open=high, high=high, low=low, close=high, volume=100,
)
elif violation == "negative_price":
neg = draw(st.floats(min_value=-10000.0, max_value=-0.01, allow_nan=False, allow_infinity=False))
return OHLCVData(
ticker=ticker, date=date.today(),
open=neg, high=abs(neg), low=abs(neg), close=abs(neg), volume=100,
)
elif violation == "negative_volume":
price = draw(st.floats(min_value=0.01, max_value=100.0, allow_nan=False, allow_infinity=False))
neg_vol = draw(st.integers(min_value=-10**9, max_value=-1))
return OHLCVData(
ticker=ticker, date=date.today(),
open=price, high=price, low=price, close=price, volume=neg_vol,
)
else: # future_date
future = date.today() + timedelta(days=draw(st.integers(min_value=1, max_value=365)))
price = draw(st.floats(min_value=0.01, max_value=100.0, allow_nan=False, allow_infinity=False))
return OHLCVData(
ticker=ticker, date=future,
open=price, high=price, low=price, close=price, volume=100,
)
_DIMENSIONS = ["technical", "sr_quality", "sentiment", "fundamental", "momentum"]
@st.composite
def dimension_scores(draw: st.DrawFn) -> float:
"""Generate float values in [0, 100] for dimension scores."""
return draw(st.floats(min_value=0.0, max_value=100.0, allow_nan=False, allow_infinity=False))
@st.composite
def weight_configs(draw: st.DrawFn) -> dict[str, float]:
"""Generate dicts of dimension → positive float weight."""
dims = draw(st.lists(st.sampled_from(_DIMENSIONS), min_size=1, max_size=5, unique=True))
weights: dict[str, float] = {}
for dim in dims:
weights[dim] = draw(st.floats(min_value=0.01, max_value=10.0, allow_nan=False, allow_infinity=False))
return weights
@st.composite
def sr_levels(draw: st.DrawFn) -> dict[str, Any]:
"""Generate SR level data (price, type, strength, detection_method)."""
return {
"price_level": draw(st.floats(min_value=0.01, max_value=10000.0, allow_nan=False, allow_infinity=False)),
"type": draw(st.sampled_from(["support", "resistance"])),
"strength": draw(st.integers(min_value=0, max_value=100)),
"detection_method": draw(st.sampled_from(["volume_profile", "pivot_point", "merged"])),
}
@st.composite
def sentiment_scores(draw: st.DrawFn) -> dict[str, Any]:
"""Generate sentiment data (classification, confidence, source, timestamp)."""
naive_dt = draw(
st.datetimes(
min_value=datetime(2020, 1, 1),
max_value=datetime.now(),
)
)
return {
"classification": draw(st.sampled_from(["bullish", "bearish", "neutral"])),
"confidence": draw(st.integers(min_value=0, max_value=100)),
"source": draw(st.text(alphabet=string.ascii_lowercase, min_size=3, max_size=20)),
"timestamp": naive_dt.replace(tzinfo=timezone.utc),
}
@st.composite
def trade_setups(draw: st.DrawFn) -> dict[str, Any]:
"""Generate trade setup data (direction, entry, stop, target, rr_ratio, composite_score)."""
direction = draw(st.sampled_from(["long", "short"]))
entry = draw(st.floats(min_value=1.0, max_value=10000.0, allow_nan=False, allow_infinity=False))
atr_dist = draw(st.floats(min_value=0.01, max_value=entry * 0.2, allow_nan=False, allow_infinity=False))
if direction == "long":
stop = entry - atr_dist
target = entry + atr_dist * draw(st.floats(min_value=3.0, max_value=10.0, allow_nan=False, allow_infinity=False))
else:
stop = entry + atr_dist
target = entry - atr_dist * draw(st.floats(min_value=3.0, max_value=10.0, allow_nan=False, allow_infinity=False))
rr_ratio = abs(target - entry) / abs(entry - stop) if abs(entry - stop) > 0 else 0.0
return {
"direction": direction,
"entry_price": entry,
"stop_loss": stop,
"target": target,
"rr_ratio": rr_ratio,
"composite_score": draw(st.floats(min_value=0.0, max_value=100.0, allow_nan=False, allow_infinity=False)),
}

View File

@@ -0,0 +1 @@

1
tests/unit/__init__.py Normal file
View File

@@ -0,0 +1 @@

102
tests/unit/test_cache.py Normal file
View File

@@ -0,0 +1,102 @@
"""Unit tests for app.cache LRU cache wrapper."""
from datetime import date
from app.cache import LRUCache
def _key(ticker: str, indicator: str = "RSI") -> tuple:
return (ticker, date(2024, 1, 1), date(2024, 6, 1), indicator)
class TestLRUCacheBasics:
def test_get_miss_returns_none(self):
cache = LRUCache()
assert cache.get(_key("AAPL")) is None
def test_set_and_get_round_trip(self):
cache = LRUCache()
cache.set(_key("AAPL"), {"score": 72})
assert cache.get(_key("AAPL")) == {"score": 72}
def test_set_overwrites_existing(self):
cache = LRUCache()
cache.set(_key("AAPL"), 1)
cache.set(_key("AAPL"), 2)
assert cache.get(_key("AAPL")) == 2
assert cache.size == 1
def test_size_and_clear(self):
cache = LRUCache(max_size=5)
for i in range(3):
cache.set((f"T{i}", None, None, "RSI"), i)
assert cache.size == 3
cache.clear()
assert cache.size == 0
class TestLRUEviction:
def test_evicts_lru_when_full(self):
cache = LRUCache(max_size=3)
cache.set(_key("A"), 1)
cache.set(_key("B"), 2)
cache.set(_key("C"), 3)
# A is LRU — inserting D should evict A
cache.set(_key("D"), 4)
assert cache.get(_key("A")) is None
assert cache.size == 3
def test_access_promotes_entry(self):
cache = LRUCache(max_size=3)
cache.set(_key("A"), 1)
cache.set(_key("B"), 2)
cache.set(_key("C"), 3)
# Access A so B becomes LRU
cache.get(_key("A"))
cache.set(_key("D"), 4)
assert cache.get(_key("B")) is None
assert cache.get(_key("A")) == 1
def test_update_promotes_entry(self):
cache = LRUCache(max_size=3)
cache.set(_key("A"), 1)
cache.set(_key("B"), 2)
cache.set(_key("C"), 3)
# Update A so B becomes LRU
cache.set(_key("A"), 10)
cache.set(_key("D"), 4)
assert cache.get(_key("B")) is None
assert cache.get(_key("A")) == 10
class TestTickerInvalidation:
def test_invalidate_removes_all_entries_for_ticker(self):
cache = LRUCache()
cache.set(("AAPL", date(2024, 1, 1), date(2024, 6, 1), "RSI"), 1)
cache.set(("AAPL", date(2024, 1, 1), date(2024, 6, 1), "ADX"), 2)
cache.set(("MSFT", date(2024, 1, 1), date(2024, 6, 1), "RSI"), 3)
removed = cache.invalidate_ticker("AAPL")
assert removed == 2
assert cache.get(("AAPL", date(2024, 1, 1), date(2024, 6, 1), "RSI")) is None
assert cache.get(("AAPL", date(2024, 1, 1), date(2024, 6, 1), "ADX")) is None
assert cache.get(("MSFT", date(2024, 1, 1), date(2024, 6, 1), "RSI")) == 3
def test_invalidate_nonexistent_ticker_returns_zero(self):
cache = LRUCache()
cache.set(_key("AAPL"), 1)
assert cache.invalidate_ticker("GOOG") == 0
assert cache.size == 1
def test_invalidate_on_empty_cache(self):
cache = LRUCache()
assert cache.invalidate_ticker("AAPL") == 0
class TestMaxSizeProperty:
def test_default_max_size(self):
cache = LRUCache()
assert cache.max_size == 1000
def test_custom_max_size(self):
cache = LRUCache(max_size=50)
assert cache.max_size == 50

View File

@@ -0,0 +1,201 @@
"""Tests for the exception hierarchy and global exception handlers."""
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.exceptions import (
AppError,
AuthenticationError,
AuthorizationError,
DuplicateError,
NotFoundError,
ProviderError,
RateLimitError,
ValidationError,
)
from app.middleware import register_exception_handlers
from app.schemas.common import APIEnvelope
# ── Exception hierarchy tests ──
def test_app_error_defaults():
err = AppError()
assert err.status_code == 500
assert err.message == "Internal server error"
assert str(err) == "Internal server error"
def test_app_error_custom_message():
err = AppError("something broke")
assert err.message == "something broke"
assert str(err) == "something broke"
@pytest.mark.parametrize(
"cls,code,default_msg",
[
(ValidationError, 400, "Validation error"),
(NotFoundError, 404, "Resource not found"),
(DuplicateError, 409, "Resource already exists"),
(AuthenticationError, 401, "Authentication required"),
(AuthorizationError, 403, "Insufficient permissions"),
(ProviderError, 502, "Market data provider unavailable"),
(RateLimitError, 429, "Rate limited"),
],
)
def test_subclass_defaults(cls, code, default_msg):
err = cls()
assert err.status_code == code
assert err.message == default_msg
def test_subclass_custom_message():
err = NotFoundError("Ticker not found: AAPL")
assert err.status_code == 404
assert err.message == "Ticker not found: AAPL"
def test_all_subclasses_are_app_errors():
for cls in (
ValidationError,
NotFoundError,
DuplicateError,
AuthenticationError,
AuthorizationError,
ProviderError,
RateLimitError,
):
assert issubclass(cls, AppError)
# ── APIEnvelope schema tests ──
def test_envelope_success():
env = APIEnvelope(status="success", data={"id": 1})
assert env.status == "success"
assert env.data == {"id": 1}
assert env.error is None
def test_envelope_error():
env = APIEnvelope(status="error", error="bad request")
assert env.status == "error"
assert env.data is None
assert env.error == "bad request"
# ── Middleware integration tests ──
def _make_app() -> FastAPI:
"""Create a minimal FastAPI app with exception handlers and test routes."""
app = FastAPI()
register_exception_handlers(app)
@app.get("/raise-not-found")
async def _raise_not_found():
raise NotFoundError("Ticker not found: XYZ")
@app.get("/raise-validation")
async def _raise_validation():
raise ValidationError("high < low")
@app.get("/raise-duplicate")
async def _raise_duplicate():
raise DuplicateError("Ticker already exists: AAPL")
@app.get("/raise-auth")
async def _raise_auth():
raise AuthenticationError()
@app.get("/raise-authz")
async def _raise_authz():
raise AuthorizationError()
@app.get("/raise-provider")
async def _raise_provider():
raise ProviderError()
@app.get("/raise-rate-limit")
async def _raise_rate_limit():
raise RateLimitError("Rate limited. Ingested 42 records. Resume available.")
@app.get("/raise-unhandled")
async def _raise_unhandled():
raise RuntimeError("unexpected")
return app
@pytest.fixture
def client():
return TestClient(_make_app())
def test_middleware_not_found(client):
resp = client.get("/raise-not-found")
assert resp.status_code == 404
body = resp.json()
assert body["status"] == "error"
assert body["data"] is None
assert body["error"] == "Ticker not found: XYZ"
def test_middleware_validation(client):
resp = client.get("/raise-validation")
assert resp.status_code == 400
body = resp.json()
assert body["status"] == "error"
assert body["error"] == "high < low"
def test_middleware_duplicate(client):
resp = client.get("/raise-duplicate")
assert resp.status_code == 409
body = resp.json()
assert body["status"] == "error"
assert "already exists" in body["error"]
def test_middleware_authentication(client):
resp = client.get("/raise-auth")
assert resp.status_code == 401
body = resp.json()
assert body["status"] == "error"
def test_middleware_authorization(client):
resp = client.get("/raise-authz")
assert resp.status_code == 403
body = resp.json()
assert body["status"] == "error"
def test_middleware_provider_error(client):
resp = client.get("/raise-provider")
assert resp.status_code == 502
body = resp.json()
assert body["status"] == "error"
def test_middleware_rate_limit(client):
resp = client.get("/raise-rate-limit")
assert resp.status_code == 429
body = resp.json()
assert body["status"] == "error"
assert "42 records" in body["error"]
def test_middleware_unhandled_exception():
app = _make_app()
with TestClient(app, raise_server_exceptions=False) as c:
resp = c.get("/raise-unhandled")
assert resp.status_code == 500
body = resp.json()
assert body["status"] == "error"
assert body["data"] is None
assert body["error"] == "Internal server error"

View File

@@ -0,0 +1,205 @@
"""Unit tests for app.services.indicator_service pure computation functions."""
import pytest
from app.exceptions import ValidationError
from app.services.indicator_service import (
compute_adx,
compute_atr,
compute_ema,
compute_ema_cross,
compute_pivot_points,
compute_rsi,
compute_volume_profile,
)
# ---------------------------------------------------------------------------
# Helpers: generate synthetic OHLCV data
# ---------------------------------------------------------------------------
def _rising_closes(n: int, start: float = 100.0, step: float = 1.0) -> list[float]:
return [start + i * step for i in range(n)]
def _flat_closes(n: int, price: float = 100.0) -> list[float]:
return [price] * n
def _ohlcv_from_closes(closes: list[float], spread: float = 2.0):
"""Generate highs/lows/volumes from a close series."""
highs = [c + spread for c in closes]
lows = [c - spread for c in closes]
volumes = [1000] * len(closes)
return highs, lows, closes, volumes
# ---------------------------------------------------------------------------
# EMA
# ---------------------------------------------------------------------------
class TestComputeEMA:
def test_basic_ema(self):
closes = _rising_closes(25)
result = compute_ema(closes, period=20)
assert "ema" in result
assert "score" in result
assert 0 <= result["score"] <= 100
def test_insufficient_data_raises(self):
closes = _rising_closes(5)
with pytest.raises(ValidationError, match="EMA.*requires at least"):
compute_ema(closes, period=20)
def test_price_above_ema_high_score(self):
# Rising prices → latest close above EMA → score > 50
closes = _rising_closes(30, start=100, step=2)
result = compute_ema(closes, period=20)
assert result["score"] > 50
def test_price_below_ema_low_score(self):
# Falling prices → latest close below EMA → score < 50
closes = list(reversed(_rising_closes(30, start=100, step=2)))
result = compute_ema(closes, period=20)
assert result["score"] < 50
# ---------------------------------------------------------------------------
# RSI
# ---------------------------------------------------------------------------
class TestComputeRSI:
def test_basic_rsi(self):
closes = _rising_closes(20)
result = compute_rsi(closes)
assert "rsi" in result
assert 0 <= result["score"] <= 100
def test_all_gains_rsi_100(self):
closes = _rising_closes(20, step=1)
result = compute_rsi(closes)
assert result["rsi"] == 100.0
def test_all_losses_rsi_0(self):
closes = list(reversed(_rising_closes(20, step=1)))
result = compute_rsi(closes)
assert result["rsi"] == pytest.approx(0.0, abs=0.5)
def test_insufficient_data_raises(self):
with pytest.raises(ValidationError, match="RSI requires"):
compute_rsi([100.0] * 5)
# ---------------------------------------------------------------------------
# ATR
# ---------------------------------------------------------------------------
class TestComputeATR:
def test_basic_atr(self):
closes = _rising_closes(20)
highs, lows, _, _ = _ohlcv_from_closes(closes)
result = compute_atr(highs, lows, closes)
assert "atr" in result
assert result["atr"] > 0
assert 0 <= result["score"] <= 100
def test_insufficient_data_raises(self):
closes = [100.0] * 5
highs, lows, _, _ = _ohlcv_from_closes(closes)
with pytest.raises(ValidationError, match="ATR requires"):
compute_atr(highs, lows, closes)
# ---------------------------------------------------------------------------
# ADX
# ---------------------------------------------------------------------------
class TestComputeADX:
def test_basic_adx(self):
closes = _rising_closes(30)
highs, lows, _, _ = _ohlcv_from_closes(closes)
result = compute_adx(highs, lows, closes)
assert "adx" in result
assert "plus_di" in result
assert "minus_di" in result
assert 0 <= result["score"] <= 100
def test_insufficient_data_raises(self):
closes = _rising_closes(10)
highs, lows, _, _ = _ohlcv_from_closes(closes)
with pytest.raises(ValidationError, match="ADX requires"):
compute_adx(highs, lows, closes)
# ---------------------------------------------------------------------------
# Volume Profile
# ---------------------------------------------------------------------------
class TestComputeVolumeProfile:
def test_basic_volume_profile(self):
closes = _rising_closes(25)
highs, lows, _, volumes = _ohlcv_from_closes(closes)
result = compute_volume_profile(highs, lows, closes, volumes)
assert "poc" in result
assert "value_area_low" in result
assert "value_area_high" in result
assert "hvn" in result
assert "lvn" in result
assert 0 <= result["score"] <= 100
def test_insufficient_data_raises(self):
closes = [100.0] * 10
highs, lows, _, volumes = _ohlcv_from_closes(closes)
with pytest.raises(ValidationError, match="Volume Profile requires"):
compute_volume_profile(highs, lows, closes, volumes)
# ---------------------------------------------------------------------------
# Pivot Points
# ---------------------------------------------------------------------------
class TestComputePivotPoints:
def test_basic_pivot_points(self):
# Create data with clear swing highs/lows
closes = [10, 15, 20, 15, 10, 15, 20, 15, 10, 15]
highs = [c + 1 for c in closes]
lows = [c - 1 for c in closes]
result = compute_pivot_points(highs, lows, closes)
assert "swing_highs" in result
assert "swing_lows" in result
assert 0 <= result["score"] <= 100
def test_insufficient_data_raises(self):
with pytest.raises(ValidationError, match="Pivot Points requires"):
compute_pivot_points([1, 2], [0, 1], [0.5, 1.5])
# ---------------------------------------------------------------------------
# EMA Cross
# ---------------------------------------------------------------------------
class TestComputeEMACross:
def test_bullish_signal(self):
# Rising prices → short EMA > long EMA → bullish
closes = _rising_closes(60, step=2)
result = compute_ema_cross(closes, short_period=20, long_period=50)
assert result["signal"] == "bullish"
assert result["short_ema"] > result["long_ema"]
def test_bearish_signal(self):
# Falling prices → short EMA < long EMA → bearish
closes = list(reversed(_rising_closes(60, step=2)))
result = compute_ema_cross(closes, short_period=20, long_period=50)
assert result["signal"] == "bearish"
assert result["short_ema"] < result["long_ema"]
def test_neutral_signal(self):
# Flat prices → EMAs converge → neutral
closes = _flat_closes(60)
result = compute_ema_cross(closes, short_period=20, long_period=50)
assert result["signal"] == "neutral"
def test_insufficient_data_raises(self):
closes = _rising_closes(30)
with pytest.raises(ValidationError, match="EMA Cross requires"):
compute_ema_cross(closes, short_period=20, long_period=50)

View File

@@ -0,0 +1,95 @@
"""Unit tests for app.scheduler module."""
import pytest
from app.scheduler import (
_is_job_enabled,
_parse_frequency,
_resume_tickers,
_last_successful,
configure_scheduler,
scheduler,
)
class TestParseFrequency:
def test_hourly(self):
assert _parse_frequency("hourly") == {"hours": 1}
def test_daily(self):
assert _parse_frequency("daily") == {"hours": 24}
def test_case_insensitive(self):
assert _parse_frequency("Hourly") == {"hours": 1}
assert _parse_frequency("DAILY") == {"hours": 24}
def test_unknown_defaults_to_daily(self):
assert _parse_frequency("weekly") == {"hours": 24}
assert _parse_frequency("") == {"hours": 24}
class TestResumeTickers:
def test_no_previous_returns_full_list(self):
symbols = ["AAPL", "GOOG", "MSFT"]
_last_successful["test_job"] = None
result = _resume_tickers(symbols, "test_job")
assert result == ["AAPL", "GOOG", "MSFT"]
def test_resume_after_first(self):
symbols = ["AAPL", "GOOG", "MSFT"]
_last_successful["test_job"] = "AAPL"
result = _resume_tickers(symbols, "test_job")
# Should start from GOOG, then wrap around
assert result == ["GOOG", "MSFT", "AAPL"]
def test_resume_after_middle(self):
symbols = ["AAPL", "GOOG", "MSFT", "TSLA"]
_last_successful["test_job"] = "GOOG"
result = _resume_tickers(symbols, "test_job")
assert result == ["MSFT", "TSLA", "AAPL", "GOOG"]
def test_resume_after_last(self):
symbols = ["AAPL", "GOOG", "MSFT"]
_last_successful["test_job"] = "MSFT"
result = _resume_tickers(symbols, "test_job")
# All already processed, wraps to full list
assert result == ["AAPL", "GOOG", "MSFT"]
def test_unknown_last_returns_full_list(self):
symbols = ["AAPL", "GOOG", "MSFT"]
_last_successful["test_job"] = "NVDA"
result = _resume_tickers(symbols, "test_job")
assert result == ["AAPL", "GOOG", "MSFT"]
def test_empty_list(self):
_last_successful["test_job"] = "AAPL"
result = _resume_tickers([], "test_job")
assert result == []
class TestConfigureScheduler:
def test_configure_adds_four_jobs(self):
# Remove any existing jobs first
scheduler.remove_all_jobs()
configure_scheduler()
jobs = scheduler.get_jobs()
job_ids = {j.id for j in jobs}
assert job_ids == {
"data_collector",
"sentiment_collector",
"fundamental_collector",
"rr_scanner",
}
def test_configure_is_idempotent(self):
scheduler.remove_all_jobs()
configure_scheduler()
configure_scheduler() # Should replace, not duplicate
job_ids = [j.id for j in scheduler.get_jobs()]
# Each ID should appear exactly once
assert sorted(job_ids) == sorted([
"data_collector",
"fundamental_collector",
"rr_scanner",
"sentiment_collector",
])