Big refactoring
Some checks failed
Deploy / lint (push) Failing after 21s
Deploy / test (push) Has been skipped
Deploy / deploy (push) Has been skipped

This commit is contained in:
Dennis Thiessen
2026-03-03 15:20:18 +01:00
parent 181cfe6588
commit 0a011d4ce9
55 changed files with 6898 additions and 544 deletions

View File

@@ -0,0 +1,72 @@
"""Unit tests for chained fundamentals provider fallback behavior."""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from app.exceptions import ProviderError
from app.providers.fundamentals_chain import ChainedFundamentalProvider
from app.providers.protocol import FundamentalData
class _FailProvider:
def __init__(self, message: str) -> None:
self._message = message
async def fetch_fundamentals(self, ticker: str) -> FundamentalData:
raise ProviderError(f"{self._message} ({ticker})")
class _DataProvider:
def __init__(self, data: FundamentalData) -> None:
self._data = data
async def fetch_fundamentals(self, ticker: str) -> FundamentalData:
return FundamentalData(
ticker=ticker,
pe_ratio=self._data.pe_ratio,
revenue_growth=self._data.revenue_growth,
earnings_surprise=self._data.earnings_surprise,
market_cap=self._data.market_cap,
fetched_at=self._data.fetched_at,
unavailable_fields=self._data.unavailable_fields,
)
@pytest.mark.asyncio
async def test_chained_provider_uses_fallback_provider_on_primary_failure():
fallback_data = FundamentalData(
ticker="AAPL",
pe_ratio=25.0,
revenue_growth=None,
earnings_surprise=None,
market_cap=1_000_000.0,
fetched_at=datetime.now(timezone.utc),
unavailable_fields={},
)
provider = ChainedFundamentalProvider([
("primary", _FailProvider("primary down")),
("fallback", _DataProvider(fallback_data)),
])
result = await provider.fetch_fundamentals("AAPL")
assert result.pe_ratio == 25.0
assert result.market_cap == 1_000_000.0
assert result.unavailable_fields.get("provider") == "fallback"
@pytest.mark.asyncio
async def test_chained_provider_raises_when_all_providers_fail():
provider = ChainedFundamentalProvider([
("p1", _FailProvider("p1 failed")),
("p2", _FailProvider("p2 failed")),
])
with pytest.raises(ProviderError) as exc:
await provider.fetch_fundamentals("MSFT")
assert "All fundamentals providers failed" in str(exc.value)

View File

@@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.exceptions import ProviderError
from app.providers.openai_sentiment import OpenAISentimentProvider
@@ -160,3 +161,42 @@ class TestCitationsExtraction:
assert result.citations == []
assert result.reasoning == "Quiet day"
class TestBatchSentiment:
@pytest.mark.asyncio
async def test_batch_sentiment_parses_multiple_tickers(self, provider):
json_text = (
'[{"ticker":"AAPL","classification":"bullish","confidence":81,"reasoning":"Positive earnings"},'
'{"ticker":"MSFT","classification":"neutral","confidence":52,"reasoning":"Mixed guidance"}]'
)
mock_response = _build_response(json_text)
provider._client.responses.create = AsyncMock(return_value=mock_response)
result = await provider.fetch_sentiment_batch(["AAPL", "MSFT"])
assert set(result.keys()) == {"AAPL", "MSFT"}
assert result["AAPL"].classification == "bullish"
assert result["MSFT"].classification == "neutral"
@pytest.mark.asyncio
async def test_batch_sentiment_skips_invalid_rows(self, provider):
json_text = (
'[{"ticker":"AAPL","classification":"bullish","confidence":81,"reasoning":"Positive earnings"},'
'{"ticker":"TSLA","classification":"invalid","confidence":95,"reasoning":"Bad shape"}]'
)
mock_response = _build_response(json_text)
provider._client.responses.create = AsyncMock(return_value=mock_response)
result = await provider.fetch_sentiment_batch(["AAPL", "MSFT"])
assert set(result.keys()) == {"AAPL"}
@pytest.mark.asyncio
async def test_batch_sentiment_requires_array_json(self, provider):
json_text = '{"ticker":"AAPL","classification":"bullish","confidence":81,"reasoning":"Positive earnings"}'
mock_response = _build_response(json_text)
provider._client.responses.create = AsyncMock(return_value=mock_response)
with pytest.raises(ProviderError):
await provider.fetch_sentiment_batch(["AAPL", "MSFT"])

View File

@@ -0,0 +1,130 @@
from __future__ import annotations
from dataclasses import dataclass
from app.services.recommendation_service import (
direction_analyzer,
probability_estimator,
signal_conflict_detector,
target_generator,
)
@dataclass
class _SRLevelStub:
id: int
price_level: float
type: str
strength: int
def test_high_confidence_long_example():
dimension_scores = {
"technical": 75.0,
"momentum": 68.0,
"fundamental": 55.0,
}
confidence = direction_analyzer.calculate_confidence(
direction="long",
dimension_scores=dimension_scores,
sentiment_classification="bullish",
conflicts=[],
)
assert confidence > 70.0
def test_high_confidence_short_example():
dimension_scores = {
"technical": 30.0,
"momentum": 35.0,
"fundamental": 45.0,
}
confidence = direction_analyzer.calculate_confidence(
direction="short",
dimension_scores=dimension_scores,
sentiment_classification="bearish",
conflicts=[],
)
assert confidence > 70.0
def test_detects_sentiment_technical_conflict():
conflicts = signal_conflict_detector.detect_conflicts(
dimension_scores={"technical": 72.0, "momentum": 55.0, "fundamental": 50.0},
sentiment_classification="bearish",
)
assert any("sentiment-technical" in conflict for conflict in conflicts)
def test_generate_targets_respects_direction_and_order():
sr_levels = [
_SRLevelStub(id=1, price_level=110.0, type="resistance", strength=80),
_SRLevelStub(id=2, price_level=115.0, type="resistance", strength=70),
_SRLevelStub(id=3, price_level=120.0, type="resistance", strength=60),
_SRLevelStub(id=4, price_level=95.0, type="support", strength=75),
]
targets = target_generator.generate_targets(
direction="long",
entry_price=100.0,
stop_loss=96.0,
sr_levels=sr_levels, # type: ignore[arg-type]
atr_value=2.0,
)
assert len(targets) >= 1
assert all(target["price"] > 100.0 for target in targets)
distances = [target["distance_from_entry"] for target in targets]
assert distances == sorted(distances)
def test_probability_ranges_by_classification():
config = {
"recommendation_signal_alignment_weight": 0.15,
"recommendation_sr_strength_weight": 0.20,
"recommendation_distance_penalty_factor": 0.10,
}
dimension_scores = {"technical": 70.0, "momentum": 70.0}
conservative = probability_estimator.estimate_probability(
{
"classification": "Conservative",
"sr_strength": 80,
"distance_atr_multiple": 1.5,
},
dimension_scores,
"bullish",
"long",
config,
)
moderate = probability_estimator.estimate_probability(
{
"classification": "Moderate",
"sr_strength": 60,
"distance_atr_multiple": 3.0,
},
dimension_scores,
"bullish",
"long",
config,
)
aggressive = probability_estimator.estimate_probability(
{
"classification": "Aggressive",
"sr_strength": 40,
"distance_atr_multiple": 6.0,
},
dimension_scores,
"bullish",
"long",
config,
)
assert conservative > 60
assert 40 <= moderate <= 70
assert aggressive < 50

View File

@@ -228,23 +228,23 @@ async def test_scan_ticker_full_flow_quality_selection_and_persistence(
)
# -- Assert: database persistence --
# Old dummy setup should be gone, only the 2 new setups should exist
# History is preserved: old setup remains, 2 new setups are appended
db_result = await scan_session.execute(
select(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
)
persisted = list(db_result.scalars().all())
assert len(persisted) == 2, (
f"Expected 2 persisted setups (old deleted), got {len(persisted)}"
assert len(persisted) == 3, (
f"Expected 3 persisted setups (1 old + 2 new), got {len(persisted)}"
)
persisted_directions = sorted(s.direction for s in persisted)
assert persisted_directions == ["long", "short"], (
f"Expected ['long', 'short'] persisted, got {persisted_directions}"
assert persisted_directions == ["long", "long", "short"], (
f"Expected ['long', 'long', 'short'] persisted, got {persisted_directions}"
)
# Verify persisted records match returned setups
persisted_long = [s for s in persisted if s.direction == "long"][0]
persisted_short = [s for s in persisted if s.direction == "short"][0]
# Verify latest persisted records match returned setups
persisted_long = max((s for s in persisted if s.direction == "long"), key=lambda s: s.id)
persisted_short = max((s for s in persisted if s.direction == "short"), key=lambda s: s.id)
assert persisted_long.target == long_setup.target
assert persisted_long.rr_ratio == long_setup.rr_ratio

View File

@@ -68,7 +68,7 @@ class TestResumeTickers:
class TestConfigureScheduler:
def test_configure_adds_four_jobs(self):
def test_configure_adds_five_jobs(self):
# Remove any existing jobs first
scheduler.remove_all_jobs()
configure_scheduler()
@@ -79,6 +79,7 @@ class TestConfigureScheduler:
"sentiment_collector",
"fundamental_collector",
"rr_scanner",
"ticker_universe_sync",
}
def test_configure_is_idempotent(self):
@@ -92,4 +93,5 @@ class TestConfigureScheduler:
"fundamental_collector",
"rr_scanner",
"sentiment_collector",
"ticker_universe_sync",
])

View File

@@ -0,0 +1,123 @@
"""Unit tests for ticker_universe_service bootstrap logic."""
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.database import Base
from app.exceptions import ProviderError
from app.models.settings import SystemSetting
from app.models.ticker import Ticker
from app.services import ticker_universe_service
_engine = create_async_engine("sqlite+aiosqlite://", echo=False)
_session_factory = async_sessionmaker(_engine, class_=AsyncSession, expire_on_commit=False)
@pytest.fixture(autouse=True)
async def _setup_tables() -> AsyncGenerator[None, None]:
async with _engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with _engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def session() -> AsyncGenerator[AsyncSession, None]:
async with _session_factory() as s:
yield s
@pytest.mark.asyncio
async def test_bootstrap_universe_adds_missing_symbols(session: AsyncSession, monkeypatch: pytest.MonkeyPatch):
session.add(Ticker(symbol="AAPL"))
await session.commit()
async def _fake_fetch(_db: AsyncSession, _universe: str) -> list[str]:
return ["AAPL", "MSFT", "NVDA"]
monkeypatch.setattr(ticker_universe_service, "fetch_universe_symbols", _fake_fetch)
result = await ticker_universe_service.bootstrap_universe(session, "sp500")
assert result["added"] == 2
assert result["already_tracked"] == 1
assert result["deleted"] == 0
rows = await session.execute(select(Ticker.symbol).order_by(Ticker.symbol.asc()))
assert list(rows.scalars().all()) == ["AAPL", "MSFT", "NVDA"]
@pytest.mark.asyncio
async def test_bootstrap_universe_prunes_missing_symbols(session: AsyncSession, monkeypatch: pytest.MonkeyPatch):
session.add_all([Ticker(symbol="AAPL"), Ticker(symbol="MSFT"), Ticker(symbol="TSLA")])
await session.commit()
async def _fake_fetch(_db: AsyncSession, _universe: str) -> list[str]:
return ["AAPL", "MSFT"]
monkeypatch.setattr(ticker_universe_service, "fetch_universe_symbols", _fake_fetch)
result = await ticker_universe_service.bootstrap_universe(
session,
"sp500",
prune_missing=True,
)
assert result["added"] == 0
assert result["already_tracked"] == 2
assert result["deleted"] == 1
rows = await session.execute(select(Ticker.symbol).order_by(Ticker.symbol.asc()))
assert list(rows.scalars().all()) == ["AAPL", "MSFT"]
@pytest.mark.asyncio
async def test_fetch_universe_symbols_uses_cached_snapshot_when_live_sources_fail(
session: AsyncSession,
monkeypatch: pytest.MonkeyPatch,
):
session.add(
SystemSetting(
key="ticker_universe_cache_sp500",
value=json.dumps({"symbols": ["AAPL", "MSFT"], "source": "test"}),
)
)
await session.commit()
async def _fake_public(_universe: str):
return [], ["public failed"], None
async def _fake_fmp(_universe: str):
raise ProviderError("fmp failed")
monkeypatch.setattr(ticker_universe_service, "_fetch_universe_symbols_from_public", _fake_public)
monkeypatch.setattr(ticker_universe_service, "_fetch_universe_symbols_from_fmp", _fake_fmp)
symbols = await ticker_universe_service.fetch_universe_symbols(session, "sp500")
assert symbols == ["AAPL", "MSFT"]
@pytest.mark.asyncio
async def test_fetch_universe_symbols_uses_seed_when_live_and_cache_fail(
session: AsyncSession,
monkeypatch: pytest.MonkeyPatch,
):
async def _fake_public(_universe: str):
return [], ["public failed"], None
async def _fake_fmp(_universe: str):
raise ProviderError("fmp failed")
monkeypatch.setattr(ticker_universe_service, "_fetch_universe_symbols_from_public", _fake_public)
monkeypatch.setattr(ticker_universe_service, "_fetch_universe_symbols_from_fmp", _fake_fmp)
symbols = await ticker_universe_service.fetch_universe_symbols(session, "sp500")
assert "AAPL" in symbols
assert len(symbols) > 10