"""Technical Analysis service. Computes indicators from OHLCV data. Each indicator function is a pure function that takes a list of OHLCV-like records and returns raw values plus a normalized 0-100 score. The service layer handles DB fetching, caching, and minimum-data validation. """ from __future__ import annotations from datetime import date from typing import Any from sqlalchemy.ext.asyncio import AsyncSession from app.cache import indicator_cache from app.exceptions import ValidationError from app.services.price_service import query_ohlcv # --------------------------------------------------------------------------- # Minimum data requirements per indicator # --------------------------------------------------------------------------- MIN_BARS: dict[str, int] = { "adx": 28, "ema": 0, # dynamic: period + 1 "rsi": 15, "atr": 15, "volume_profile": 20, "pivot_points": 5, } DEFAULT_PERIODS: dict[str, int] = { "adx": 14, "ema": 20, "rsi": 14, "atr": 14, } # --------------------------------------------------------------------------- # Pure computation helpers # --------------------------------------------------------------------------- def _ema(values: list[float], period: int) -> list[float]: """Compute EMA series. Returns list same length as *values*.""" if len(values) < period: return [] k = 2.0 / (period + 1) ema_vals: list[float] = [sum(values[:period]) / period] for v in values[period:]: ema_vals.append(v * k + ema_vals[-1] * (1 - k)) return ema_vals def compute_adx( highs: list[float], lows: list[float], closes: list[float], period: int = 14, ) -> dict[str, Any]: """Compute ADX from high/low/close arrays. Returns dict with ``adx``, ``plus_di``, ``minus_di``, ``score``. """ n = len(closes) if n < 2 * period: raise ValidationError( f"ADX requires at least {2 * period} bars, got {n}" ) # True Range, +DM, -DM tr_list: list[float] = [] plus_dm: list[float] = [] minus_dm: list[float] = [] for i in range(1, n): h, low_val, pc = highs[i], lows[i], closes[i - 1] tr_list.append(max(h - low_val, abs(h - pc), abs(low_val - pc))) up = highs[i] - highs[i - 1] down = lows[i - 1] - lows[i] plus_dm.append(up if up > down and up > 0 else 0.0) minus_dm.append(down if down > up and down > 0 else 0.0) # Smoothed TR, +DM, -DM (Wilder smoothing) def _smooth(vals: list[float], p: int) -> list[float]: s = [sum(vals[:p])] for v in vals[p:]: s.append(s[-1] - s[-1] / p + v) return s s_tr = _smooth(tr_list, period) s_plus = _smooth(plus_dm, period) s_minus = _smooth(minus_dm, period) # +DI, -DI, DX dx_list: list[float] = [] plus_di_last = 0.0 minus_di_last = 0.0 for i in range(len(s_tr)): tr_v = s_tr[i] if s_tr[i] != 0 else 1e-10 pdi = 100.0 * s_plus[i] / tr_v mdi = 100.0 * s_minus[i] / tr_v denom = pdi + mdi if (pdi + mdi) != 0 else 1e-10 dx_list.append(100.0 * abs(pdi - mdi) / denom) plus_di_last = pdi minus_di_last = mdi # ADX = smoothed DX if len(dx_list) < period: adx_val = sum(dx_list) / len(dx_list) if dx_list else 0.0 else: adx_vals = _smooth(dx_list, period) adx_val = adx_vals[-1] score = max(0.0, min(100.0, adx_val)) return { "adx": round(adx_val, 4), "plus_di": round(plus_di_last, 4), "minus_di": round(minus_di_last, 4), "score": round(score, 4), } def compute_ema( closes: list[float], period: int = 20, ) -> dict[str, Any]: """Compute EMA for *closes* with given *period*. Score: normalized position of latest close relative to EMA. Above EMA → higher score, below → lower. """ min_bars = period + 1 if len(closes) < min_bars: raise ValidationError( f"EMA({period}) requires at least {min_bars} bars, got {len(closes)}" ) ema_vals = _ema(closes, period) latest_ema = ema_vals[-1] latest_close = closes[-1] # Score: 50 = at EMA, 100 = 5%+ above, 0 = 5%+ below if latest_ema == 0: pct = 0.0 else: pct = (latest_close - latest_ema) / latest_ema * 100.0 score = max(0.0, min(100.0, 50.0 + pct * 10.0)) return { "ema": round(latest_ema, 4), "period": period, "latest_close": round(latest_close, 4), "score": round(score, 4), } def compute_rsi( closes: list[float], period: int = 14, ) -> dict[str, Any]: """Compute RSI. Score = RSI value (already 0-100).""" n = len(closes) if n < period + 1: raise ValidationError( f"RSI requires at least {period + 1} bars, got {n}" ) deltas = [closes[i] - closes[i - 1] for i in range(1, n)] gains = [d if d > 0 else 0.0 for d in deltas] losses = [-d if d < 0 else 0.0 for d in deltas] avg_gain = sum(gains[:period]) / period avg_loss = sum(losses[:period]) / period for i in range(period, len(deltas)): avg_gain = (avg_gain * (period - 1) + gains[i]) / period avg_loss = (avg_loss * (period - 1) + losses[i]) / period if avg_loss == 0: rsi = 100.0 else: rs = avg_gain / avg_loss rsi = 100.0 - 100.0 / (1.0 + rs) score = max(0.0, min(100.0, rsi)) return { "rsi": round(rsi, 4), "period": period, "score": round(score, 4), } def compute_atr( highs: list[float], lows: list[float], closes: list[float], period: int = 14, ) -> dict[str, Any]: """Compute ATR. Score = normalized inverse (lower ATR = higher score).""" n = len(closes) if n < period + 1: raise ValidationError( f"ATR requires at least {period + 1} bars, got {n}" ) tr_list: list[float] = [] for i in range(1, n): h, low_val, pc = highs[i], lows[i], closes[i - 1] tr_list.append(max(h - low_val, abs(h - pc), abs(low_val - pc))) # Wilder smoothing atr = sum(tr_list[:period]) / period for tr in tr_list[period:]: atr = (atr * (period - 1) + tr) / period # Score: inverse normalized. ATR as % of price; lower = higher score. latest_close = closes[-1] if latest_close == 0: atr_pct = 0.0 else: atr_pct = atr / latest_close * 100.0 # 0% ATR → 100 score, 10%+ ATR → 0 score score = max(0.0, min(100.0, 100.0 - atr_pct * 10.0)) return { "atr": round(atr, 4), "period": period, "atr_percent": round(atr_pct, 4), "score": round(score, 4), } def compute_volume_profile( highs: list[float], lows: list[float], closes: list[float], volumes: list[int], num_bins: int = 20, ) -> dict[str, Any]: """Compute Volume Profile: POC, Value Area, HVN, LVN. Score: proximity of latest close to POC (closer = higher). """ n = len(closes) if n < 20: raise ValidationError( f"Volume Profile requires at least 20 bars, got {n}" ) price_min = min(lows) price_max = max(highs) if price_max == price_min: price_max = price_min + 1.0 # avoid zero-width range bin_width = (price_max - price_min) / num_bins bins: list[float] = [0.0] * num_bins bin_prices: list[float] = [ price_min + (i + 0.5) * bin_width for i in range(num_bins) ] for i in range(n): # Distribute volume across bins the bar spans bar_low, bar_high = lows[i], highs[i] for b in range(num_bins): bl = price_min + b * bin_width bh = bl + bin_width if bar_high >= bl and bar_low <= bh: bins[b] += volumes[i] total_vol = sum(bins) if total_vol == 0: total_vol = 1.0 # POC = bin with highest volume poc_idx = bins.index(max(bins)) poc = round(bin_prices[poc_idx], 4) # Value Area: 70% of total volume around POC sorted_bins = sorted(range(num_bins), key=lambda i: bins[i], reverse=True) va_vol = 0.0 va_indices: list[int] = [] for idx in sorted_bins: va_vol += bins[idx] va_indices.append(idx) if va_vol >= total_vol * 0.7: break va_low = round(price_min + min(va_indices) * bin_width, 4) va_high = round(price_min + (max(va_indices) + 1) * bin_width, 4) # HVN / LVN: bins above/below average volume avg_vol = total_vol / num_bins hvn = [round(bin_prices[i], 4) for i in range(num_bins) if bins[i] > avg_vol] lvn = [round(bin_prices[i], 4) for i in range(num_bins) if bins[i] < avg_vol] # Score: proximity of latest close to POC latest = closes[-1] price_range = price_max - price_min if price_range == 0: score = 100.0 else: dist_pct = abs(latest - poc) / price_range score = max(0.0, min(100.0, 100.0 * (1.0 - dist_pct))) return { "poc": poc, "value_area_low": va_low, "value_area_high": va_high, "hvn": hvn, "lvn": lvn, "score": round(score, 4), } def compute_pivot_points( highs: list[float], lows: list[float], closes: list[float], window: int = 2, ) -> dict[str, Any]: """Detect swing highs/lows as pivot points. A swing high at index *i* means highs[i] >= all highs in [i-window, i+window]. Score: based on number of pivots near current price. """ n = len(closes) if n < 5: raise ValidationError( f"Pivot Points requires at least 5 bars, got {n}" ) swing_highs: list[float] = [] swing_lows: list[float] = [] for i in range(window, n - window): # Swing high if all(highs[i] >= highs[j] for j in range(i - window, i + window + 1)): swing_highs.append(round(highs[i], 4)) # Swing low if all(lows[i] <= lows[j] for j in range(i - window, i + window + 1)): swing_lows.append(round(lows[i], 4)) all_pivots = swing_highs + swing_lows latest = closes[-1] # Score: fraction of pivots within 2% of current price → 0-100 if not all_pivots or latest == 0: score = 0.0 else: near = sum(1 for p in all_pivots if abs(p - latest) / latest <= 0.02) score = min(100.0, (near / max(len(all_pivots), 1)) * 100.0) return { "swing_highs": swing_highs, "swing_lows": swing_lows, "pivot_count": len(all_pivots), "score": round(score, 4), } def compute_ema_cross( closes: list[float], short_period: int = 20, long_period: int = 50, tolerance: float = 1e-6, ) -> dict[str, Any]: """Compare short EMA vs long EMA. Returns signal: bullish (short > long), bearish (short < long), neutral (within tolerance). """ min_bars = long_period + 1 if len(closes) < min_bars: raise ValidationError( f"EMA Cross requires at least {min_bars} bars, got {len(closes)}" ) short_ema_vals = _ema(closes, short_period) long_ema_vals = _ema(closes, long_period) short_ema = short_ema_vals[-1] long_ema = long_ema_vals[-1] diff = short_ema - long_ema if abs(diff) <= tolerance: signal = "neutral" elif diff > 0: signal = "bullish" else: signal = "bearish" return { "short_ema": round(short_ema, 4), "long_ema": round(long_ema, 4), "short_period": short_period, "long_period": long_period, "signal": signal, } # --------------------------------------------------------------------------- # Supported indicator types # --------------------------------------------------------------------------- INDICATOR_TYPES = {"adx", "ema", "rsi", "atr", "volume_profile", "pivot_points"} # --------------------------------------------------------------------------- # Service-layer functions (DB + cache + validation) # --------------------------------------------------------------------------- def _extract_ohlcv(records: list) -> tuple[ list[float], list[float], list[float], list[float], list[int] ]: """Extract parallel arrays from OHLCVRecord list.""" opens = [float(r.open) for r in records] highs = [float(r.high) for r in records] lows = [float(r.low) for r in records] closes = [float(r.close) for r in records] volumes = [int(r.volume) for r in records] return opens, highs, lows, closes, volumes async def get_indicator( db: AsyncSession, symbol: str, indicator_type: str, start_date: date | None = None, end_date: date | None = None, period: int | None = None, ) -> dict[str, Any]: """Compute a single indicator for *symbol*. Checks cache first; stores result after computing. """ indicator_type = indicator_type.lower() if indicator_type not in INDICATOR_TYPES: raise ValidationError( f"Unknown indicator type: {indicator_type}. " f"Supported: {', '.join(sorted(INDICATOR_TYPES))}" ) cache_key = (symbol.upper(), str(start_date), str(end_date), indicator_type) cached = indicator_cache.get(cache_key) if cached is not None: return cached records = await query_ohlcv(db, symbol, start_date, end_date) _, highs, lows, closes, volumes = _extract_ohlcv(records) n = len(records) if indicator_type == "adx": p = period or DEFAULT_PERIODS["adx"] result = compute_adx(highs, lows, closes, period=p) elif indicator_type == "ema": p = period or DEFAULT_PERIODS["ema"] result = compute_ema(closes, period=p) elif indicator_type == "rsi": p = period or DEFAULT_PERIODS["rsi"] result = compute_rsi(closes, period=p) elif indicator_type == "atr": p = period or DEFAULT_PERIODS["atr"] result = compute_atr(highs, lows, closes, period=p) elif indicator_type == "volume_profile": result = compute_volume_profile(highs, lows, closes, volumes) elif indicator_type == "pivot_points": result = compute_pivot_points(highs, lows, closes) else: raise ValidationError(f"Unknown indicator type: {indicator_type}") response = { "indicator_type": indicator_type, "values": {k: v for k, v in result.items() if k != "score"}, "score": result["score"], "bars_used": n, } indicator_cache.set(cache_key, response) return response async def get_ema_cross( db: AsyncSession, symbol: str, start_date: date | None = None, end_date: date | None = None, short_period: int = 20, long_period: int = 50, ) -> dict[str, Any]: """Compute EMA cross signal for *symbol*.""" cache_key = ( symbol.upper(), str(start_date), str(end_date), f"ema_cross_{short_period}_{long_period}", ) cached = indicator_cache.get(cache_key) if cached is not None: return cached records = await query_ohlcv(db, symbol, start_date, end_date) _, _, _, closes, _ = _extract_ohlcv(records) result = compute_ema_cross(closes, short_period, long_period) indicator_cache.set(cache_key, result) return result