Big refactoring
This commit is contained in:
@@ -33,6 +33,24 @@ Rules:
|
||||
- reasoning should cite specific recent news or events you found
|
||||
"""
|
||||
|
||||
_SENTIMENT_BATCH_PROMPT = """\
|
||||
Search the web for the LATEST news, analyst opinions, and market developments \
|
||||
about each stock ticker from the past 24-48 hours.
|
||||
|
||||
Tickers:
|
||||
{tickers_csv}
|
||||
|
||||
Respond ONLY with a JSON array (no markdown, no extra text), one object per ticker:
|
||||
[{{"ticker":"AAPL","classification":"bullish|bearish|neutral","confidence":0-100,"reasoning":"brief explanation"}}]
|
||||
|
||||
Rules:
|
||||
- Include every ticker exactly once
|
||||
- ticker must be uppercase symbol
|
||||
- classification must be exactly one of: bullish, bearish, neutral
|
||||
- confidence must be an integer from 0 to 100
|
||||
- reasoning should cite specific recent news or events you found
|
||||
"""
|
||||
|
||||
VALID_CLASSIFICATIONS = {"bullish", "bearish", "neutral"}
|
||||
|
||||
|
||||
@@ -49,6 +67,59 @@ class OpenAISentimentProvider:
|
||||
self._client = AsyncOpenAI(api_key=api_key, http_client=http_client)
|
||||
self._model = model
|
||||
|
||||
@staticmethod
|
||||
def _extract_raw_text(response: object, ticker_context: str) -> str:
|
||||
raw_text = ""
|
||||
for item in response.output:
|
||||
if item.type == "message" and item.content:
|
||||
for block in item.content:
|
||||
if hasattr(block, "text") and block.text:
|
||||
raw_text = block.text
|
||||
break
|
||||
if raw_text:
|
||||
break
|
||||
|
||||
if not raw_text:
|
||||
raise ProviderError(f"No text output from OpenAI for {ticker_context}")
|
||||
|
||||
clean = raw_text.strip()
|
||||
if clean.startswith("```"):
|
||||
clean = clean.split("\n", 1)[1] if "\n" in clean else clean[3:]
|
||||
if clean.endswith("```"):
|
||||
clean = clean[:-3]
|
||||
return clean.strip()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_single_result(parsed: dict, ticker: str, citations: list[dict[str, str]]) -> SentimentData:
|
||||
classification = str(parsed.get("classification", "")).lower()
|
||||
if classification not in VALID_CLASSIFICATIONS:
|
||||
raise ProviderError(
|
||||
f"Invalid classification '{classification}' from OpenAI for {ticker}"
|
||||
)
|
||||
|
||||
confidence = int(parsed.get("confidence", 50))
|
||||
confidence = max(0, min(100, confidence))
|
||||
reasoning = str(parsed.get("reasoning", ""))
|
||||
|
||||
if reasoning:
|
||||
logger.info(
|
||||
"OpenAI sentiment for %s: %s (confidence=%d) — %s",
|
||||
ticker,
|
||||
classification,
|
||||
confidence,
|
||||
reasoning,
|
||||
)
|
||||
|
||||
return SentimentData(
|
||||
ticker=ticker,
|
||||
classification=classification,
|
||||
confidence=confidence,
|
||||
source="openai",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
reasoning=reasoning,
|
||||
citations=citations,
|
||||
)
|
||||
|
||||
async def fetch_sentiment(self, ticker: str) -> SentimentData:
|
||||
"""Use the Responses API with web_search_preview to get live sentiment."""
|
||||
try:
|
||||
@@ -58,48 +129,10 @@ class OpenAISentimentProvider:
|
||||
instructions="You are a financial sentiment analyst. Always respond with valid JSON only, no markdown fences.",
|
||||
input=_SENTIMENT_PROMPT.format(ticker=ticker),
|
||||
)
|
||||
|
||||
# Extract text from the ResponseOutputMessage in the output
|
||||
raw_text = ""
|
||||
for item in response.output:
|
||||
if item.type == "message" and item.content:
|
||||
for block in item.content:
|
||||
if hasattr(block, "text") and block.text:
|
||||
raw_text = block.text
|
||||
break
|
||||
if raw_text:
|
||||
break
|
||||
|
||||
if not raw_text:
|
||||
raise ProviderError(f"No text output from OpenAI for {ticker}")
|
||||
|
||||
raw_text = raw_text.strip()
|
||||
logger.debug("OpenAI raw response for %s: %s", ticker, raw_text)
|
||||
|
||||
# Strip markdown fences if present
|
||||
clean = raw_text
|
||||
if clean.startswith("```"):
|
||||
clean = clean.split("\n", 1)[1] if "\n" in clean else clean[3:]
|
||||
if clean.endswith("```"):
|
||||
clean = clean[:-3]
|
||||
clean = clean.strip()
|
||||
|
||||
clean = self._extract_raw_text(response, ticker)
|
||||
logger.debug("OpenAI raw response for %s: %s", ticker, clean)
|
||||
parsed = json.loads(clean)
|
||||
|
||||
classification = parsed.get("classification", "").lower()
|
||||
if classification not in VALID_CLASSIFICATIONS:
|
||||
raise ProviderError(
|
||||
f"Invalid classification '{classification}' from OpenAI for {ticker}"
|
||||
)
|
||||
|
||||
confidence = int(parsed.get("confidence", 50))
|
||||
confidence = max(0, min(100, confidence))
|
||||
|
||||
reasoning = parsed.get("reasoning", "")
|
||||
if reasoning:
|
||||
logger.info("OpenAI sentiment for %s: %s (confidence=%d) — %s",
|
||||
ticker, classification, confidence, reasoning)
|
||||
|
||||
# Extract url_citation annotations from response output
|
||||
citations: list[dict[str, str]] = []
|
||||
for item in response.output:
|
||||
@@ -112,19 +145,10 @@ class OpenAISentimentProvider:
|
||||
"url": getattr(annotation, "url", ""),
|
||||
"title": getattr(annotation, "title", ""),
|
||||
})
|
||||
|
||||
return SentimentData(
|
||||
ticker=ticker,
|
||||
classification=classification,
|
||||
confidence=confidence,
|
||||
source="openai",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
reasoning=reasoning,
|
||||
citations=citations,
|
||||
)
|
||||
return self._normalize_single_result(parsed, ticker, citations)
|
||||
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.error("Failed to parse OpenAI JSON for %s: %s — raw: %s", ticker, exc, raw_text)
|
||||
logger.error("Failed to parse OpenAI JSON for %s: %s", ticker, exc)
|
||||
raise ProviderError(f"Invalid JSON from OpenAI for {ticker}") from exc
|
||||
except ProviderError:
|
||||
raise
|
||||
@@ -134,3 +158,49 @@ class OpenAISentimentProvider:
|
||||
raise RateLimitError(f"OpenAI rate limit hit for {ticker}") from exc
|
||||
logger.error("OpenAI provider error for %s: %s", ticker, exc)
|
||||
raise ProviderError(f"OpenAI provider error for {ticker}: {exc}") from exc
|
||||
|
||||
async def fetch_sentiment_batch(self, tickers: list[str]) -> dict[str, SentimentData]:
|
||||
"""Fetch sentiment for multiple tickers in one OpenAI request.
|
||||
|
||||
Returns a map keyed by uppercase ticker symbol. Invalid/missing rows are skipped.
|
||||
"""
|
||||
normalized = [t.strip().upper() for t in tickers if t and t.strip()]
|
||||
if not normalized:
|
||||
return {}
|
||||
|
||||
ticker_context = ",".join(normalized)
|
||||
try:
|
||||
response = await self._client.responses.create(
|
||||
model=self._model,
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
instructions="You are a financial sentiment analyst. Always respond with valid JSON only, no markdown fences.",
|
||||
input=_SENTIMENT_BATCH_PROMPT.format(tickers_csv=", ".join(normalized)),
|
||||
)
|
||||
clean = self._extract_raw_text(response, ticker_context)
|
||||
logger.debug("OpenAI batch raw response for %s: %s", ticker_context, clean)
|
||||
parsed = json.loads(clean)
|
||||
if not isinstance(parsed, list):
|
||||
raise ProviderError("Batch sentiment response must be a JSON array")
|
||||
|
||||
out: dict[str, SentimentData] = {}
|
||||
requested = set(normalized)
|
||||
for row in parsed:
|
||||
if not isinstance(row, dict):
|
||||
continue
|
||||
symbol = str(row.get("ticker", "")).strip().upper()
|
||||
if symbol not in requested:
|
||||
continue
|
||||
try:
|
||||
out[symbol] = self._normalize_single_result(row, symbol, citations=[])
|
||||
except Exception:
|
||||
continue
|
||||
return out
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ProviderError(f"Invalid batch JSON from OpenAI for {ticker_context}") from exc
|
||||
except ProviderError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
msg = str(exc).lower()
|
||||
if "429" in msg or "rate" in msg or "quota" in msg:
|
||||
raise RateLimitError(f"OpenAI rate limit hit for batch {ticker_context}") from exc
|
||||
raise ProviderError(f"OpenAI batch provider error for {ticker_context}: {exc}") from exc
|
||||
|
||||
Reference in New Issue
Block a user