Big refactoring
This commit is contained in:
104
tests/property/test_recommendation_properties.py
Normal file
104
tests/property/test_recommendation_properties.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from hypothesis import given, settings, strategies as st
|
||||
|
||||
from app.services.recommendation_service import direction_analyzer, probability_estimator
|
||||
|
||||
|
||||
@settings(max_examples=100, deadline=None)
|
||||
@given(
|
||||
technical=st.floats(min_value=0, max_value=100),
|
||||
momentum=st.floats(min_value=0, max_value=100),
|
||||
fundamental=st.floats(min_value=0, max_value=100),
|
||||
sentiment=st.sampled_from(["bearish", "neutral", "bullish", None]),
|
||||
)
|
||||
def test_property_confidence_bounds(technical, momentum, fundamental, sentiment):
|
||||
"""Feature: intelligent-trade-recommendations, Property 3: Confidence Score Bounds."""
|
||||
scores = {
|
||||
"technical": technical,
|
||||
"momentum": momentum,
|
||||
"fundamental": fundamental,
|
||||
}
|
||||
|
||||
long_conf = direction_analyzer.calculate_confidence("long", scores, sentiment, conflicts=[])
|
||||
short_conf = direction_analyzer.calculate_confidence("short", scores, sentiment, conflicts=[])
|
||||
|
||||
assert 0 <= long_conf <= 100
|
||||
assert 0 <= short_conf <= 100
|
||||
|
||||
|
||||
@settings(max_examples=100, deadline=None)
|
||||
@given(
|
||||
strength_low=st.floats(min_value=0, max_value=50),
|
||||
strength_high=st.floats(min_value=50, max_value=100),
|
||||
)
|
||||
def test_property_strength_monotonic_probability(strength_low, strength_high):
|
||||
"""Feature: intelligent-trade-recommendations, Property 11: S/R Strength Monotonicity."""
|
||||
config = {
|
||||
"recommendation_signal_alignment_weight": 0.15,
|
||||
"recommendation_sr_strength_weight": 0.20,
|
||||
"recommendation_distance_penalty_factor": 0.10,
|
||||
}
|
||||
scores = {"technical": 65.0, "momentum": 65.0}
|
||||
|
||||
base_target = {
|
||||
"classification": "Moderate",
|
||||
"distance_atr_multiple": 3.0,
|
||||
}
|
||||
|
||||
low = probability_estimator.estimate_probability(
|
||||
{**base_target, "sr_strength": strength_low},
|
||||
scores,
|
||||
"bullish",
|
||||
"long",
|
||||
config,
|
||||
)
|
||||
high = probability_estimator.estimate_probability(
|
||||
{**base_target, "sr_strength": strength_high},
|
||||
scores,
|
||||
"bullish",
|
||||
"long",
|
||||
config,
|
||||
)
|
||||
|
||||
assert high >= low
|
||||
|
||||
|
||||
@settings(max_examples=100, deadline=None)
|
||||
@given(
|
||||
near_distance=st.floats(min_value=1.0, max_value=3.0),
|
||||
far_distance=st.floats(min_value=3.1, max_value=8.0),
|
||||
)
|
||||
def test_property_distance_probability_relationship(near_distance, far_distance):
|
||||
"""Feature: intelligent-trade-recommendations, Property 12: Distance Probability Relationship."""
|
||||
config = {
|
||||
"recommendation_signal_alignment_weight": 0.15,
|
||||
"recommendation_sr_strength_weight": 0.20,
|
||||
"recommendation_distance_penalty_factor": 0.10,
|
||||
}
|
||||
scores = {"technical": 65.0, "momentum": 65.0}
|
||||
|
||||
near_prob = probability_estimator.estimate_probability(
|
||||
{
|
||||
"classification": "Conservative",
|
||||
"sr_strength": 60,
|
||||
"distance_atr_multiple": near_distance,
|
||||
},
|
||||
scores,
|
||||
"bullish",
|
||||
"long",
|
||||
config,
|
||||
)
|
||||
far_prob = probability_estimator.estimate_probability(
|
||||
{
|
||||
"classification": "Aggressive",
|
||||
"sr_strength": 60,
|
||||
"distance_atr_multiple": far_distance,
|
||||
},
|
||||
scores,
|
||||
"bullish",
|
||||
"long",
|
||||
config,
|
||||
)
|
||||
|
||||
assert near_prob >= far_prob
|
||||
72
tests/unit/test_fundamentals_chain_provider.py
Normal file
72
tests/unit/test_fundamentals_chain_provider.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Unit tests for chained fundamentals provider fallback behavior."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from app.exceptions import ProviderError
|
||||
from app.providers.fundamentals_chain import ChainedFundamentalProvider
|
||||
from app.providers.protocol import FundamentalData
|
||||
|
||||
|
||||
class _FailProvider:
|
||||
def __init__(self, message: str) -> None:
|
||||
self._message = message
|
||||
|
||||
async def fetch_fundamentals(self, ticker: str) -> FundamentalData:
|
||||
raise ProviderError(f"{self._message} ({ticker})")
|
||||
|
||||
|
||||
class _DataProvider:
|
||||
def __init__(self, data: FundamentalData) -> None:
|
||||
self._data = data
|
||||
|
||||
async def fetch_fundamentals(self, ticker: str) -> FundamentalData:
|
||||
return FundamentalData(
|
||||
ticker=ticker,
|
||||
pe_ratio=self._data.pe_ratio,
|
||||
revenue_growth=self._data.revenue_growth,
|
||||
earnings_surprise=self._data.earnings_surprise,
|
||||
market_cap=self._data.market_cap,
|
||||
fetched_at=self._data.fetched_at,
|
||||
unavailable_fields=self._data.unavailable_fields,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chained_provider_uses_fallback_provider_on_primary_failure():
|
||||
fallback_data = FundamentalData(
|
||||
ticker="AAPL",
|
||||
pe_ratio=25.0,
|
||||
revenue_growth=None,
|
||||
earnings_surprise=None,
|
||||
market_cap=1_000_000.0,
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
unavailable_fields={},
|
||||
)
|
||||
|
||||
provider = ChainedFundamentalProvider([
|
||||
("primary", _FailProvider("primary down")),
|
||||
("fallback", _DataProvider(fallback_data)),
|
||||
])
|
||||
|
||||
result = await provider.fetch_fundamentals("AAPL")
|
||||
|
||||
assert result.pe_ratio == 25.0
|
||||
assert result.market_cap == 1_000_000.0
|
||||
assert result.unavailable_fields.get("provider") == "fallback"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chained_provider_raises_when_all_providers_fail():
|
||||
provider = ChainedFundamentalProvider([
|
||||
("p1", _FailProvider("p1 failed")),
|
||||
("p2", _FailProvider("p2 failed")),
|
||||
])
|
||||
|
||||
with pytest.raises(ProviderError) as exc:
|
||||
await provider.fetch_fundamentals("MSFT")
|
||||
|
||||
assert "All fundamentals providers failed" in str(exc.value)
|
||||
@@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.exceptions import ProviderError
|
||||
from app.providers.openai_sentiment import OpenAISentimentProvider
|
||||
|
||||
|
||||
@@ -160,3 +161,42 @@ class TestCitationsExtraction:
|
||||
|
||||
assert result.citations == []
|
||||
assert result.reasoning == "Quiet day"
|
||||
|
||||
|
||||
class TestBatchSentiment:
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_sentiment_parses_multiple_tickers(self, provider):
|
||||
json_text = (
|
||||
'[{"ticker":"AAPL","classification":"bullish","confidence":81,"reasoning":"Positive earnings"},'
|
||||
'{"ticker":"MSFT","classification":"neutral","confidence":52,"reasoning":"Mixed guidance"}]'
|
||||
)
|
||||
mock_response = _build_response(json_text)
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await provider.fetch_sentiment_batch(["AAPL", "MSFT"])
|
||||
|
||||
assert set(result.keys()) == {"AAPL", "MSFT"}
|
||||
assert result["AAPL"].classification == "bullish"
|
||||
assert result["MSFT"].classification == "neutral"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_sentiment_skips_invalid_rows(self, provider):
|
||||
json_text = (
|
||||
'[{"ticker":"AAPL","classification":"bullish","confidence":81,"reasoning":"Positive earnings"},'
|
||||
'{"ticker":"TSLA","classification":"invalid","confidence":95,"reasoning":"Bad shape"}]'
|
||||
)
|
||||
mock_response = _build_response(json_text)
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await provider.fetch_sentiment_batch(["AAPL", "MSFT"])
|
||||
|
||||
assert set(result.keys()) == {"AAPL"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_sentiment_requires_array_json(self, provider):
|
||||
json_text = '{"ticker":"AAPL","classification":"bullish","confidence":81,"reasoning":"Positive earnings"}'
|
||||
mock_response = _build_response(json_text)
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with pytest.raises(ProviderError):
|
||||
await provider.fetch_sentiment_batch(["AAPL", "MSFT"])
|
||||
|
||||
130
tests/unit/test_recommendation_service.py
Normal file
130
tests/unit/test_recommendation_service.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.services.recommendation_service import (
|
||||
direction_analyzer,
|
||||
probability_estimator,
|
||||
signal_conflict_detector,
|
||||
target_generator,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SRLevelStub:
|
||||
id: int
|
||||
price_level: float
|
||||
type: str
|
||||
strength: int
|
||||
|
||||
|
||||
def test_high_confidence_long_example():
|
||||
dimension_scores = {
|
||||
"technical": 75.0,
|
||||
"momentum": 68.0,
|
||||
"fundamental": 55.0,
|
||||
}
|
||||
|
||||
confidence = direction_analyzer.calculate_confidence(
|
||||
direction="long",
|
||||
dimension_scores=dimension_scores,
|
||||
sentiment_classification="bullish",
|
||||
conflicts=[],
|
||||
)
|
||||
|
||||
assert confidence > 70.0
|
||||
|
||||
|
||||
def test_high_confidence_short_example():
|
||||
dimension_scores = {
|
||||
"technical": 30.0,
|
||||
"momentum": 35.0,
|
||||
"fundamental": 45.0,
|
||||
}
|
||||
|
||||
confidence = direction_analyzer.calculate_confidence(
|
||||
direction="short",
|
||||
dimension_scores=dimension_scores,
|
||||
sentiment_classification="bearish",
|
||||
conflicts=[],
|
||||
)
|
||||
|
||||
assert confidence > 70.0
|
||||
|
||||
|
||||
def test_detects_sentiment_technical_conflict():
|
||||
conflicts = signal_conflict_detector.detect_conflicts(
|
||||
dimension_scores={"technical": 72.0, "momentum": 55.0, "fundamental": 50.0},
|
||||
sentiment_classification="bearish",
|
||||
)
|
||||
|
||||
assert any("sentiment-technical" in conflict for conflict in conflicts)
|
||||
|
||||
|
||||
def test_generate_targets_respects_direction_and_order():
|
||||
sr_levels = [
|
||||
_SRLevelStub(id=1, price_level=110.0, type="resistance", strength=80),
|
||||
_SRLevelStub(id=2, price_level=115.0, type="resistance", strength=70),
|
||||
_SRLevelStub(id=3, price_level=120.0, type="resistance", strength=60),
|
||||
_SRLevelStub(id=4, price_level=95.0, type="support", strength=75),
|
||||
]
|
||||
|
||||
targets = target_generator.generate_targets(
|
||||
direction="long",
|
||||
entry_price=100.0,
|
||||
stop_loss=96.0,
|
||||
sr_levels=sr_levels, # type: ignore[arg-type]
|
||||
atr_value=2.0,
|
||||
)
|
||||
|
||||
assert len(targets) >= 1
|
||||
assert all(target["price"] > 100.0 for target in targets)
|
||||
distances = [target["distance_from_entry"] for target in targets]
|
||||
assert distances == sorted(distances)
|
||||
|
||||
|
||||
def test_probability_ranges_by_classification():
|
||||
config = {
|
||||
"recommendation_signal_alignment_weight": 0.15,
|
||||
"recommendation_sr_strength_weight": 0.20,
|
||||
"recommendation_distance_penalty_factor": 0.10,
|
||||
}
|
||||
dimension_scores = {"technical": 70.0, "momentum": 70.0}
|
||||
|
||||
conservative = probability_estimator.estimate_probability(
|
||||
{
|
||||
"classification": "Conservative",
|
||||
"sr_strength": 80,
|
||||
"distance_atr_multiple": 1.5,
|
||||
},
|
||||
dimension_scores,
|
||||
"bullish",
|
||||
"long",
|
||||
config,
|
||||
)
|
||||
moderate = probability_estimator.estimate_probability(
|
||||
{
|
||||
"classification": "Moderate",
|
||||
"sr_strength": 60,
|
||||
"distance_atr_multiple": 3.0,
|
||||
},
|
||||
dimension_scores,
|
||||
"bullish",
|
||||
"long",
|
||||
config,
|
||||
)
|
||||
aggressive = probability_estimator.estimate_probability(
|
||||
{
|
||||
"classification": "Aggressive",
|
||||
"sr_strength": 40,
|
||||
"distance_atr_multiple": 6.0,
|
||||
},
|
||||
dimension_scores,
|
||||
"bullish",
|
||||
"long",
|
||||
config,
|
||||
)
|
||||
|
||||
assert conservative > 60
|
||||
assert 40 <= moderate <= 70
|
||||
assert aggressive < 50
|
||||
@@ -228,23 +228,23 @@ async def test_scan_ticker_full_flow_quality_selection_and_persistence(
|
||||
)
|
||||
|
||||
# -- Assert: database persistence --
|
||||
# Old dummy setup should be gone, only the 2 new setups should exist
|
||||
# History is preserved: old setup remains, 2 new setups are appended
|
||||
db_result = await scan_session.execute(
|
||||
select(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
|
||||
)
|
||||
persisted = list(db_result.scalars().all())
|
||||
assert len(persisted) == 2, (
|
||||
f"Expected 2 persisted setups (old deleted), got {len(persisted)}"
|
||||
assert len(persisted) == 3, (
|
||||
f"Expected 3 persisted setups (1 old + 2 new), got {len(persisted)}"
|
||||
)
|
||||
|
||||
persisted_directions = sorted(s.direction for s in persisted)
|
||||
assert persisted_directions == ["long", "short"], (
|
||||
f"Expected ['long', 'short'] persisted, got {persisted_directions}"
|
||||
assert persisted_directions == ["long", "long", "short"], (
|
||||
f"Expected ['long', 'long', 'short'] persisted, got {persisted_directions}"
|
||||
)
|
||||
|
||||
# Verify persisted records match returned setups
|
||||
persisted_long = [s for s in persisted if s.direction == "long"][0]
|
||||
persisted_short = [s for s in persisted if s.direction == "short"][0]
|
||||
# Verify latest persisted records match returned setups
|
||||
persisted_long = max((s for s in persisted if s.direction == "long"), key=lambda s: s.id)
|
||||
persisted_short = max((s for s in persisted if s.direction == "short"), key=lambda s: s.id)
|
||||
|
||||
assert persisted_long.target == long_setup.target
|
||||
assert persisted_long.rr_ratio == long_setup.rr_ratio
|
||||
|
||||
@@ -68,7 +68,7 @@ class TestResumeTickers:
|
||||
|
||||
|
||||
class TestConfigureScheduler:
|
||||
def test_configure_adds_four_jobs(self):
|
||||
def test_configure_adds_five_jobs(self):
|
||||
# Remove any existing jobs first
|
||||
scheduler.remove_all_jobs()
|
||||
configure_scheduler()
|
||||
@@ -79,6 +79,7 @@ class TestConfigureScheduler:
|
||||
"sentiment_collector",
|
||||
"fundamental_collector",
|
||||
"rr_scanner",
|
||||
"ticker_universe_sync",
|
||||
}
|
||||
|
||||
def test_configure_is_idempotent(self):
|
||||
@@ -92,4 +93,5 @@ class TestConfigureScheduler:
|
||||
"fundamental_collector",
|
||||
"rr_scanner",
|
||||
"sentiment_collector",
|
||||
"ticker_universe_sync",
|
||||
])
|
||||
|
||||
123
tests/unit/test_ticker_universe_service.py
Normal file
123
tests/unit/test_ticker_universe_service.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Unit tests for ticker_universe_service bootstrap logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.database import Base
|
||||
from app.exceptions import ProviderError
|
||||
from app.models.settings import SystemSetting
|
||||
from app.models.ticker import Ticker
|
||||
from app.services import ticker_universe_service
|
||||
|
||||
_engine = create_async_engine("sqlite+aiosqlite://", echo=False)
|
||||
_session_factory = async_sessionmaker(_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def _setup_tables() -> AsyncGenerator[None, None]:
|
||||
async with _engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
async with _engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with _session_factory() as s:
|
||||
yield s
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bootstrap_universe_adds_missing_symbols(session: AsyncSession, monkeypatch: pytest.MonkeyPatch):
|
||||
session.add(Ticker(symbol="AAPL"))
|
||||
await session.commit()
|
||||
|
||||
async def _fake_fetch(_db: AsyncSession, _universe: str) -> list[str]:
|
||||
return ["AAPL", "MSFT", "NVDA"]
|
||||
|
||||
monkeypatch.setattr(ticker_universe_service, "fetch_universe_symbols", _fake_fetch)
|
||||
|
||||
result = await ticker_universe_service.bootstrap_universe(session, "sp500")
|
||||
|
||||
assert result["added"] == 2
|
||||
assert result["already_tracked"] == 1
|
||||
assert result["deleted"] == 0
|
||||
|
||||
rows = await session.execute(select(Ticker.symbol).order_by(Ticker.symbol.asc()))
|
||||
assert list(rows.scalars().all()) == ["AAPL", "MSFT", "NVDA"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bootstrap_universe_prunes_missing_symbols(session: AsyncSession, monkeypatch: pytest.MonkeyPatch):
|
||||
session.add_all([Ticker(symbol="AAPL"), Ticker(symbol="MSFT"), Ticker(symbol="TSLA")])
|
||||
await session.commit()
|
||||
|
||||
async def _fake_fetch(_db: AsyncSession, _universe: str) -> list[str]:
|
||||
return ["AAPL", "MSFT"]
|
||||
|
||||
monkeypatch.setattr(ticker_universe_service, "fetch_universe_symbols", _fake_fetch)
|
||||
|
||||
result = await ticker_universe_service.bootstrap_universe(
|
||||
session,
|
||||
"sp500",
|
||||
prune_missing=True,
|
||||
)
|
||||
|
||||
assert result["added"] == 0
|
||||
assert result["already_tracked"] == 2
|
||||
assert result["deleted"] == 1
|
||||
|
||||
rows = await session.execute(select(Ticker.symbol).order_by(Ticker.symbol.asc()))
|
||||
assert list(rows.scalars().all()) == ["AAPL", "MSFT"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_universe_symbols_uses_cached_snapshot_when_live_sources_fail(
|
||||
session: AsyncSession,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
session.add(
|
||||
SystemSetting(
|
||||
key="ticker_universe_cache_sp500",
|
||||
value=json.dumps({"symbols": ["AAPL", "MSFT"], "source": "test"}),
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def _fake_public(_universe: str):
|
||||
return [], ["public failed"], None
|
||||
|
||||
async def _fake_fmp(_universe: str):
|
||||
raise ProviderError("fmp failed")
|
||||
|
||||
monkeypatch.setattr(ticker_universe_service, "_fetch_universe_symbols_from_public", _fake_public)
|
||||
monkeypatch.setattr(ticker_universe_service, "_fetch_universe_symbols_from_fmp", _fake_fmp)
|
||||
|
||||
symbols = await ticker_universe_service.fetch_universe_symbols(session, "sp500")
|
||||
assert symbols == ["AAPL", "MSFT"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_universe_symbols_uses_seed_when_live_and_cache_fail(
|
||||
session: AsyncSession,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
async def _fake_public(_universe: str):
|
||||
return [], ["public failed"], None
|
||||
|
||||
async def _fake_fmp(_universe: str):
|
||||
raise ProviderError("fmp failed")
|
||||
|
||||
monkeypatch.setattr(ticker_universe_service, "_fetch_universe_symbols_from_public", _fake_public)
|
||||
monkeypatch.setattr(ticker_universe_service, "_fetch_universe_symbols_from_fmp", _fake_fmp)
|
||||
|
||||
symbols = await ticker_universe_service.fetch_universe_symbols(session, "sp500")
|
||||
assert "AAPL" in symbols
|
||||
assert len(symbols) > 10
|
||||
Reference in New Issue
Block a user