"""Unit tests for chained fundamentals provider fallback behavior.""" from __future__ import annotations from datetime import datetime, timezone import pytest from app.exceptions import ProviderError from app.providers.fundamentals_chain import ChainedFundamentalProvider from app.providers.protocol import FundamentalData class _FailProvider: def __init__(self, message: str) -> None: self._message = message async def fetch_fundamentals(self, ticker: str) -> FundamentalData: raise ProviderError(f"{self._message} ({ticker})") class _DataProvider: def __init__(self, data: FundamentalData) -> None: self._data = data async def fetch_fundamentals(self, ticker: str) -> FundamentalData: return FundamentalData( ticker=ticker, pe_ratio=self._data.pe_ratio, revenue_growth=self._data.revenue_growth, earnings_surprise=self._data.earnings_surprise, market_cap=self._data.market_cap, fetched_at=self._data.fetched_at, unavailable_fields=self._data.unavailable_fields, ) @pytest.mark.asyncio async def test_chained_provider_uses_fallback_provider_on_primary_failure(): fallback_data = FundamentalData( ticker="AAPL", pe_ratio=25.0, revenue_growth=None, earnings_surprise=None, market_cap=1_000_000.0, fetched_at=datetime.now(timezone.utc), unavailable_fields={}, ) provider = ChainedFundamentalProvider([ ("primary", _FailProvider("primary down")), ("fallback", _DataProvider(fallback_data)), ]) result = await provider.fetch_fundamentals("AAPL") assert result.pe_ratio == 25.0 assert result.market_cap == 1_000_000.0 assert result.unavailable_fields.get("provider") == "fallback" @pytest.mark.asyncio async def test_chained_provider_raises_when_all_providers_fail(): provider = ChainedFundamentalProvider([ ("p1", _FailProvider("p1 failed")), ("p2", _FailProvider("p2 failed")), ]) with pytest.raises(ProviderError) as exc: await provider.fetch_fundamentals("MSFT") assert "All fundamentals providers failed" in str(exc.value)