"""Shared test fixtures and hypothesis strategies for the stock-data-backend test suite.""" from __future__ import annotations import string from datetime import date, datetime, timedelta, timezone from typing import Any import pytest from httpx import ASGITransport, AsyncClient from hypothesis import strategies as st from sqlalchemy.ext.asyncio import ( AsyncSession, async_sessionmaker, create_async_engine, ) from app.database import Base from app.providers.protocol import OHLCVData # --------------------------------------------------------------------------- # Test database (SQLite in-memory, async via aiosqlite) # --------------------------------------------------------------------------- TEST_DATABASE_URL = "sqlite+aiosqlite://" _test_engine = create_async_engine(TEST_DATABASE_URL, echo=False) _test_session_factory = async_sessionmaker( _test_engine, class_=AsyncSession, expire_on_commit=False, ) @pytest.fixture(autouse=True) async def _setup_db(): """Create all tables before each test and drop them after.""" async with _test_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield async with _test_engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) @pytest.fixture async def db_session() -> AsyncSession: """Provide a transactional DB session that rolls back after the test.""" async with _test_session_factory() as session: async with session.begin(): yield session await session.rollback() # --------------------------------------------------------------------------- # FastAPI test client # --------------------------------------------------------------------------- @pytest.fixture async def client(db_session: AsyncSession) -> AsyncClient: """Async HTTP test client wired to the FastAPI app with the test DB session.""" from app.dependencies import get_db from app.main import app async def _override_get_db(): yield db_session app.dependency_overrides[get_db] = _override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac app.dependency_overrides.clear() # --------------------------------------------------------------------------- # Mock MarketDataProvider # --------------------------------------------------------------------------- class MockMarketDataProvider: """Configurable mock that satisfies the MarketDataProvider protocol.""" def __init__( self, ohlcv_data: list[OHLCVData] | None = None, error: Exception | None = None, ) -> None: self.ohlcv_data = ohlcv_data or [] self.error = error self.calls: list[dict[str, Any]] = [] async def fetch_ohlcv( self, ticker: str, start_date: date, end_date: date ) -> list[OHLCVData]: self.calls.append( {"ticker": ticker, "start_date": start_date, "end_date": end_date} ) if self.error is not None: raise self.error return [r for r in self.ohlcv_data if r.ticker == ticker] @pytest.fixture def mock_provider() -> MockMarketDataProvider: """Return a fresh MockMarketDataProvider instance.""" return MockMarketDataProvider() # --------------------------------------------------------------------------- # Hypothesis custom strategies # --------------------------------------------------------------------------- _TICKER_ALPHABET = string.ascii_uppercase + string.digits @st.composite def valid_ticker_symbols(draw: st.DrawFn) -> str: """Generate uppercase alphanumeric ticker symbols (1-10 chars).""" return draw( st.text(alphabet=_TICKER_ALPHABET, min_size=1, max_size=10) ) @st.composite def whitespace_strings(draw: st.DrawFn) -> str: """Generate strings composed entirely of whitespace (including empty).""" return draw( st.text(alphabet=" \t\n\r\x0b\x0c", min_size=0, max_size=20) ) @st.composite def valid_ohlcv_records(draw: st.DrawFn) -> OHLCVData: """Generate valid OHLCV records (high >= low, prices >= 0, volume >= 0, date <= today).""" ticker = draw(valid_ticker_symbols()) low = draw(st.floats(min_value=0.01, max_value=10000.0, allow_nan=False, allow_infinity=False)) high = draw(st.floats(min_value=low, max_value=10000.0, allow_nan=False, allow_infinity=False)) open_ = draw(st.floats(min_value=low, max_value=high, allow_nan=False, allow_infinity=False)) close = draw(st.floats(min_value=low, max_value=high, allow_nan=False, allow_infinity=False)) volume = draw(st.integers(min_value=0, max_value=10**12)) record_date = draw( st.dates(min_value=date(2000, 1, 1), max_value=date.today()) ) return OHLCVData( ticker=ticker, date=record_date, open=open_, high=high, low=low, close=close, volume=volume, ) @st.composite def invalid_ohlcv_records(draw: st.DrawFn) -> OHLCVData: """Generate OHLCV records that violate at least one constraint.""" ticker = draw(valid_ticker_symbols()) violation = draw(st.sampled_from(["high_lt_low", "negative_price", "negative_volume", "future_date"])) if violation == "high_lt_low": high = draw(st.floats(min_value=0.01, max_value=100.0, allow_nan=False, allow_infinity=False)) low = draw(st.floats(min_value=high + 0.01, max_value=200.0, allow_nan=False, allow_infinity=False)) return OHLCVData( ticker=ticker, date=date.today(), open=high, high=high, low=low, close=high, volume=100, ) elif violation == "negative_price": neg = draw(st.floats(min_value=-10000.0, max_value=-0.01, allow_nan=False, allow_infinity=False)) return OHLCVData( ticker=ticker, date=date.today(), open=neg, high=abs(neg), low=abs(neg), close=abs(neg), volume=100, ) elif violation == "negative_volume": price = draw(st.floats(min_value=0.01, max_value=100.0, allow_nan=False, allow_infinity=False)) neg_vol = draw(st.integers(min_value=-10**9, max_value=-1)) return OHLCVData( ticker=ticker, date=date.today(), open=price, high=price, low=price, close=price, volume=neg_vol, ) else: # future_date future = date.today() + timedelta(days=draw(st.integers(min_value=1, max_value=365))) price = draw(st.floats(min_value=0.01, max_value=100.0, allow_nan=False, allow_infinity=False)) return OHLCVData( ticker=ticker, date=future, open=price, high=price, low=price, close=price, volume=100, ) _DIMENSIONS = ["technical", "sr_quality", "sentiment", "fundamental", "momentum"] @st.composite def dimension_scores(draw: st.DrawFn) -> float: """Generate float values in [0, 100] for dimension scores.""" return draw(st.floats(min_value=0.0, max_value=100.0, allow_nan=False, allow_infinity=False)) @st.composite def weight_configs(draw: st.DrawFn) -> dict[str, float]: """Generate dicts of dimension → positive float weight.""" dims = draw(st.lists(st.sampled_from(_DIMENSIONS), min_size=1, max_size=5, unique=True)) weights: dict[str, float] = {} for dim in dims: weights[dim] = draw(st.floats(min_value=0.01, max_value=10.0, allow_nan=False, allow_infinity=False)) return weights @st.composite def sr_levels(draw: st.DrawFn) -> dict[str, Any]: """Generate SR level data (price, type, strength, detection_method).""" return { "price_level": draw(st.floats(min_value=0.01, max_value=10000.0, allow_nan=False, allow_infinity=False)), "type": draw(st.sampled_from(["support", "resistance"])), "strength": draw(st.integers(min_value=0, max_value=100)), "detection_method": draw(st.sampled_from(["volume_profile", "pivot_point", "merged"])), } @st.composite def sentiment_scores(draw: st.DrawFn) -> dict[str, Any]: """Generate sentiment data (classification, confidence, source, timestamp).""" naive_dt = draw( st.datetimes( min_value=datetime(2020, 1, 1), max_value=datetime.now(), ) ) return { "classification": draw(st.sampled_from(["bullish", "bearish", "neutral"])), "confidence": draw(st.integers(min_value=0, max_value=100)), "source": draw(st.text(alphabet=string.ascii_lowercase, min_size=3, max_size=20)), "timestamp": naive_dt.replace(tzinfo=timezone.utc), } @st.composite def trade_setups(draw: st.DrawFn) -> dict[str, Any]: """Generate trade setup data (direction, entry, stop, target, rr_ratio, composite_score).""" direction = draw(st.sampled_from(["long", "short"])) entry = draw(st.floats(min_value=1.0, max_value=10000.0, allow_nan=False, allow_infinity=False)) atr_dist = draw(st.floats(min_value=0.01, max_value=entry * 0.2, allow_nan=False, allow_infinity=False)) if direction == "long": stop = entry - atr_dist target = entry + atr_dist * draw(st.floats(min_value=3.0, max_value=10.0, allow_nan=False, allow_infinity=False)) else: stop = entry + atr_dist target = entry - atr_dist * draw(st.floats(min_value=3.0, max_value=10.0, allow_nan=False, allow_infinity=False)) rr_ratio = abs(target - entry) / abs(entry - stop) if abs(entry - stop) > 0 else 0.0 return { "direction": direction, "entry_price": entry, "stop_loss": stop, "target": target, "rr_ratio": rr_ratio, "composite_score": draw(st.floats(min_value=0.0, max_value=100.0, allow_nan=False, allow_infinity=False)), }