Big refactoring
This commit is contained in:
72
tests/unit/test_fundamentals_chain_provider.py
Normal file
72
tests/unit/test_fundamentals_chain_provider.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Unit tests for chained fundamentals provider fallback behavior."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from app.exceptions import ProviderError
|
||||
from app.providers.fundamentals_chain import ChainedFundamentalProvider
|
||||
from app.providers.protocol import FundamentalData
|
||||
|
||||
|
||||
class _FailProvider:
|
||||
def __init__(self, message: str) -> None:
|
||||
self._message = message
|
||||
|
||||
async def fetch_fundamentals(self, ticker: str) -> FundamentalData:
|
||||
raise ProviderError(f"{self._message} ({ticker})")
|
||||
|
||||
|
||||
class _DataProvider:
|
||||
def __init__(self, data: FundamentalData) -> None:
|
||||
self._data = data
|
||||
|
||||
async def fetch_fundamentals(self, ticker: str) -> FundamentalData:
|
||||
return FundamentalData(
|
||||
ticker=ticker,
|
||||
pe_ratio=self._data.pe_ratio,
|
||||
revenue_growth=self._data.revenue_growth,
|
||||
earnings_surprise=self._data.earnings_surprise,
|
||||
market_cap=self._data.market_cap,
|
||||
fetched_at=self._data.fetched_at,
|
||||
unavailable_fields=self._data.unavailable_fields,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chained_provider_uses_fallback_provider_on_primary_failure():
|
||||
fallback_data = FundamentalData(
|
||||
ticker="AAPL",
|
||||
pe_ratio=25.0,
|
||||
revenue_growth=None,
|
||||
earnings_surprise=None,
|
||||
market_cap=1_000_000.0,
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
unavailable_fields={},
|
||||
)
|
||||
|
||||
provider = ChainedFundamentalProvider([
|
||||
("primary", _FailProvider("primary down")),
|
||||
("fallback", _DataProvider(fallback_data)),
|
||||
])
|
||||
|
||||
result = await provider.fetch_fundamentals("AAPL")
|
||||
|
||||
assert result.pe_ratio == 25.0
|
||||
assert result.market_cap == 1_000_000.0
|
||||
assert result.unavailable_fields.get("provider") == "fallback"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chained_provider_raises_when_all_providers_fail():
|
||||
provider = ChainedFundamentalProvider([
|
||||
("p1", _FailProvider("p1 failed")),
|
||||
("p2", _FailProvider("p2 failed")),
|
||||
])
|
||||
|
||||
with pytest.raises(ProviderError) as exc:
|
||||
await provider.fetch_fundamentals("MSFT")
|
||||
|
||||
assert "All fundamentals providers failed" in str(exc.value)
|
||||
Reference in New Issue
Block a user