Big refactoring
This commit is contained in:
@@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.exceptions import ProviderError
|
||||
from app.providers.openai_sentiment import OpenAISentimentProvider
|
||||
|
||||
|
||||
@@ -160,3 +161,42 @@ class TestCitationsExtraction:
|
||||
|
||||
assert result.citations == []
|
||||
assert result.reasoning == "Quiet day"
|
||||
|
||||
|
||||
class TestBatchSentiment:
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_sentiment_parses_multiple_tickers(self, provider):
|
||||
json_text = (
|
||||
'[{"ticker":"AAPL","classification":"bullish","confidence":81,"reasoning":"Positive earnings"},'
|
||||
'{"ticker":"MSFT","classification":"neutral","confidence":52,"reasoning":"Mixed guidance"}]'
|
||||
)
|
||||
mock_response = _build_response(json_text)
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await provider.fetch_sentiment_batch(["AAPL", "MSFT"])
|
||||
|
||||
assert set(result.keys()) == {"AAPL", "MSFT"}
|
||||
assert result["AAPL"].classification == "bullish"
|
||||
assert result["MSFT"].classification == "neutral"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_sentiment_skips_invalid_rows(self, provider):
|
||||
json_text = (
|
||||
'[{"ticker":"AAPL","classification":"bullish","confidence":81,"reasoning":"Positive earnings"},'
|
||||
'{"ticker":"TSLA","classification":"invalid","confidence":95,"reasoning":"Bad shape"}]'
|
||||
)
|
||||
mock_response = _build_response(json_text)
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await provider.fetch_sentiment_batch(["AAPL", "MSFT"])
|
||||
|
||||
assert set(result.keys()) == {"AAPL"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_sentiment_requires_array_json(self, provider):
|
||||
json_text = '{"ticker":"AAPL","classification":"bullish","confidence":81,"reasoning":"Positive earnings"}'
|
||||
mock_response = _build_response(json_text)
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with pytest.raises(ProviderError):
|
||||
await provider.fetch_sentiment_batch(["AAPL", "MSFT"])
|
||||
|
||||
Reference in New Issue
Block a user