Files
signal-platform/tests/conftest.py
Dennis Thiessen 61ab24490d
Some checks failed
Deploy / lint (push) Failing after 7s
Deploy / test (push) Has been skipped
Deploy / deploy (push) Has been skipped
first commit
2026-02-20 17:31:01 +01:00

261 lines
9.5 KiB
Python

"""Shared test fixtures and hypothesis strategies for the stock-data-backend test suite."""
from __future__ import annotations
import string
from datetime import date, datetime, timedelta, timezone
from typing import Any
import pytest
from httpx import ASGITransport, AsyncClient
from hypothesis import strategies as st
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from app.database import Base
from app.providers.protocol import OHLCVData
# ---------------------------------------------------------------------------
# Test database (SQLite in-memory, async via aiosqlite)
# ---------------------------------------------------------------------------
TEST_DATABASE_URL = "sqlite+aiosqlite://"
_test_engine = create_async_engine(TEST_DATABASE_URL, echo=False)
_test_session_factory = async_sessionmaker(
_test_engine,
class_=AsyncSession,
expire_on_commit=False,
)
@pytest.fixture(autouse=True)
async def _setup_db():
"""Create all tables before each test and drop them after."""
async with _test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with _test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def db_session() -> AsyncSession:
"""Provide a transactional DB session that rolls back after the test."""
async with _test_session_factory() as session:
async with session.begin():
yield session
await session.rollback()
# ---------------------------------------------------------------------------
# FastAPI test client
# ---------------------------------------------------------------------------
@pytest.fixture
async def client(db_session: AsyncSession) -> AsyncClient:
"""Async HTTP test client wired to the FastAPI app with the test DB session."""
from app.dependencies import get_db
from app.main import app
async def _override_get_db():
yield db_session
app.dependency_overrides[get_db] = _override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
app.dependency_overrides.clear()
# ---------------------------------------------------------------------------
# Mock MarketDataProvider
# ---------------------------------------------------------------------------
class MockMarketDataProvider:
"""Configurable mock that satisfies the MarketDataProvider protocol."""
def __init__(
self,
ohlcv_data: list[OHLCVData] | None = None,
error: Exception | None = None,
) -> None:
self.ohlcv_data = ohlcv_data or []
self.error = error
self.calls: list[dict[str, Any]] = []
async def fetch_ohlcv(
self, ticker: str, start_date: date, end_date: date
) -> list[OHLCVData]:
self.calls.append(
{"ticker": ticker, "start_date": start_date, "end_date": end_date}
)
if self.error is not None:
raise self.error
return [r for r in self.ohlcv_data if r.ticker == ticker]
@pytest.fixture
def mock_provider() -> MockMarketDataProvider:
"""Return a fresh MockMarketDataProvider instance."""
return MockMarketDataProvider()
# ---------------------------------------------------------------------------
# Hypothesis custom strategies
# ---------------------------------------------------------------------------
_TICKER_ALPHABET = string.ascii_uppercase + string.digits
@st.composite
def valid_ticker_symbols(draw: st.DrawFn) -> str:
"""Generate uppercase alphanumeric ticker symbols (1-10 chars)."""
return draw(
st.text(alphabet=_TICKER_ALPHABET, min_size=1, max_size=10)
)
@st.composite
def whitespace_strings(draw: st.DrawFn) -> str:
"""Generate strings composed entirely of whitespace (including empty)."""
return draw(
st.text(alphabet=" \t\n\r\x0b\x0c", min_size=0, max_size=20)
)
@st.composite
def valid_ohlcv_records(draw: st.DrawFn) -> OHLCVData:
"""Generate valid OHLCV records (high >= low, prices >= 0, volume >= 0, date <= today)."""
ticker = draw(valid_ticker_symbols())
low = draw(st.floats(min_value=0.01, max_value=10000.0, allow_nan=False, allow_infinity=False))
high = draw(st.floats(min_value=low, max_value=10000.0, allow_nan=False, allow_infinity=False))
open_ = draw(st.floats(min_value=low, max_value=high, allow_nan=False, allow_infinity=False))
close = draw(st.floats(min_value=low, max_value=high, allow_nan=False, allow_infinity=False))
volume = draw(st.integers(min_value=0, max_value=10**12))
record_date = draw(
st.dates(min_value=date(2000, 1, 1), max_value=date.today())
)
return OHLCVData(
ticker=ticker,
date=record_date,
open=open_,
high=high,
low=low,
close=close,
volume=volume,
)
@st.composite
def invalid_ohlcv_records(draw: st.DrawFn) -> OHLCVData:
"""Generate OHLCV records that violate at least one constraint."""
ticker = draw(valid_ticker_symbols())
violation = draw(st.sampled_from(["high_lt_low", "negative_price", "negative_volume", "future_date"]))
if violation == "high_lt_low":
high = draw(st.floats(min_value=0.01, max_value=100.0, allow_nan=False, allow_infinity=False))
low = draw(st.floats(min_value=high + 0.01, max_value=200.0, allow_nan=False, allow_infinity=False))
return OHLCVData(
ticker=ticker, date=date.today(),
open=high, high=high, low=low, close=high, volume=100,
)
elif violation == "negative_price":
neg = draw(st.floats(min_value=-10000.0, max_value=-0.01, allow_nan=False, allow_infinity=False))
return OHLCVData(
ticker=ticker, date=date.today(),
open=neg, high=abs(neg), low=abs(neg), close=abs(neg), volume=100,
)
elif violation == "negative_volume":
price = draw(st.floats(min_value=0.01, max_value=100.0, allow_nan=False, allow_infinity=False))
neg_vol = draw(st.integers(min_value=-10**9, max_value=-1))
return OHLCVData(
ticker=ticker, date=date.today(),
open=price, high=price, low=price, close=price, volume=neg_vol,
)
else: # future_date
future = date.today() + timedelta(days=draw(st.integers(min_value=1, max_value=365)))
price = draw(st.floats(min_value=0.01, max_value=100.0, allow_nan=False, allow_infinity=False))
return OHLCVData(
ticker=ticker, date=future,
open=price, high=price, low=price, close=price, volume=100,
)
_DIMENSIONS = ["technical", "sr_quality", "sentiment", "fundamental", "momentum"]
@st.composite
def dimension_scores(draw: st.DrawFn) -> float:
"""Generate float values in [0, 100] for dimension scores."""
return draw(st.floats(min_value=0.0, max_value=100.0, allow_nan=False, allow_infinity=False))
@st.composite
def weight_configs(draw: st.DrawFn) -> dict[str, float]:
"""Generate dicts of dimension → positive float weight."""
dims = draw(st.lists(st.sampled_from(_DIMENSIONS), min_size=1, max_size=5, unique=True))
weights: dict[str, float] = {}
for dim in dims:
weights[dim] = draw(st.floats(min_value=0.01, max_value=10.0, allow_nan=False, allow_infinity=False))
return weights
@st.composite
def sr_levels(draw: st.DrawFn) -> dict[str, Any]:
"""Generate SR level data (price, type, strength, detection_method)."""
return {
"price_level": draw(st.floats(min_value=0.01, max_value=10000.0, allow_nan=False, allow_infinity=False)),
"type": draw(st.sampled_from(["support", "resistance"])),
"strength": draw(st.integers(min_value=0, max_value=100)),
"detection_method": draw(st.sampled_from(["volume_profile", "pivot_point", "merged"])),
}
@st.composite
def sentiment_scores(draw: st.DrawFn) -> dict[str, Any]:
"""Generate sentiment data (classification, confidence, source, timestamp)."""
naive_dt = draw(
st.datetimes(
min_value=datetime(2020, 1, 1),
max_value=datetime.now(),
)
)
return {
"classification": draw(st.sampled_from(["bullish", "bearish", "neutral"])),
"confidence": draw(st.integers(min_value=0, max_value=100)),
"source": draw(st.text(alphabet=string.ascii_lowercase, min_size=3, max_size=20)),
"timestamp": naive_dt.replace(tzinfo=timezone.utc),
}
@st.composite
def trade_setups(draw: st.DrawFn) -> dict[str, Any]:
"""Generate trade setup data (direction, entry, stop, target, rr_ratio, composite_score)."""
direction = draw(st.sampled_from(["long", "short"]))
entry = draw(st.floats(min_value=1.0, max_value=10000.0, allow_nan=False, allow_infinity=False))
atr_dist = draw(st.floats(min_value=0.01, max_value=entry * 0.2, allow_nan=False, allow_infinity=False))
if direction == "long":
stop = entry - atr_dist
target = entry + atr_dist * draw(st.floats(min_value=3.0, max_value=10.0, allow_nan=False, allow_infinity=False))
else:
stop = entry + atr_dist
target = entry - atr_dist * draw(st.floats(min_value=3.0, max_value=10.0, allow_nan=False, allow_infinity=False))
rr_ratio = abs(target - entry) / abs(entry - stop) if abs(entry - stop) > 0 else 0.0
return {
"direction": direction,
"entry_price": entry,
"stop_loss": stop,
"target": target,
"rr_ratio": rr_ratio,
"composite_score": draw(st.floats(min_value=0.0, max_value=100.0, allow_nan=False, allow_infinity=False)),
}