first commit
This commit is contained in:
274
app/services/sr_service.py
Normal file
274
app/services/sr_service.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""S/R Detector service.
|
||||
|
||||
Detects support/resistance levels from Volume Profile (HVN/LVN) and
|
||||
Pivot Points (swing highs/lows), assigns strength scores, merges nearby
|
||||
levels, tags as support/resistance, and persists to DB.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.exceptions import NotFoundError, ValidationError
|
||||
from app.models.sr_level import SRLevel
|
||||
from app.models.ticker import Ticker
|
||||
from app.services.indicator_service import (
|
||||
_extract_ohlcv,
|
||||
compute_pivot_points,
|
||||
compute_volume_profile,
|
||||
)
|
||||
from app.services.price_service import query_ohlcv
|
||||
|
||||
DEFAULT_TOLERANCE = 0.005 # 0.5%
|
||||
|
||||
|
||||
async def _get_ticker(db: AsyncSession, symbol: str) -> Ticker:
|
||||
"""Look up a ticker by symbol."""
|
||||
normalised = symbol.strip().upper()
|
||||
result = await db.execute(select(Ticker).where(Ticker.symbol == normalised))
|
||||
ticker = result.scalar_one_or_none()
|
||||
if ticker is None:
|
||||
raise NotFoundError(f"Ticker not found: {normalised}")
|
||||
return ticker
|
||||
|
||||
|
||||
def _count_price_touches(
|
||||
price_level: float,
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
tolerance: float = DEFAULT_TOLERANCE,
|
||||
) -> int:
|
||||
"""Count how many bars touched/respected a price level within tolerance."""
|
||||
count = 0
|
||||
tol = price_level * tolerance if price_level != 0 else tolerance
|
||||
for i in range(len(closes)):
|
||||
# A bar "touches" the level if the level is within the bar's range
|
||||
# (within tolerance)
|
||||
if lows[i] - tol <= price_level <= highs[i] + tol:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def _strength_from_touches(touches: int, total_bars: int) -> int:
|
||||
"""Convert touch count to a 0-100 strength score.
|
||||
|
||||
More touches relative to total bars = higher strength.
|
||||
Cap at 100.
|
||||
"""
|
||||
if total_bars == 0:
|
||||
return 0
|
||||
# Scale: each touch contributes proportionally, with a multiplier
|
||||
# so that a level touched ~20% of bars gets score ~100
|
||||
raw = (touches / total_bars) * 500.0
|
||||
return max(0, min(100, int(round(raw))))
|
||||
|
||||
|
||||
def _extract_candidate_levels(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
volumes: list[int],
|
||||
) -> list[tuple[float, str]]:
|
||||
"""Extract candidate S/R levels from Volume Profile and Pivot Points.
|
||||
|
||||
Returns list of (price_level, detection_method) tuples.
|
||||
"""
|
||||
candidates: list[tuple[float, str]] = []
|
||||
|
||||
# Volume Profile: HVN and LVN as candidate levels
|
||||
try:
|
||||
vp = compute_volume_profile(highs, lows, closes, volumes)
|
||||
for price in vp.get("hvn", []):
|
||||
candidates.append((price, "volume_profile"))
|
||||
for price in vp.get("lvn", []):
|
||||
candidates.append((price, "volume_profile"))
|
||||
except ValidationError:
|
||||
pass # Not enough data for volume profile
|
||||
|
||||
# Pivot Points: swing highs and lows
|
||||
try:
|
||||
pp = compute_pivot_points(highs, lows, closes)
|
||||
for price in pp.get("swing_highs", []):
|
||||
candidates.append((price, "pivot_point"))
|
||||
for price in pp.get("swing_lows", []):
|
||||
candidates.append((price, "pivot_point"))
|
||||
except ValidationError:
|
||||
pass # Not enough data for pivot points
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
def _merge_levels(
|
||||
levels: list[dict],
|
||||
tolerance: float = DEFAULT_TOLERANCE,
|
||||
) -> list[dict]:
|
||||
"""Merge levels within tolerance into consolidated levels.
|
||||
|
||||
Levels from different methods within tolerance are merged.
|
||||
Merged levels combine strength scores (capped at 100) and get
|
||||
detection_method = "merged".
|
||||
"""
|
||||
if not levels:
|
||||
return []
|
||||
|
||||
# Sort by price
|
||||
sorted_levels = sorted(levels, key=lambda x: x["price_level"])
|
||||
merged: list[dict] = []
|
||||
|
||||
for level in sorted_levels:
|
||||
if not merged:
|
||||
merged.append(dict(level))
|
||||
continue
|
||||
|
||||
last = merged[-1]
|
||||
ref_price = last["price_level"]
|
||||
tol = ref_price * tolerance if ref_price != 0 else tolerance
|
||||
|
||||
if abs(level["price_level"] - ref_price) <= tol:
|
||||
# Merge: average price, combine strength, mark as merged
|
||||
combined_strength = min(100, last["strength"] + level["strength"])
|
||||
avg_price = (last["price_level"] + level["price_level"]) / 2.0
|
||||
method = (
|
||||
"merged"
|
||||
if last["detection_method"] != level["detection_method"]
|
||||
else last["detection_method"]
|
||||
)
|
||||
last["price_level"] = round(avg_price, 4)
|
||||
last["strength"] = combined_strength
|
||||
last["detection_method"] = method
|
||||
else:
|
||||
merged.append(dict(level))
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def _tag_levels(
|
||||
levels: list[dict],
|
||||
current_price: float,
|
||||
) -> list[dict]:
|
||||
"""Tag each level as 'support' or 'resistance' relative to current price."""
|
||||
for level in levels:
|
||||
if level["price_level"] < current_price:
|
||||
level["type"] = "support"
|
||||
else:
|
||||
level["type"] = "resistance"
|
||||
return levels
|
||||
|
||||
|
||||
def detect_sr_levels(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
volumes: list[int],
|
||||
tolerance: float = DEFAULT_TOLERANCE,
|
||||
) -> list[dict]:
|
||||
"""Detect, score, merge, and tag S/R levels from OHLCV data.
|
||||
|
||||
Returns list of dicts with keys: price_level, type, strength,
|
||||
detection_method — sorted by strength descending.
|
||||
"""
|
||||
if not closes:
|
||||
return []
|
||||
|
||||
candidates = _extract_candidate_levels(highs, lows, closes, volumes)
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
total_bars = len(closes)
|
||||
current_price = closes[-1]
|
||||
|
||||
# Build level dicts with strength scores
|
||||
raw_levels: list[dict] = []
|
||||
for price, method in candidates:
|
||||
touches = _count_price_touches(price, highs, lows, closes, tolerance)
|
||||
strength = _strength_from_touches(touches, total_bars)
|
||||
raw_levels.append({
|
||||
"price_level": price,
|
||||
"strength": strength,
|
||||
"detection_method": method,
|
||||
"type": "", # will be tagged after merge
|
||||
})
|
||||
|
||||
# Merge nearby levels
|
||||
merged = _merge_levels(raw_levels, tolerance)
|
||||
|
||||
# Tag as support/resistance
|
||||
tagged = _tag_levels(merged, current_price)
|
||||
|
||||
# Sort by strength descending
|
||||
tagged.sort(key=lambda x: x["strength"], reverse=True)
|
||||
|
||||
return tagged
|
||||
|
||||
|
||||
async def recalculate_sr_levels(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
tolerance: float = DEFAULT_TOLERANCE,
|
||||
) -> list[SRLevel]:
|
||||
"""Recalculate S/R levels for a ticker and persist to DB.
|
||||
|
||||
1. Fetch OHLCV data
|
||||
2. Detect levels
|
||||
3. Delete old levels for ticker
|
||||
4. Insert new levels
|
||||
5. Return new levels sorted by strength desc
|
||||
"""
|
||||
ticker = await _get_ticker(db, symbol)
|
||||
|
||||
records = await query_ohlcv(db, symbol)
|
||||
if not records:
|
||||
# No OHLCV data — clear any existing levels
|
||||
await db.execute(
|
||||
delete(SRLevel).where(SRLevel.ticker_id == ticker.id)
|
||||
)
|
||||
await db.commit()
|
||||
return []
|
||||
|
||||
_, highs, lows, closes, volumes = _extract_ohlcv(records)
|
||||
|
||||
levels = detect_sr_levels(highs, lows, closes, volumes, tolerance)
|
||||
|
||||
# Delete old levels
|
||||
await db.execute(
|
||||
delete(SRLevel).where(SRLevel.ticker_id == ticker.id)
|
||||
)
|
||||
|
||||
# Insert new levels
|
||||
now = datetime.utcnow()
|
||||
new_models: list[SRLevel] = []
|
||||
for lvl in levels:
|
||||
model = SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=lvl["price_level"],
|
||||
type=lvl["type"],
|
||||
strength=lvl["strength"],
|
||||
detection_method=lvl["detection_method"],
|
||||
created_at=now,
|
||||
)
|
||||
db.add(model)
|
||||
new_models.append(model)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Refresh to get IDs
|
||||
for m in new_models:
|
||||
await db.refresh(m)
|
||||
|
||||
return new_models
|
||||
|
||||
|
||||
async def get_sr_levels(
|
||||
db: AsyncSession,
|
||||
symbol: str,
|
||||
tolerance: float = DEFAULT_TOLERANCE,
|
||||
) -> list[SRLevel]:
|
||||
"""Get S/R levels for a ticker, recalculating on every request (MVP).
|
||||
|
||||
Returns levels sorted by strength descending.
|
||||
"""
|
||||
return await recalculate_sr_levels(db, symbol, tolerance)
|
||||
Reference in New Issue
Block a user