first commit
This commit is contained in:
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
260
tests/conftest.py
Normal file
260
tests/conftest.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""Shared test fixtures and hypothesis strategies for the stock-data-backend test suite."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from hypothesis import strategies as st
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from app.database import Base
|
||||
from app.providers.protocol import OHLCVData
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test database (SQLite in-memory, async via aiosqlite)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite://"
|
||||
|
||||
_test_engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
||||
_test_session_factory = async_sessionmaker(
|
||||
_test_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def _setup_db():
|
||||
"""Create all tables before each test and drop them after."""
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session() -> AsyncSession:
|
||||
"""Provide a transactional DB session that rolls back after the test."""
|
||||
async with _test_session_factory() as session:
|
||||
async with session.begin():
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastAPI test client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(db_session: AsyncSession) -> AsyncClient:
|
||||
"""Async HTTP test client wired to the FastAPI app with the test DB session."""
|
||||
from app.dependencies import get_db
|
||||
from app.main import app
|
||||
|
||||
async def _override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock MarketDataProvider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockMarketDataProvider:
|
||||
"""Configurable mock that satisfies the MarketDataProvider protocol."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ohlcv_data: list[OHLCVData] | None = None,
|
||||
error: Exception | None = None,
|
||||
) -> None:
|
||||
self.ohlcv_data = ohlcv_data or []
|
||||
self.error = error
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
|
||||
async def fetch_ohlcv(
|
||||
self, ticker: str, start_date: date, end_date: date
|
||||
) -> list[OHLCVData]:
|
||||
self.calls.append(
|
||||
{"ticker": ticker, "start_date": start_date, "end_date": end_date}
|
||||
)
|
||||
if self.error is not None:
|
||||
raise self.error
|
||||
return [r for r in self.ohlcv_data if r.ticker == ticker]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider() -> MockMarketDataProvider:
|
||||
"""Return a fresh MockMarketDataProvider instance."""
|
||||
return MockMarketDataProvider()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hypothesis custom strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TICKER_ALPHABET = string.ascii_uppercase + string.digits
|
||||
|
||||
|
||||
@st.composite
|
||||
def valid_ticker_symbols(draw: st.DrawFn) -> str:
|
||||
"""Generate uppercase alphanumeric ticker symbols (1-10 chars)."""
|
||||
return draw(
|
||||
st.text(alphabet=_TICKER_ALPHABET, min_size=1, max_size=10)
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def whitespace_strings(draw: st.DrawFn) -> str:
|
||||
"""Generate strings composed entirely of whitespace (including empty)."""
|
||||
return draw(
|
||||
st.text(alphabet=" \t\n\r\x0b\x0c", min_size=0, max_size=20)
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def valid_ohlcv_records(draw: st.DrawFn) -> OHLCVData:
|
||||
"""Generate valid OHLCV records (high >= low, prices >= 0, volume >= 0, date <= today)."""
|
||||
ticker = draw(valid_ticker_symbols())
|
||||
low = draw(st.floats(min_value=0.01, max_value=10000.0, allow_nan=False, allow_infinity=False))
|
||||
high = draw(st.floats(min_value=low, max_value=10000.0, allow_nan=False, allow_infinity=False))
|
||||
open_ = draw(st.floats(min_value=low, max_value=high, allow_nan=False, allow_infinity=False))
|
||||
close = draw(st.floats(min_value=low, max_value=high, allow_nan=False, allow_infinity=False))
|
||||
volume = draw(st.integers(min_value=0, max_value=10**12))
|
||||
record_date = draw(
|
||||
st.dates(min_value=date(2000, 1, 1), max_value=date.today())
|
||||
)
|
||||
return OHLCVData(
|
||||
ticker=ticker,
|
||||
date=record_date,
|
||||
open=open_,
|
||||
high=high,
|
||||
low=low,
|
||||
close=close,
|
||||
volume=volume,
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def invalid_ohlcv_records(draw: st.DrawFn) -> OHLCVData:
|
||||
"""Generate OHLCV records that violate at least one constraint."""
|
||||
ticker = draw(valid_ticker_symbols())
|
||||
violation = draw(st.sampled_from(["high_lt_low", "negative_price", "negative_volume", "future_date"]))
|
||||
|
||||
if violation == "high_lt_low":
|
||||
high = draw(st.floats(min_value=0.01, max_value=100.0, allow_nan=False, allow_infinity=False))
|
||||
low = draw(st.floats(min_value=high + 0.01, max_value=200.0, allow_nan=False, allow_infinity=False))
|
||||
return OHLCVData(
|
||||
ticker=ticker, date=date.today(),
|
||||
open=high, high=high, low=low, close=high, volume=100,
|
||||
)
|
||||
elif violation == "negative_price":
|
||||
neg = draw(st.floats(min_value=-10000.0, max_value=-0.01, allow_nan=False, allow_infinity=False))
|
||||
return OHLCVData(
|
||||
ticker=ticker, date=date.today(),
|
||||
open=neg, high=abs(neg), low=abs(neg), close=abs(neg), volume=100,
|
||||
)
|
||||
elif violation == "negative_volume":
|
||||
price = draw(st.floats(min_value=0.01, max_value=100.0, allow_nan=False, allow_infinity=False))
|
||||
neg_vol = draw(st.integers(min_value=-10**9, max_value=-1))
|
||||
return OHLCVData(
|
||||
ticker=ticker, date=date.today(),
|
||||
open=price, high=price, low=price, close=price, volume=neg_vol,
|
||||
)
|
||||
else: # future_date
|
||||
future = date.today() + timedelta(days=draw(st.integers(min_value=1, max_value=365)))
|
||||
price = draw(st.floats(min_value=0.01, max_value=100.0, allow_nan=False, allow_infinity=False))
|
||||
return OHLCVData(
|
||||
ticker=ticker, date=future,
|
||||
open=price, high=price, low=price, close=price, volume=100,
|
||||
)
|
||||
|
||||
|
||||
_DIMENSIONS = ["technical", "sr_quality", "sentiment", "fundamental", "momentum"]
|
||||
|
||||
|
||||
@st.composite
|
||||
def dimension_scores(draw: st.DrawFn) -> float:
|
||||
"""Generate float values in [0, 100] for dimension scores."""
|
||||
return draw(st.floats(min_value=0.0, max_value=100.0, allow_nan=False, allow_infinity=False))
|
||||
|
||||
|
||||
@st.composite
|
||||
def weight_configs(draw: st.DrawFn) -> dict[str, float]:
|
||||
"""Generate dicts of dimension → positive float weight."""
|
||||
dims = draw(st.lists(st.sampled_from(_DIMENSIONS), min_size=1, max_size=5, unique=True))
|
||||
weights: dict[str, float] = {}
|
||||
for dim in dims:
|
||||
weights[dim] = draw(st.floats(min_value=0.01, max_value=10.0, allow_nan=False, allow_infinity=False))
|
||||
return weights
|
||||
|
||||
|
||||
@st.composite
|
||||
def sr_levels(draw: st.DrawFn) -> dict[str, Any]:
|
||||
"""Generate SR level data (price, type, strength, detection_method)."""
|
||||
return {
|
||||
"price_level": draw(st.floats(min_value=0.01, max_value=10000.0, allow_nan=False, allow_infinity=False)),
|
||||
"type": draw(st.sampled_from(["support", "resistance"])),
|
||||
"strength": draw(st.integers(min_value=0, max_value=100)),
|
||||
"detection_method": draw(st.sampled_from(["volume_profile", "pivot_point", "merged"])),
|
||||
}
|
||||
|
||||
|
||||
@st.composite
|
||||
def sentiment_scores(draw: st.DrawFn) -> dict[str, Any]:
|
||||
"""Generate sentiment data (classification, confidence, source, timestamp)."""
|
||||
naive_dt = draw(
|
||||
st.datetimes(
|
||||
min_value=datetime(2020, 1, 1),
|
||||
max_value=datetime.now(),
|
||||
)
|
||||
)
|
||||
return {
|
||||
"classification": draw(st.sampled_from(["bullish", "bearish", "neutral"])),
|
||||
"confidence": draw(st.integers(min_value=0, max_value=100)),
|
||||
"source": draw(st.text(alphabet=string.ascii_lowercase, min_size=3, max_size=20)),
|
||||
"timestamp": naive_dt.replace(tzinfo=timezone.utc),
|
||||
}
|
||||
|
||||
|
||||
@st.composite
|
||||
def trade_setups(draw: st.DrawFn) -> dict[str, Any]:
|
||||
"""Generate trade setup data (direction, entry, stop, target, rr_ratio, composite_score)."""
|
||||
direction = draw(st.sampled_from(["long", "short"]))
|
||||
entry = draw(st.floats(min_value=1.0, max_value=10000.0, allow_nan=False, allow_infinity=False))
|
||||
atr_dist = draw(st.floats(min_value=0.01, max_value=entry * 0.2, allow_nan=False, allow_infinity=False))
|
||||
|
||||
if direction == "long":
|
||||
stop = entry - atr_dist
|
||||
target = entry + atr_dist * draw(st.floats(min_value=3.0, max_value=10.0, allow_nan=False, allow_infinity=False))
|
||||
else:
|
||||
stop = entry + atr_dist
|
||||
target = entry - atr_dist * draw(st.floats(min_value=3.0, max_value=10.0, allow_nan=False, allow_infinity=False))
|
||||
|
||||
rr_ratio = abs(target - entry) / abs(entry - stop) if abs(entry - stop) > 0 else 0.0
|
||||
|
||||
return {
|
||||
"direction": direction,
|
||||
"entry_price": entry,
|
||||
"stop_loss": stop,
|
||||
"target": target,
|
||||
"rr_ratio": rr_ratio,
|
||||
"composite_score": draw(st.floats(min_value=0.0, max_value=100.0, allow_nan=False, allow_infinity=False)),
|
||||
}
|
||||
1
tests/property/__init__.py
Normal file
1
tests/property/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
1
tests/unit/__init__.py
Normal file
1
tests/unit/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
102
tests/unit/test_cache.py
Normal file
102
tests/unit/test_cache.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Unit tests for app.cache LRU cache wrapper."""
|
||||
|
||||
from datetime import date
|
||||
|
||||
from app.cache import LRUCache
|
||||
|
||||
|
||||
def _key(ticker: str, indicator: str = "RSI") -> tuple:
|
||||
return (ticker, date(2024, 1, 1), date(2024, 6, 1), indicator)
|
||||
|
||||
|
||||
class TestLRUCacheBasics:
|
||||
def test_get_miss_returns_none(self):
|
||||
cache = LRUCache()
|
||||
assert cache.get(_key("AAPL")) is None
|
||||
|
||||
def test_set_and_get_round_trip(self):
|
||||
cache = LRUCache()
|
||||
cache.set(_key("AAPL"), {"score": 72})
|
||||
assert cache.get(_key("AAPL")) == {"score": 72}
|
||||
|
||||
def test_set_overwrites_existing(self):
|
||||
cache = LRUCache()
|
||||
cache.set(_key("AAPL"), 1)
|
||||
cache.set(_key("AAPL"), 2)
|
||||
assert cache.get(_key("AAPL")) == 2
|
||||
assert cache.size == 1
|
||||
|
||||
def test_size_and_clear(self):
|
||||
cache = LRUCache(max_size=5)
|
||||
for i in range(3):
|
||||
cache.set((f"T{i}", None, None, "RSI"), i)
|
||||
assert cache.size == 3
|
||||
cache.clear()
|
||||
assert cache.size == 0
|
||||
|
||||
|
||||
class TestLRUEviction:
|
||||
def test_evicts_lru_when_full(self):
|
||||
cache = LRUCache(max_size=3)
|
||||
cache.set(_key("A"), 1)
|
||||
cache.set(_key("B"), 2)
|
||||
cache.set(_key("C"), 3)
|
||||
# A is LRU — inserting D should evict A
|
||||
cache.set(_key("D"), 4)
|
||||
assert cache.get(_key("A")) is None
|
||||
assert cache.size == 3
|
||||
|
||||
def test_access_promotes_entry(self):
|
||||
cache = LRUCache(max_size=3)
|
||||
cache.set(_key("A"), 1)
|
||||
cache.set(_key("B"), 2)
|
||||
cache.set(_key("C"), 3)
|
||||
# Access A so B becomes LRU
|
||||
cache.get(_key("A"))
|
||||
cache.set(_key("D"), 4)
|
||||
assert cache.get(_key("B")) is None
|
||||
assert cache.get(_key("A")) == 1
|
||||
|
||||
def test_update_promotes_entry(self):
|
||||
cache = LRUCache(max_size=3)
|
||||
cache.set(_key("A"), 1)
|
||||
cache.set(_key("B"), 2)
|
||||
cache.set(_key("C"), 3)
|
||||
# Update A so B becomes LRU
|
||||
cache.set(_key("A"), 10)
|
||||
cache.set(_key("D"), 4)
|
||||
assert cache.get(_key("B")) is None
|
||||
assert cache.get(_key("A")) == 10
|
||||
|
||||
|
||||
class TestTickerInvalidation:
|
||||
def test_invalidate_removes_all_entries_for_ticker(self):
|
||||
cache = LRUCache()
|
||||
cache.set(("AAPL", date(2024, 1, 1), date(2024, 6, 1), "RSI"), 1)
|
||||
cache.set(("AAPL", date(2024, 1, 1), date(2024, 6, 1), "ADX"), 2)
|
||||
cache.set(("MSFT", date(2024, 1, 1), date(2024, 6, 1), "RSI"), 3)
|
||||
removed = cache.invalidate_ticker("AAPL")
|
||||
assert removed == 2
|
||||
assert cache.get(("AAPL", date(2024, 1, 1), date(2024, 6, 1), "RSI")) is None
|
||||
assert cache.get(("AAPL", date(2024, 1, 1), date(2024, 6, 1), "ADX")) is None
|
||||
assert cache.get(("MSFT", date(2024, 1, 1), date(2024, 6, 1), "RSI")) == 3
|
||||
|
||||
def test_invalidate_nonexistent_ticker_returns_zero(self):
|
||||
cache = LRUCache()
|
||||
cache.set(_key("AAPL"), 1)
|
||||
assert cache.invalidate_ticker("GOOG") == 0
|
||||
assert cache.size == 1
|
||||
|
||||
def test_invalidate_on_empty_cache(self):
|
||||
cache = LRUCache()
|
||||
assert cache.invalidate_ticker("AAPL") == 0
|
||||
|
||||
|
||||
class TestMaxSizeProperty:
|
||||
def test_default_max_size(self):
|
||||
cache = LRUCache()
|
||||
assert cache.max_size == 1000
|
||||
|
||||
def test_custom_max_size(self):
|
||||
cache = LRUCache(max_size=50)
|
||||
assert cache.max_size == 50
|
||||
201
tests/unit/test_exceptions_and_middleware.py
Normal file
201
tests/unit/test_exceptions_and_middleware.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Tests for the exception hierarchy and global exception handlers."""
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.exceptions import (
|
||||
AppError,
|
||||
AuthenticationError,
|
||||
AuthorizationError,
|
||||
DuplicateError,
|
||||
NotFoundError,
|
||||
ProviderError,
|
||||
RateLimitError,
|
||||
ValidationError,
|
||||
)
|
||||
from app.middleware import register_exception_handlers
|
||||
from app.schemas.common import APIEnvelope
|
||||
|
||||
|
||||
# ── Exception hierarchy tests ──
|
||||
|
||||
|
||||
def test_app_error_defaults():
|
||||
err = AppError()
|
||||
assert err.status_code == 500
|
||||
assert err.message == "Internal server error"
|
||||
assert str(err) == "Internal server error"
|
||||
|
||||
|
||||
def test_app_error_custom_message():
|
||||
err = AppError("something broke")
|
||||
assert err.message == "something broke"
|
||||
assert str(err) == "something broke"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cls,code,default_msg",
|
||||
[
|
||||
(ValidationError, 400, "Validation error"),
|
||||
(NotFoundError, 404, "Resource not found"),
|
||||
(DuplicateError, 409, "Resource already exists"),
|
||||
(AuthenticationError, 401, "Authentication required"),
|
||||
(AuthorizationError, 403, "Insufficient permissions"),
|
||||
(ProviderError, 502, "Market data provider unavailable"),
|
||||
(RateLimitError, 429, "Rate limited"),
|
||||
],
|
||||
)
|
||||
def test_subclass_defaults(cls, code, default_msg):
|
||||
err = cls()
|
||||
assert err.status_code == code
|
||||
assert err.message == default_msg
|
||||
|
||||
|
||||
def test_subclass_custom_message():
|
||||
err = NotFoundError("Ticker not found: AAPL")
|
||||
assert err.status_code == 404
|
||||
assert err.message == "Ticker not found: AAPL"
|
||||
|
||||
|
||||
def test_all_subclasses_are_app_errors():
|
||||
for cls in (
|
||||
ValidationError,
|
||||
NotFoundError,
|
||||
DuplicateError,
|
||||
AuthenticationError,
|
||||
AuthorizationError,
|
||||
ProviderError,
|
||||
RateLimitError,
|
||||
):
|
||||
assert issubclass(cls, AppError)
|
||||
|
||||
|
||||
# ── APIEnvelope schema tests ──
|
||||
|
||||
|
||||
def test_envelope_success():
|
||||
env = APIEnvelope(status="success", data={"id": 1})
|
||||
assert env.status == "success"
|
||||
assert env.data == {"id": 1}
|
||||
assert env.error is None
|
||||
|
||||
|
||||
def test_envelope_error():
|
||||
env = APIEnvelope(status="error", error="bad request")
|
||||
assert env.status == "error"
|
||||
assert env.data is None
|
||||
assert env.error == "bad request"
|
||||
|
||||
|
||||
# ── Middleware integration tests ──
|
||||
|
||||
|
||||
def _make_app() -> FastAPI:
|
||||
"""Create a minimal FastAPI app with exception handlers and test routes."""
|
||||
app = FastAPI()
|
||||
register_exception_handlers(app)
|
||||
|
||||
@app.get("/raise-not-found")
|
||||
async def _raise_not_found():
|
||||
raise NotFoundError("Ticker not found: XYZ")
|
||||
|
||||
@app.get("/raise-validation")
|
||||
async def _raise_validation():
|
||||
raise ValidationError("high < low")
|
||||
|
||||
@app.get("/raise-duplicate")
|
||||
async def _raise_duplicate():
|
||||
raise DuplicateError("Ticker already exists: AAPL")
|
||||
|
||||
@app.get("/raise-auth")
|
||||
async def _raise_auth():
|
||||
raise AuthenticationError()
|
||||
|
||||
@app.get("/raise-authz")
|
||||
async def _raise_authz():
|
||||
raise AuthorizationError()
|
||||
|
||||
@app.get("/raise-provider")
|
||||
async def _raise_provider():
|
||||
raise ProviderError()
|
||||
|
||||
@app.get("/raise-rate-limit")
|
||||
async def _raise_rate_limit():
|
||||
raise RateLimitError("Rate limited. Ingested 42 records. Resume available.")
|
||||
|
||||
@app.get("/raise-unhandled")
|
||||
async def _raise_unhandled():
|
||||
raise RuntimeError("unexpected")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(_make_app())
|
||||
|
||||
|
||||
def test_middleware_not_found(client):
|
||||
resp = client.get("/raise-not-found")
|
||||
assert resp.status_code == 404
|
||||
body = resp.json()
|
||||
assert body["status"] == "error"
|
||||
assert body["data"] is None
|
||||
assert body["error"] == "Ticker not found: XYZ"
|
||||
|
||||
|
||||
def test_middleware_validation(client):
|
||||
resp = client.get("/raise-validation")
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
assert body["status"] == "error"
|
||||
assert body["error"] == "high < low"
|
||||
|
||||
|
||||
def test_middleware_duplicate(client):
|
||||
resp = client.get("/raise-duplicate")
|
||||
assert resp.status_code == 409
|
||||
body = resp.json()
|
||||
assert body["status"] == "error"
|
||||
assert "already exists" in body["error"]
|
||||
|
||||
|
||||
def test_middleware_authentication(client):
|
||||
resp = client.get("/raise-auth")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["status"] == "error"
|
||||
|
||||
|
||||
def test_middleware_authorization(client):
|
||||
resp = client.get("/raise-authz")
|
||||
assert resp.status_code == 403
|
||||
body = resp.json()
|
||||
assert body["status"] == "error"
|
||||
|
||||
|
||||
def test_middleware_provider_error(client):
|
||||
resp = client.get("/raise-provider")
|
||||
assert resp.status_code == 502
|
||||
body = resp.json()
|
||||
assert body["status"] == "error"
|
||||
|
||||
|
||||
def test_middleware_rate_limit(client):
|
||||
resp = client.get("/raise-rate-limit")
|
||||
assert resp.status_code == 429
|
||||
body = resp.json()
|
||||
assert body["status"] == "error"
|
||||
assert "42 records" in body["error"]
|
||||
|
||||
|
||||
def test_middleware_unhandled_exception():
|
||||
app = _make_app()
|
||||
with TestClient(app, raise_server_exceptions=False) as c:
|
||||
resp = c.get("/raise-unhandled")
|
||||
assert resp.status_code == 500
|
||||
body = resp.json()
|
||||
assert body["status"] == "error"
|
||||
assert body["data"] is None
|
||||
assert body["error"] == "Internal server error"
|
||||
205
tests/unit/test_indicator_service.py
Normal file
205
tests/unit/test_indicator_service.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Unit tests for app.services.indicator_service pure computation functions."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.exceptions import ValidationError
|
||||
from app.services.indicator_service import (
|
||||
compute_adx,
|
||||
compute_atr,
|
||||
compute_ema,
|
||||
compute_ema_cross,
|
||||
compute_pivot_points,
|
||||
compute_rsi,
|
||||
compute_volume_profile,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers: generate synthetic OHLCV data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _rising_closes(n: int, start: float = 100.0, step: float = 1.0) -> list[float]:
|
||||
return [start + i * step for i in range(n)]
|
||||
|
||||
|
||||
def _flat_closes(n: int, price: float = 100.0) -> list[float]:
|
||||
return [price] * n
|
||||
|
||||
|
||||
def _ohlcv_from_closes(closes: list[float], spread: float = 2.0):
|
||||
"""Generate highs/lows/volumes from a close series."""
|
||||
highs = [c + spread for c in closes]
|
||||
lows = [c - spread for c in closes]
|
||||
volumes = [1000] * len(closes)
|
||||
return highs, lows, closes, volumes
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EMA
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestComputeEMA:
|
||||
def test_basic_ema(self):
|
||||
closes = _rising_closes(25)
|
||||
result = compute_ema(closes, period=20)
|
||||
assert "ema" in result
|
||||
assert "score" in result
|
||||
assert 0 <= result["score"] <= 100
|
||||
|
||||
def test_insufficient_data_raises(self):
|
||||
closes = _rising_closes(5)
|
||||
with pytest.raises(ValidationError, match="EMA.*requires at least"):
|
||||
compute_ema(closes, period=20)
|
||||
|
||||
def test_price_above_ema_high_score(self):
|
||||
# Rising prices → latest close above EMA → score > 50
|
||||
closes = _rising_closes(30, start=100, step=2)
|
||||
result = compute_ema(closes, period=20)
|
||||
assert result["score"] > 50
|
||||
|
||||
def test_price_below_ema_low_score(self):
|
||||
# Falling prices → latest close below EMA → score < 50
|
||||
closes = list(reversed(_rising_closes(30, start=100, step=2)))
|
||||
result = compute_ema(closes, period=20)
|
||||
assert result["score"] < 50
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RSI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestComputeRSI:
|
||||
def test_basic_rsi(self):
|
||||
closes = _rising_closes(20)
|
||||
result = compute_rsi(closes)
|
||||
assert "rsi" in result
|
||||
assert 0 <= result["score"] <= 100
|
||||
|
||||
def test_all_gains_rsi_100(self):
|
||||
closes = _rising_closes(20, step=1)
|
||||
result = compute_rsi(closes)
|
||||
assert result["rsi"] == 100.0
|
||||
|
||||
def test_all_losses_rsi_0(self):
|
||||
closes = list(reversed(_rising_closes(20, step=1)))
|
||||
result = compute_rsi(closes)
|
||||
assert result["rsi"] == pytest.approx(0.0, abs=0.5)
|
||||
|
||||
def test_insufficient_data_raises(self):
|
||||
with pytest.raises(ValidationError, match="RSI requires"):
|
||||
compute_rsi([100.0] * 5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ATR
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestComputeATR:
|
||||
def test_basic_atr(self):
|
||||
closes = _rising_closes(20)
|
||||
highs, lows, _, _ = _ohlcv_from_closes(closes)
|
||||
result = compute_atr(highs, lows, closes)
|
||||
assert "atr" in result
|
||||
assert result["atr"] > 0
|
||||
assert 0 <= result["score"] <= 100
|
||||
|
||||
def test_insufficient_data_raises(self):
|
||||
closes = [100.0] * 5
|
||||
highs, lows, _, _ = _ohlcv_from_closes(closes)
|
||||
with pytest.raises(ValidationError, match="ATR requires"):
|
||||
compute_atr(highs, lows, closes)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ADX
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestComputeADX:
|
||||
def test_basic_adx(self):
|
||||
closes = _rising_closes(30)
|
||||
highs, lows, _, _ = _ohlcv_from_closes(closes)
|
||||
result = compute_adx(highs, lows, closes)
|
||||
assert "adx" in result
|
||||
assert "plus_di" in result
|
||||
assert "minus_di" in result
|
||||
assert 0 <= result["score"] <= 100
|
||||
|
||||
def test_insufficient_data_raises(self):
|
||||
closes = _rising_closes(10)
|
||||
highs, lows, _, _ = _ohlcv_from_closes(closes)
|
||||
with pytest.raises(ValidationError, match="ADX requires"):
|
||||
compute_adx(highs, lows, closes)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Volume Profile
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestComputeVolumeProfile:
|
||||
def test_basic_volume_profile(self):
|
||||
closes = _rising_closes(25)
|
||||
highs, lows, _, volumes = _ohlcv_from_closes(closes)
|
||||
result = compute_volume_profile(highs, lows, closes, volumes)
|
||||
assert "poc" in result
|
||||
assert "value_area_low" in result
|
||||
assert "value_area_high" in result
|
||||
assert "hvn" in result
|
||||
assert "lvn" in result
|
||||
assert 0 <= result["score"] <= 100
|
||||
|
||||
def test_insufficient_data_raises(self):
|
||||
closes = [100.0] * 10
|
||||
highs, lows, _, volumes = _ohlcv_from_closes(closes)
|
||||
with pytest.raises(ValidationError, match="Volume Profile requires"):
|
||||
compute_volume_profile(highs, lows, closes, volumes)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pivot Points
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestComputePivotPoints:
|
||||
def test_basic_pivot_points(self):
|
||||
# Create data with clear swing highs/lows
|
||||
closes = [10, 15, 20, 15, 10, 15, 20, 15, 10, 15]
|
||||
highs = [c + 1 for c in closes]
|
||||
lows = [c - 1 for c in closes]
|
||||
result = compute_pivot_points(highs, lows, closes)
|
||||
assert "swing_highs" in result
|
||||
assert "swing_lows" in result
|
||||
assert 0 <= result["score"] <= 100
|
||||
|
||||
def test_insufficient_data_raises(self):
|
||||
with pytest.raises(ValidationError, match="Pivot Points requires"):
|
||||
compute_pivot_points([1, 2], [0, 1], [0.5, 1.5])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EMA Cross
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestComputeEMACross:
|
||||
def test_bullish_signal(self):
|
||||
# Rising prices → short EMA > long EMA → bullish
|
||||
closes = _rising_closes(60, step=2)
|
||||
result = compute_ema_cross(closes, short_period=20, long_period=50)
|
||||
assert result["signal"] == "bullish"
|
||||
assert result["short_ema"] > result["long_ema"]
|
||||
|
||||
def test_bearish_signal(self):
|
||||
# Falling prices → short EMA < long EMA → bearish
|
||||
closes = list(reversed(_rising_closes(60, step=2)))
|
||||
result = compute_ema_cross(closes, short_period=20, long_period=50)
|
||||
assert result["signal"] == "bearish"
|
||||
assert result["short_ema"] < result["long_ema"]
|
||||
|
||||
def test_neutral_signal(self):
|
||||
# Flat prices → EMAs converge → neutral
|
||||
closes = _flat_closes(60)
|
||||
result = compute_ema_cross(closes, short_period=20, long_period=50)
|
||||
assert result["signal"] == "neutral"
|
||||
|
||||
def test_insufficient_data_raises(self):
|
||||
closes = _rising_closes(30)
|
||||
with pytest.raises(ValidationError, match="EMA Cross requires"):
|
||||
compute_ema_cross(closes, short_period=20, long_period=50)
|
||||
95
tests/unit/test_scheduler.py
Normal file
95
tests/unit/test_scheduler.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Unit tests for app.scheduler module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.scheduler import (
|
||||
_is_job_enabled,
|
||||
_parse_frequency,
|
||||
_resume_tickers,
|
||||
_last_successful,
|
||||
configure_scheduler,
|
||||
scheduler,
|
||||
)
|
||||
|
||||
|
||||
class TestParseFrequency:
|
||||
def test_hourly(self):
|
||||
assert _parse_frequency("hourly") == {"hours": 1}
|
||||
|
||||
def test_daily(self):
|
||||
assert _parse_frequency("daily") == {"hours": 24}
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _parse_frequency("Hourly") == {"hours": 1}
|
||||
assert _parse_frequency("DAILY") == {"hours": 24}
|
||||
|
||||
def test_unknown_defaults_to_daily(self):
|
||||
assert _parse_frequency("weekly") == {"hours": 24}
|
||||
assert _parse_frequency("") == {"hours": 24}
|
||||
|
||||
|
||||
class TestResumeTickers:
|
||||
def test_no_previous_returns_full_list(self):
|
||||
symbols = ["AAPL", "GOOG", "MSFT"]
|
||||
_last_successful["test_job"] = None
|
||||
result = _resume_tickers(symbols, "test_job")
|
||||
assert result == ["AAPL", "GOOG", "MSFT"]
|
||||
|
||||
def test_resume_after_first(self):
|
||||
symbols = ["AAPL", "GOOG", "MSFT"]
|
||||
_last_successful["test_job"] = "AAPL"
|
||||
result = _resume_tickers(symbols, "test_job")
|
||||
# Should start from GOOG, then wrap around
|
||||
assert result == ["GOOG", "MSFT", "AAPL"]
|
||||
|
||||
def test_resume_after_middle(self):
|
||||
symbols = ["AAPL", "GOOG", "MSFT", "TSLA"]
|
||||
_last_successful["test_job"] = "GOOG"
|
||||
result = _resume_tickers(symbols, "test_job")
|
||||
assert result == ["MSFT", "TSLA", "AAPL", "GOOG"]
|
||||
|
||||
def test_resume_after_last(self):
|
||||
symbols = ["AAPL", "GOOG", "MSFT"]
|
||||
_last_successful["test_job"] = "MSFT"
|
||||
result = _resume_tickers(symbols, "test_job")
|
||||
# All already processed, wraps to full list
|
||||
assert result == ["AAPL", "GOOG", "MSFT"]
|
||||
|
||||
def test_unknown_last_returns_full_list(self):
|
||||
symbols = ["AAPL", "GOOG", "MSFT"]
|
||||
_last_successful["test_job"] = "NVDA"
|
||||
result = _resume_tickers(symbols, "test_job")
|
||||
assert result == ["AAPL", "GOOG", "MSFT"]
|
||||
|
||||
def test_empty_list(self):
|
||||
_last_successful["test_job"] = "AAPL"
|
||||
result = _resume_tickers([], "test_job")
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestConfigureScheduler:
|
||||
def test_configure_adds_four_jobs(self):
|
||||
# Remove any existing jobs first
|
||||
scheduler.remove_all_jobs()
|
||||
configure_scheduler()
|
||||
jobs = scheduler.get_jobs()
|
||||
job_ids = {j.id for j in jobs}
|
||||
assert job_ids == {
|
||||
"data_collector",
|
||||
"sentiment_collector",
|
||||
"fundamental_collector",
|
||||
"rr_scanner",
|
||||
}
|
||||
|
||||
def test_configure_is_idempotent(self):
|
||||
scheduler.remove_all_jobs()
|
||||
configure_scheduler()
|
||||
configure_scheduler() # Should replace, not duplicate
|
||||
job_ids = [j.id for j in scheduler.get_jobs()]
|
||||
# Each ID should appear exactly once
|
||||
assert sorted(job_ids) == sorted([
|
||||
"data_collector",
|
||||
"fundamental_collector",
|
||||
"rr_scanner",
|
||||
"sentiment_collector",
|
||||
])
|
||||
Reference in New Issue
Block a user