first commit
This commit is contained in:
260
tests/conftest.py
Normal file
260
tests/conftest.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""Shared test fixtures and hypothesis strategies for the stock-data-backend test suite."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from hypothesis import strategies as st
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from app.database import Base
|
||||
from app.providers.protocol import OHLCVData
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test database (SQLite in-memory, async via aiosqlite)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite://"
|
||||
|
||||
_test_engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
||||
_test_session_factory = async_sessionmaker(
|
||||
_test_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def _setup_db():
|
||||
"""Create all tables before each test and drop them after."""
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session() -> AsyncSession:
|
||||
"""Provide a transactional DB session that rolls back after the test."""
|
||||
async with _test_session_factory() as session:
|
||||
async with session.begin():
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastAPI test client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(db_session: AsyncSession) -> AsyncClient:
|
||||
"""Async HTTP test client wired to the FastAPI app with the test DB session."""
|
||||
from app.dependencies import get_db
|
||||
from app.main import app
|
||||
|
||||
async def _override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock MarketDataProvider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockMarketDataProvider:
|
||||
"""Configurable mock that satisfies the MarketDataProvider protocol."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ohlcv_data: list[OHLCVData] | None = None,
|
||||
error: Exception | None = None,
|
||||
) -> None:
|
||||
self.ohlcv_data = ohlcv_data or []
|
||||
self.error = error
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
|
||||
async def fetch_ohlcv(
|
||||
self, ticker: str, start_date: date, end_date: date
|
||||
) -> list[OHLCVData]:
|
||||
self.calls.append(
|
||||
{"ticker": ticker, "start_date": start_date, "end_date": end_date}
|
||||
)
|
||||
if self.error is not None:
|
||||
raise self.error
|
||||
return [r for r in self.ohlcv_data if r.ticker == ticker]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider() -> MockMarketDataProvider:
|
||||
"""Return a fresh MockMarketDataProvider instance."""
|
||||
return MockMarketDataProvider()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hypothesis custom strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TICKER_ALPHABET = string.ascii_uppercase + string.digits
|
||||
|
||||
|
||||
@st.composite
|
||||
def valid_ticker_symbols(draw: st.DrawFn) -> str:
|
||||
"""Generate uppercase alphanumeric ticker symbols (1-10 chars)."""
|
||||
return draw(
|
||||
st.text(alphabet=_TICKER_ALPHABET, min_size=1, max_size=10)
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def whitespace_strings(draw: st.DrawFn) -> str:
|
||||
"""Generate strings composed entirely of whitespace (including empty)."""
|
||||
return draw(
|
||||
st.text(alphabet=" \t\n\r\x0b\x0c", min_size=0, max_size=20)
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def valid_ohlcv_records(draw: st.DrawFn) -> OHLCVData:
|
||||
"""Generate valid OHLCV records (high >= low, prices >= 0, volume >= 0, date <= today)."""
|
||||
ticker = draw(valid_ticker_symbols())
|
||||
low = draw(st.floats(min_value=0.01, max_value=10000.0, allow_nan=False, allow_infinity=False))
|
||||
high = draw(st.floats(min_value=low, max_value=10000.0, allow_nan=False, allow_infinity=False))
|
||||
open_ = draw(st.floats(min_value=low, max_value=high, allow_nan=False, allow_infinity=False))
|
||||
close = draw(st.floats(min_value=low, max_value=high, allow_nan=False, allow_infinity=False))
|
||||
volume = draw(st.integers(min_value=0, max_value=10**12))
|
||||
record_date = draw(
|
||||
st.dates(min_value=date(2000, 1, 1), max_value=date.today())
|
||||
)
|
||||
return OHLCVData(
|
||||
ticker=ticker,
|
||||
date=record_date,
|
||||
open=open_,
|
||||
high=high,
|
||||
low=low,
|
||||
close=close,
|
||||
volume=volume,
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def invalid_ohlcv_records(draw: st.DrawFn) -> OHLCVData:
|
||||
"""Generate OHLCV records that violate at least one constraint."""
|
||||
ticker = draw(valid_ticker_symbols())
|
||||
violation = draw(st.sampled_from(["high_lt_low", "negative_price", "negative_volume", "future_date"]))
|
||||
|
||||
if violation == "high_lt_low":
|
||||
high = draw(st.floats(min_value=0.01, max_value=100.0, allow_nan=False, allow_infinity=False))
|
||||
low = draw(st.floats(min_value=high + 0.01, max_value=200.0, allow_nan=False, allow_infinity=False))
|
||||
return OHLCVData(
|
||||
ticker=ticker, date=date.today(),
|
||||
open=high, high=high, low=low, close=high, volume=100,
|
||||
)
|
||||
elif violation == "negative_price":
|
||||
neg = draw(st.floats(min_value=-10000.0, max_value=-0.01, allow_nan=False, allow_infinity=False))
|
||||
return OHLCVData(
|
||||
ticker=ticker, date=date.today(),
|
||||
open=neg, high=abs(neg), low=abs(neg), close=abs(neg), volume=100,
|
||||
)
|
||||
elif violation == "negative_volume":
|
||||
price = draw(st.floats(min_value=0.01, max_value=100.0, allow_nan=False, allow_infinity=False))
|
||||
neg_vol = draw(st.integers(min_value=-10**9, max_value=-1))
|
||||
return OHLCVData(
|
||||
ticker=ticker, date=date.today(),
|
||||
open=price, high=price, low=price, close=price, volume=neg_vol,
|
||||
)
|
||||
else: # future_date
|
||||
future = date.today() + timedelta(days=draw(st.integers(min_value=1, max_value=365)))
|
||||
price = draw(st.floats(min_value=0.01, max_value=100.0, allow_nan=False, allow_infinity=False))
|
||||
return OHLCVData(
|
||||
ticker=ticker, date=future,
|
||||
open=price, high=price, low=price, close=price, volume=100,
|
||||
)
|
||||
|
||||
|
||||
_DIMENSIONS = ["technical", "sr_quality", "sentiment", "fundamental", "momentum"]
|
||||
|
||||
|
||||
@st.composite
|
||||
def dimension_scores(draw: st.DrawFn) -> float:
|
||||
"""Generate float values in [0, 100] for dimension scores."""
|
||||
return draw(st.floats(min_value=0.0, max_value=100.0, allow_nan=False, allow_infinity=False))
|
||||
|
||||
|
||||
@st.composite
|
||||
def weight_configs(draw: st.DrawFn) -> dict[str, float]:
|
||||
"""Generate dicts of dimension → positive float weight."""
|
||||
dims = draw(st.lists(st.sampled_from(_DIMENSIONS), min_size=1, max_size=5, unique=True))
|
||||
weights: dict[str, float] = {}
|
||||
for dim in dims:
|
||||
weights[dim] = draw(st.floats(min_value=0.01, max_value=10.0, allow_nan=False, allow_infinity=False))
|
||||
return weights
|
||||
|
||||
|
||||
@st.composite
|
||||
def sr_levels(draw: st.DrawFn) -> dict[str, Any]:
|
||||
"""Generate SR level data (price, type, strength, detection_method)."""
|
||||
return {
|
||||
"price_level": draw(st.floats(min_value=0.01, max_value=10000.0, allow_nan=False, allow_infinity=False)),
|
||||
"type": draw(st.sampled_from(["support", "resistance"])),
|
||||
"strength": draw(st.integers(min_value=0, max_value=100)),
|
||||
"detection_method": draw(st.sampled_from(["volume_profile", "pivot_point", "merged"])),
|
||||
}
|
||||
|
||||
|
||||
@st.composite
|
||||
def sentiment_scores(draw: st.DrawFn) -> dict[str, Any]:
|
||||
"""Generate sentiment data (classification, confidence, source, timestamp)."""
|
||||
naive_dt = draw(
|
||||
st.datetimes(
|
||||
min_value=datetime(2020, 1, 1),
|
||||
max_value=datetime.now(),
|
||||
)
|
||||
)
|
||||
return {
|
||||
"classification": draw(st.sampled_from(["bullish", "bearish", "neutral"])),
|
||||
"confidence": draw(st.integers(min_value=0, max_value=100)),
|
||||
"source": draw(st.text(alphabet=string.ascii_lowercase, min_size=3, max_size=20)),
|
||||
"timestamp": naive_dt.replace(tzinfo=timezone.utc),
|
||||
}
|
||||
|
||||
|
||||
@st.composite
|
||||
def trade_setups(draw: st.DrawFn) -> dict[str, Any]:
|
||||
"""Generate trade setup data (direction, entry, stop, target, rr_ratio, composite_score)."""
|
||||
direction = draw(st.sampled_from(["long", "short"]))
|
||||
entry = draw(st.floats(min_value=1.0, max_value=10000.0, allow_nan=False, allow_infinity=False))
|
||||
atr_dist = draw(st.floats(min_value=0.01, max_value=entry * 0.2, allow_nan=False, allow_infinity=False))
|
||||
|
||||
if direction == "long":
|
||||
stop = entry - atr_dist
|
||||
target = entry + atr_dist * draw(st.floats(min_value=3.0, max_value=10.0, allow_nan=False, allow_infinity=False))
|
||||
else:
|
||||
stop = entry + atr_dist
|
||||
target = entry - atr_dist * draw(st.floats(min_value=3.0, max_value=10.0, allow_nan=False, allow_infinity=False))
|
||||
|
||||
rr_ratio = abs(target - entry) / abs(entry - stop) if abs(entry - stop) > 0 else 0.0
|
||||
|
||||
return {
|
||||
"direction": direction,
|
||||
"entry_price": entry,
|
||||
"stop_loss": stop,
|
||||
"target": target,
|
||||
"rr_ratio": rr_ratio,
|
||||
"composite_score": draw(st.floats(min_value=0.0, max_value=100.0, allow_nan=False, allow_infinity=False)),
|
||||
}
|
||||
Reference in New Issue
Block a user