106 lines
4.3 KiB
Python
106 lines
4.3 KiB
Python
"""Gemini sentiment provider using google-genai with search grounding."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
from google import genai
|
|
from google.genai import types
|
|
|
|
from app.exceptions import ProviderError, RateLimitError
|
|
from app.providers.protocol import SentimentData
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Ensure aiohttp's cached SSL context includes our corporate CA bundle.
|
|
# aiohttp creates _SSL_CONTEXT_VERIFIED at import time; we must patch it
|
|
# after import so that google-genai's aiohttp session trusts our proxy CA.
|
|
_CA_BUNDLE = os.environ.get("SSL_CERT_FILE", "")
|
|
if _CA_BUNDLE and Path(_CA_BUNDLE).exists():
|
|
try:
|
|
import aiohttp.connector as _aio_conn
|
|
if hasattr(_aio_conn, "_SSL_CONTEXT_VERIFIED") and _aio_conn._SSL_CONTEXT_VERIFIED is not None:
|
|
_aio_conn._SSL_CONTEXT_VERIFIED.load_verify_locations(cafile=_CA_BUNDLE)
|
|
logger.debug("Patched aiohttp _SSL_CONTEXT_VERIFIED with %s", _CA_BUNDLE)
|
|
except Exception:
|
|
logger.warning("Could not patch aiohttp SSL context", exc_info=True)
|
|
|
|
_SENTIMENT_PROMPT = """\
|
|
Analyze the current market sentiment for the stock ticker {ticker}.
|
|
Search the web for recent news articles, social media mentions, and analyst opinions.
|
|
|
|
Respond ONLY with a JSON object in this exact format (no markdown, no extra text):
|
|
{{"classification": "<bullish|bearish|neutral>", "confidence": <0-100>, "reasoning": "<brief explanation>"}}
|
|
|
|
Rules:
|
|
- classification must be exactly one of: bullish, bearish, neutral
|
|
- confidence must be an integer from 0 to 100
|
|
- reasoning should be a brief one-sentence explanation
|
|
"""
|
|
|
|
VALID_CLASSIFICATIONS = {"bullish", "bearish", "neutral"}
|
|
|
|
|
|
class GeminiSentimentProvider:
|
|
"""Fetches sentiment analysis from Gemini with search grounding."""
|
|
|
|
def __init__(self, api_key: str, model: str = "gemini-2.0-flash") -> None:
|
|
if not api_key:
|
|
raise ProviderError("Gemini API key is required")
|
|
self._client = genai.Client(api_key=api_key)
|
|
self._model = model
|
|
|
|
async def fetch_sentiment(self, ticker: str) -> SentimentData:
|
|
"""Send a structured prompt to Gemini and parse the JSON response."""
|
|
try:
|
|
response = await self._client.aio.models.generate_content(
|
|
model=self._model,
|
|
contents=_SENTIMENT_PROMPT.format(ticker=ticker),
|
|
config=types.GenerateContentConfig(
|
|
tools=[types.Tool(google_search=types.GoogleSearch())],
|
|
response_mime_type="application/json",
|
|
),
|
|
)
|
|
|
|
raw_text = response.text.strip()
|
|
logger.debug("Gemini raw response for %s: %s", ticker, raw_text)
|
|
parsed = json.loads(raw_text)
|
|
|
|
classification = parsed.get("classification", "").lower()
|
|
if classification not in VALID_CLASSIFICATIONS:
|
|
raise ProviderError(
|
|
f"Invalid classification '{classification}' from Gemini for {ticker}"
|
|
)
|
|
|
|
confidence = int(parsed.get("confidence", 50))
|
|
confidence = max(0, min(100, confidence))
|
|
|
|
reasoning = parsed.get("reasoning", "")
|
|
if reasoning:
|
|
logger.info("Gemini sentiment for %s: %s (confidence=%d) — %s",
|
|
ticker, classification, confidence, reasoning)
|
|
|
|
return SentimentData(
|
|
ticker=ticker,
|
|
classification=classification,
|
|
confidence=confidence,
|
|
source="gemini",
|
|
timestamp=datetime.now(timezone.utc),
|
|
)
|
|
|
|
except json.JSONDecodeError as exc:
|
|
logger.error("Failed to parse Gemini JSON for %s: %s", ticker, exc)
|
|
raise ProviderError(f"Invalid JSON from Gemini for {ticker}") from exc
|
|
except ProviderError:
|
|
raise
|
|
except Exception as exc:
|
|
msg = str(exc).lower()
|
|
if "429" in msg or "resource exhausted" in msg or "quota" in msg or ("rate" in msg and "limit" in msg):
|
|
raise RateLimitError(f"Gemini rate limit hit for {ticker}") from exc
|
|
logger.error("Gemini provider error for %s: %s", ticker, exc)
|
|
raise ProviderError(f"Gemini provider error for {ticker}: {exc}") from exc
|