major update
Some checks failed
Deploy / lint (push) Failing after 8s
Deploy / test (push) Has been skipped
Deploy / deploy (push) Has been skipped

This commit is contained in:
Dennis Thiessen
2026-02-27 16:08:09 +01:00
parent 61ab24490d
commit 181cfe6588
71 changed files with 7647 additions and 281 deletions

View File

@@ -0,0 +1,211 @@
"""Unit tests for cluster_sr_zones() in app.services.sr_service."""
from app.services.sr_service import cluster_sr_zones
def _level(price: float, strength: int = 50) -> dict:
"""Helper to build a level dict."""
return {"price_level": price, "strength": strength}
class TestClusterSrZonesEmptyAndEdge:
"""Edge cases: empty input, max_zones boundaries."""
def test_empty_levels_returns_empty(self):
assert cluster_sr_zones([], current_price=100.0) == []
def test_max_zones_zero_returns_empty(self):
levels = [_level(100.0)]
assert cluster_sr_zones(levels, current_price=100.0, max_zones=0) == []
def test_max_zones_negative_returns_empty(self):
levels = [_level(100.0)]
assert cluster_sr_zones(levels, current_price=100.0, max_zones=-1) == []
def test_single_level(self):
levels = [_level(95.0, 60)]
zones = cluster_sr_zones(levels, current_price=100.0)
assert len(zones) == 1
z = zones[0]
assert z["low"] == 95.0
assert z["high"] == 95.0
assert z["midpoint"] == 95.0
assert z["strength"] == 60
assert z["type"] == "support"
assert z["level_count"] == 1
class TestClusterSrZonesMerging:
"""Greedy merge behaviour."""
def test_two_levels_within_tolerance_merge(self):
# 100 and 101 are 1% apart; tolerance=2% → should merge
levels = [_level(100.0, 30), _level(101.0, 40)]
zones = cluster_sr_zones(levels, current_price=200.0, tolerance=0.02)
assert len(zones) == 1
z = zones[0]
assert z["low"] == 100.0
assert z["high"] == 101.0
assert z["midpoint"] == 100.5
assert z["strength"] == 70
assert z["level_count"] == 2
def test_two_levels_outside_tolerance_stay_separate(self):
# 100 and 110 are 10% apart; tolerance=2% → separate
levels = [_level(100.0, 30), _level(110.0, 40)]
zones = cluster_sr_zones(levels, current_price=200.0, tolerance=0.02)
assert len(zones) == 2
def test_all_same_price_merge_into_one(self):
levels = [_level(50.0, 20), _level(50.0, 30), _level(50.0, 10)]
zones = cluster_sr_zones(levels, current_price=100.0)
assert len(zones) == 1
assert zones[0]["strength"] == 60
assert zones[0]["level_count"] == 3
def test_levels_at_tolerance_boundary(self):
# midpoint of cluster starting at 100 is 100. 2% of 100 = 2.
# A level at 102 is exactly at the boundary → should merge
levels = [_level(100.0, 25), _level(102.0, 25)]
zones = cluster_sr_zones(levels, current_price=200.0, tolerance=0.02)
assert len(zones) == 1
class TestClusterSrZonesStrength:
"""Strength capping and computation."""
def test_strength_capped_at_100(self):
levels = [_level(100.0, 80), _level(100.5, 80)]
zones = cluster_sr_zones(levels, current_price=200.0, tolerance=0.02)
assert len(zones) == 1
assert zones[0]["strength"] == 100
def test_strength_sum_when_under_cap(self):
levels = [_level(100.0, 10), _level(100.5, 20)]
zones = cluster_sr_zones(levels, current_price=200.0, tolerance=0.02)
assert zones[0]["strength"] == 30
class TestClusterSrZonesTypeTagging:
"""Support vs resistance tagging."""
def test_support_when_midpoint_below_current(self):
levels = [_level(90.0, 50)]
zones = cluster_sr_zones(levels, current_price=100.0)
assert zones[0]["type"] == "support"
def test_resistance_when_midpoint_above_current(self):
levels = [_level(110.0, 50)]
zones = cluster_sr_zones(levels, current_price=100.0)
assert zones[0]["type"] == "resistance"
def test_resistance_when_midpoint_equals_current(self):
# "else resistance" per spec
levels = [_level(100.0, 50)]
zones = cluster_sr_zones(levels, current_price=100.0)
assert zones[0]["type"] == "resistance"
class TestClusterSrZonesSorting:
"""Sorting by strength descending."""
def test_sorted_by_strength_descending(self):
levels = [_level(50.0, 20), _level(150.0, 80), _level(250.0, 50)]
zones = cluster_sr_zones(levels, current_price=100.0, tolerance=0.001)
strengths = [z["strength"] for z in zones]
assert strengths == sorted(strengths, reverse=True)
class TestClusterSrZonesMaxZones:
"""max_zones filtering."""
def test_max_zones_limits_output(self):
levels = [_level(50.0, 20), _level(150.0, 80), _level(250.0, 50)]
zones = cluster_sr_zones(
levels, current_price=100.0, tolerance=0.001, max_zones=2
)
assert len(zones) == 2
# Balanced selection: 1 support (strength 20) + 1 resistance (strength 80)
types = {z["type"] for z in zones}
assert "support" in types
assert "resistance" in types
assert zones[0]["strength"] == 80
assert zones[1]["strength"] == 20
def test_max_zones_none_returns_all(self):
levels = [_level(50.0, 20), _level(150.0, 80), _level(250.0, 50)]
zones = cluster_sr_zones(
levels, current_price=100.0, tolerance=0.001, max_zones=None
)
assert len(zones) == 3
def test_max_zones_larger_than_count_returns_all(self):
levels = [_level(50.0, 20)]
zones = cluster_sr_zones(
levels, current_price=100.0, max_zones=10
)
assert len(zones) == 1
class TestClusterSrZonesBalancedSelection:
"""Balanced interleave selection behaviour (Requirements 1.1, 1.2, 1.3, 1.5, 1.6)."""
def test_mixed_input_produces_balanced_output(self):
"""3 support + 3 resistance with max_zones=4 → 2 support + 2 resistance."""
levels = [
_level(80.0, 70), # support
_level(85.0, 50), # support
_level(90.0, 30), # support
_level(110.0, 60), # resistance
_level(115.0, 40), # resistance
_level(120.0, 20), # resistance
]
zones = cluster_sr_zones(levels, current_price=100.0, tolerance=0.001, max_zones=4)
assert len(zones) == 4
support_count = sum(1 for z in zones if z["type"] == "support")
resistance_count = sum(1 for z in zones if z["type"] == "resistance")
assert support_count == 2
assert resistance_count == 2
def test_all_support_fills_from_support_only(self):
"""When no resistance levels exist, all slots filled from support."""
levels = [
_level(80.0, 70),
_level(85.0, 50),
_level(90.0, 30),
]
zones = cluster_sr_zones(levels, current_price=200.0, tolerance=0.001, max_zones=2)
assert len(zones) == 2
assert all(z["type"] == "support" for z in zones)
def test_all_resistance_fills_from_resistance_only(self):
"""When no support levels exist, all slots filled from resistance."""
levels = [
_level(110.0, 60),
_level(115.0, 40),
_level(120.0, 20),
]
zones = cluster_sr_zones(levels, current_price=50.0, tolerance=0.001, max_zones=2)
assert len(zones) == 2
assert all(z["type"] == "resistance" for z in zones)
def test_single_zone_edge_case(self):
"""Only 1 level total → returns exactly 1 zone."""
levels = [_level(95.0, 45)]
zones = cluster_sr_zones(levels, current_price=100.0, max_zones=5)
assert len(zones) == 1
assert zones[0]["strength"] == 45
def test_both_types_present_when_max_zones_gte_2(self):
"""When both types exist and max_zones >= 2, at least one of each is present."""
levels = [
_level(70.0, 90), # support (strongest overall)
_level(75.0, 80), # support
_level(80.0, 70), # support
_level(130.0, 10), # resistance (weakest overall)
]
zones = cluster_sr_zones(levels, current_price=100.0, tolerance=0.001, max_zones=2)
types = {z["type"] for z in zones}
assert "support" in types
assert "resistance" in types

View File

@@ -0,0 +1,156 @@
"""Unit tests for FMPFundamentalProvider 402 reason recording."""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from app.providers.fmp import FMPFundamentalProvider
def _mock_response(status_code: int, json_data: object = None) -> httpx.Response:
"""Build a fake httpx.Response."""
resp = httpx.Response(
status_code=status_code,
json=json_data if json_data is not None else {},
request=httpx.Request("GET", "https://example.com"),
)
return resp
@pytest.fixture
def provider() -> FMPFundamentalProvider:
return FMPFundamentalProvider(api_key="test-key")
class TestFetchJsonOptional402Tracking:
"""_fetch_json_optional returns (data, was_402) tuple."""
@pytest.mark.asyncio
async def test_returns_empty_dict_and_true_on_402(self, provider):
mock_client = AsyncMock()
mock_client.get.return_value = _mock_response(402)
data, was_402 = await provider._fetch_json_optional(
mock_client, "ratios-ttm", {}, "AAPL"
)
assert data == {}
assert was_402 is True
@pytest.mark.asyncio
async def test_returns_data_and_false_on_200(self, provider):
mock_client = AsyncMock()
mock_client.get.return_value = _mock_response(
200, [{"priceToEarningsRatioTTM": 25.5}]
)
data, was_402 = await provider._fetch_json_optional(
mock_client, "ratios-ttm", {}, "AAPL"
)
assert data == {"priceToEarningsRatioTTM": 25.5}
assert was_402 is False
class TestFetchFundamentals402Recording:
"""fetch_fundamentals records 402 endpoints in unavailable_fields."""
@pytest.mark.asyncio
async def test_all_402_records_all_fields(self, provider):
"""When all supplementary endpoints return 402, all three fields are recorded."""
profile_resp = _mock_response(200, [{"marketCap": 1_000_000}])
ratios_resp = _mock_response(402)
growth_resp = _mock_response(402)
earnings_resp = _mock_response(402)
async def mock_get(url, params=None):
if "profile" in url:
return profile_resp
if "ratios-ttm" in url:
return ratios_resp
if "financial-growth" in url:
return growth_resp
if "earnings" in url:
return earnings_resp
return _mock_response(200, [{}])
with patch("app.providers.fmp.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.get.side_effect = mock_get
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
result = await provider.fetch_fundamentals("AAPL")
assert result.unavailable_fields == {
"pe_ratio": "requires paid plan",
"revenue_growth": "requires paid plan",
"earnings_surprise": "requires paid plan",
}
@pytest.mark.asyncio
async def test_mixed_200_402_records_only_402_fields(self, provider):
"""When only ratios-ttm returns 402, only pe_ratio is recorded."""
profile_resp = _mock_response(200, [{"marketCap": 2_000_000}])
ratios_resp = _mock_response(402)
growth_resp = _mock_response(200, [{"revenueGrowth": 0.15}])
earnings_resp = _mock_response(200, [{"epsActual": 3.0, "epsEstimated": 2.5}])
async def mock_get(url, params=None):
if "profile" in url:
return profile_resp
if "ratios-ttm" in url:
return ratios_resp
if "financial-growth" in url:
return growth_resp
if "earnings" in url:
return earnings_resp
return _mock_response(200, [{}])
with patch("app.providers.fmp.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.get.side_effect = mock_get
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
result = await provider.fetch_fundamentals("AAPL")
assert result.unavailable_fields == {"pe_ratio": "requires paid plan"}
assert result.revenue_growth == 0.15
assert result.earnings_surprise is not None
@pytest.mark.asyncio
async def test_no_402_empty_unavailable_fields(self, provider):
"""When all endpoints succeed, unavailable_fields is empty."""
profile_resp = _mock_response(200, [{"marketCap": 3_000_000}])
ratios_resp = _mock_response(200, [{"priceToEarningsRatioTTM": 20.0}])
growth_resp = _mock_response(200, [{"revenueGrowth": 0.10}])
earnings_resp = _mock_response(200, [{"epsActual": 2.0, "epsEstimated": 1.8}])
async def mock_get(url, params=None):
if "profile" in url:
return profile_resp
if "ratios-ttm" in url:
return ratios_resp
if "financial-growth" in url:
return growth_resp
if "earnings" in url:
return earnings_resp
return _mock_response(200, [{}])
with patch("app.providers.fmp.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.get.side_effect = mock_get
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
result = await provider.fetch_fundamentals("AAPL")
assert result.unavailable_fields == {}
assert result.pe_ratio == 20.0

View File

@@ -0,0 +1,99 @@
"""Unit tests for fundamental_service — unavailable_fields persistence."""
from __future__ import annotations
import json
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.database import Base
from app.models.ticker import Ticker
from app.services import fundamental_service
# Use a dedicated engine so commit/refresh work without conflicting
# with the conftest transactional session.
_engine = create_async_engine("sqlite+aiosqlite://", echo=False)
_session_factory = async_sessionmaker(_engine, class_=AsyncSession, expire_on_commit=False)
@pytest.fixture(autouse=True)
async def _setup_tables():
async with _engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with _engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def session() -> AsyncSession:
async with _session_factory() as s:
yield s
@pytest.fixture
async def ticker(session: AsyncSession) -> Ticker:
"""Create a test ticker."""
t = Ticker(symbol="AAPL")
session.add(t)
await session.commit()
await session.refresh(t)
return t
@pytest.mark.asyncio
async def test_store_fundamental_persists_unavailable_fields(
session: AsyncSession, ticker: Ticker
):
"""unavailable_fields dict is serialized to JSON and stored."""
fields = {"pe_ratio": "requires paid plan", "revenue_growth": "requires paid plan"}
record = await fundamental_service.store_fundamental(
session,
symbol="AAPL",
pe_ratio=None,
revenue_growth=None,
market_cap=1_000_000.0,
unavailable_fields=fields,
)
assert json.loads(record.unavailable_fields_json) == fields
@pytest.mark.asyncio
async def test_store_fundamental_defaults_to_empty_dict(
session: AsyncSession, ticker: Ticker
):
"""When unavailable_fields is not provided, column defaults to '{}'."""
record = await fundamental_service.store_fundamental(
session,
symbol="AAPL",
pe_ratio=25.0,
)
assert json.loads(record.unavailable_fields_json) == {}
@pytest.mark.asyncio
async def test_store_fundamental_updates_unavailable_fields(
session: AsyncSession, ticker: Ticker
):
"""Updating an existing record also updates unavailable_fields_json."""
# First store
await fundamental_service.store_fundamental(
session,
symbol="AAPL",
pe_ratio=None,
unavailable_fields={"pe_ratio": "requires paid plan"},
)
# Second store — fields now available
record = await fundamental_service.store_fundamental(
session,
symbol="AAPL",
pe_ratio=25.0,
unavailable_fields={},
)
assert json.loads(record.unavailable_fields_json) == {}

View File

@@ -0,0 +1,162 @@
"""Unit tests for OpenAISentimentProvider reasoning and citations extraction."""
from __future__ import annotations
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.providers.openai_sentiment import OpenAISentimentProvider
def _make_annotation(ann_type: str, url: str = "", title: str = "") -> SimpleNamespace:
return SimpleNamespace(type=ann_type, url=url, title=title)
def _make_content_block(text: str, annotations: list | None = None) -> SimpleNamespace:
block = SimpleNamespace(text=text, annotations=annotations or [])
return block
def _make_message_item(content_blocks: list) -> SimpleNamespace:
return SimpleNamespace(type="message", content=content_blocks)
def _make_web_search_item() -> SimpleNamespace:
return SimpleNamespace(type="web_search_call")
def _build_response(json_text: str, annotations: list | None = None) -> SimpleNamespace:
"""Build a mock OpenAI Responses API response object."""
content_block = _make_content_block(json_text, annotations or [])
message_item = _make_message_item([content_block])
items = [message_item]
return SimpleNamespace(output=items)
def _build_response_with_search(
json_text: str, annotations: list | None = None
) -> SimpleNamespace:
"""Build a response with a web_search_call item followed by a message with annotations."""
search_item = _make_web_search_item()
content_block = _make_content_block(json_text, annotations or [])
message_item = _make_message_item([content_block])
return SimpleNamespace(output=[search_item, message_item])
@pytest.fixture
def provider():
"""Create an OpenAISentimentProvider with a mocked client."""
with patch("app.providers.openai_sentiment.AsyncOpenAI"):
p = OpenAISentimentProvider(api_key="test-key")
return p
class TestReasoningExtraction:
"""Tests for extracting reasoning from the parsed JSON response."""
@pytest.mark.asyncio
async def test_reasoning_extracted_from_json(self, provider):
json_text = '{"classification": "bullish", "confidence": 85, "reasoning": "Strong earnings report"}'
mock_response = _build_response(json_text)
provider._client.responses.create = AsyncMock(return_value=mock_response)
result = await provider.fetch_sentiment("AAPL")
assert result.reasoning == "Strong earnings report"
@pytest.mark.asyncio
async def test_empty_reasoning_when_field_missing(self, provider):
json_text = '{"classification": "neutral", "confidence": 50}'
mock_response = _build_response(json_text)
provider._client.responses.create = AsyncMock(return_value=mock_response)
result = await provider.fetch_sentiment("MSFT")
assert result.reasoning == ""
@pytest.mark.asyncio
async def test_empty_reasoning_when_field_is_empty_string(self, provider):
json_text = '{"classification": "bearish", "confidence": 70, "reasoning": ""}'
mock_response = _build_response(json_text)
provider._client.responses.create = AsyncMock(return_value=mock_response)
result = await provider.fetch_sentiment("TSLA")
assert result.reasoning == ""
class TestCitationsExtraction:
"""Tests for extracting url_citation annotations from the response."""
@pytest.mark.asyncio
async def test_citations_extracted_from_annotations(self, provider):
json_text = '{"classification": "bullish", "confidence": 90, "reasoning": "Good news"}'
annotations = [
_make_annotation("url_citation", url="https://example.com/1", title="Article 1"),
_make_annotation("url_citation", url="https://example.com/2", title="Article 2"),
]
mock_response = _build_response(json_text, annotations)
provider._client.responses.create = AsyncMock(return_value=mock_response)
result = await provider.fetch_sentiment("AAPL")
assert len(result.citations) == 2
assert result.citations[0] == {"url": "https://example.com/1", "title": "Article 1"}
assert result.citations[1] == {"url": "https://example.com/2", "title": "Article 2"}
@pytest.mark.asyncio
async def test_empty_citations_when_no_annotations(self, provider):
json_text = '{"classification": "neutral", "confidence": 50, "reasoning": "No news"}'
mock_response = _build_response(json_text)
provider._client.responses.create = AsyncMock(return_value=mock_response)
result = await provider.fetch_sentiment("GOOG")
assert result.citations == []
@pytest.mark.asyncio
async def test_non_url_citation_annotations_ignored(self, provider):
json_text = '{"classification": "bearish", "confidence": 60, "reasoning": "Mixed signals"}'
annotations = [
_make_annotation("file_citation", url="https://file.com", title="File"),
_make_annotation("url_citation", url="https://real.com", title="Real"),
]
mock_response = _build_response(json_text, annotations)
provider._client.responses.create = AsyncMock(return_value=mock_response)
result = await provider.fetch_sentiment("META")
assert len(result.citations) == 1
assert result.citations[0] == {"url": "https://real.com", "title": "Real"}
@pytest.mark.asyncio
async def test_citations_from_response_with_web_search_call(self, provider):
json_text = '{"classification": "bullish", "confidence": 80, "reasoning": "Positive outlook"}'
annotations = [
_make_annotation("url_citation", url="https://news.com/a", title="News A"),
]
mock_response = _build_response_with_search(json_text, annotations)
provider._client.responses.create = AsyncMock(return_value=mock_response)
result = await provider.fetch_sentiment("NVDA")
assert len(result.citations) == 1
assert result.citations[0] == {"url": "https://news.com/a", "title": "News A"}
@pytest.mark.asyncio
async def test_no_error_when_annotations_attr_missing(self, provider):
"""Content blocks without annotations attribute should not cause errors."""
json_text = '{"classification": "neutral", "confidence": 50, "reasoning": "Quiet day"}'
# Block with no annotations attribute at all
block = SimpleNamespace(text=json_text)
message_item = SimpleNamespace(type="message", content=[block])
mock_response = SimpleNamespace(output=[message_item])
provider._client.responses.create = AsyncMock(return_value=mock_response)
result = await provider.fetch_sentiment("AMD")
assert result.citations == []
assert result.reasoning == "Quiet day"

View File

@@ -0,0 +1,273 @@
"""Bug-condition exploration tests for R:R scanner target quality.
These tests confirm the bug described in bugfix.md: the old code always selected
the most distant S/R level (highest raw R:R) regardless of strength or proximity.
The fix replaces max-R:R selection with quality-score selection.
Since the code is already fixed, these tests PASS on the current codebase.
On the unfixed code they would FAIL, confirming the bug.
**Validates: Requirements 1.1, 1.3, 1.4, 2.1, 2.3, 2.4**
"""
from __future__ import annotations
from datetime import date, timedelta
import pytest
from hypothesis import given, settings, HealthCheck, strategies as st
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.ohlcv import OHLCVRecord
from app.models.sr_level import SRLevel
from app.models.ticker import Ticker
from app.services.rr_scanner_service import scan_ticker
# ---------------------------------------------------------------------------
# Session fixture that allows scan_ticker to commit
# ---------------------------------------------------------------------------
# The default db_session fixture wraps in session.begin() which conflicts
# with scan_ticker's internal commit(). We use a plain session instead.
@pytest.fixture
async def scan_session() -> AsyncSession:
"""Provide a DB session compatible with scan_ticker (which commits)."""
from tests.conftest import _test_session_factory
async with _test_session_factory() as session:
yield session
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ohlcv_bars(
ticker_id: int,
num_bars: int = 20,
base_close: float = 100.0,
) -> list[OHLCVRecord]:
"""Generate realistic OHLCV bars with small daily variation.
Produces bars where close ≈ base_close, with enough range for ATR
computation (needs >= 15 bars). The ATR will be roughly 2.0.
"""
bars: list[OHLCVRecord] = []
start = date(2024, 1, 1)
for i in range(num_bars):
close = base_close + (i % 3 - 1) * 0.5 # oscillate ±0.5
bars.append(OHLCVRecord(
ticker_id=ticker_id,
date=start + timedelta(days=i),
open=close - 0.3,
high=close + 1.0,
low=close - 1.0,
close=close,
volume=100_000,
))
return bars
# ---------------------------------------------------------------------------
# Deterministic test: strong-near vs weak-far (long setup)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_long_prefers_strong_near_over_weak_far(scan_session: AsyncSession):
"""With a strong nearby resistance and a weak distant resistance, the
scanner should pick the strong nearby one — NOT the most distant.
On unfixed code this would fail because max-R:R always picks the
farthest level.
"""
ticker = Ticker(symbol="EXPLR")
scan_session.add(ticker)
await scan_session.flush()
# 20 bars closing around 100
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
scan_session.add_all(bars)
# With ATR=2.0 and multiplier=1.5, risk=3.0.
# R:R threshold=1.5 → min reward=4.5 → min target=104.5
# Strong nearby resistance: price=105, strength=90 (R:R≈1.67, quality≈0.66)
near_level = SRLevel(
ticker_id=ticker.id,
price_level=105.0,
type="resistance",
strength=90,
detection_method="volume_profile",
)
# Weak distant resistance: price=130, strength=5 (R:R=10, quality≈0.58)
far_level = SRLevel(
ticker_id=ticker.id,
price_level=130.0,
type="resistance",
strength=5,
detection_method="volume_profile",
)
scan_session.add_all([near_level, far_level])
await scan_session.flush()
setups = await scan_ticker(scan_session, "EXPLR", rr_threshold=1.5)
long_setups = [s for s in setups if s.direction == "long"]
assert len(long_setups) == 1, "Expected exactly one long setup"
selected_target = long_setups[0].target
# The scanner must NOT pick the most distant level (130)
assert selected_target != pytest.approx(130.0, abs=0.01), (
"Bug: scanner picked the weak distant level (130) instead of the "
"strong nearby level (105)"
)
# It should pick the strong nearby level
assert selected_target == pytest.approx(105.0, abs=0.01)
# ---------------------------------------------------------------------------
# Deterministic test: strong-near vs weak-far (short setup)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_short_prefers_strong_near_over_weak_far(scan_session: AsyncSession):
"""Short-side mirror: strong nearby support should be preferred over
weak distant support.
"""
ticker = Ticker(symbol="EXPLS")
scan_session.add(ticker)
await scan_session.flush()
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
scan_session.add_all(bars)
# With ATR=2.0 and multiplier=1.5, risk=3.0.
# R:R threshold=1.5 → min reward=4.5 → min target below 95.5
# Strong nearby support: price=95, strength=85 (R:R≈1.67, quality≈0.64)
near_level = SRLevel(
ticker_id=ticker.id,
price_level=95.0,
type="support",
strength=85,
detection_method="pivot_point",
)
# Weak distant support: price=70, strength=5 (R:R=10, quality≈0.58)
far_level = SRLevel(
ticker_id=ticker.id,
price_level=70.0,
type="support",
strength=5,
detection_method="pivot_point",
)
scan_session.add_all([near_level, far_level])
await scan_session.flush()
setups = await scan_ticker(scan_session, "EXPLS", rr_threshold=1.5)
short_setups = [s for s in setups if s.direction == "short"]
assert len(short_setups) == 1, "Expected exactly one short setup"
selected_target = short_setups[0].target
assert selected_target != pytest.approx(70.0, abs=0.01), (
"Bug: scanner picked the weak distant level (70) instead of the "
"strong nearby level (95)"
)
assert selected_target == pytest.approx(95.0, abs=0.01)
# ---------------------------------------------------------------------------
# Hypothesis property test: selection is NOT always the most distant level
# ---------------------------------------------------------------------------
@st.composite
def strong_near_weak_far_pair(draw: st.DrawFn) -> dict:
"""Generate a (strong-near, weak-far) resistance pair above entry=100.
Guarantees:
- near_price < far_price (both above entry)
- near_strength >> far_strength
- Both meet the R:R threshold of 1.5 given typical ATR ≈ 2 → risk ≈ 3
"""
# Near level: 515 above entry (R:R ≈ 1.75.0 with risk≈3)
near_dist = draw(st.floats(min_value=5.0, max_value=15.0))
near_strength = draw(st.integers(min_value=70, max_value=100))
# Far level: 2560 above entry (R:R ≈ 8.320 with risk≈3)
far_dist = draw(st.floats(min_value=25.0, max_value=60.0))
far_strength = draw(st.integers(min_value=1, max_value=15))
return {
"near_price": 100.0 + near_dist,
"near_strength": near_strength,
"far_price": 100.0 + far_dist,
"far_strength": far_strength,
}
@pytest.mark.asyncio
@given(pair=strong_near_weak_far_pair())
@settings(
max_examples=15,
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
)
async def test_property_scanner_does_not_always_pick_most_distant(
pair: dict,
scan_session: AsyncSession,
):
"""**Validates: Requirements 1.1, 1.3, 1.4, 2.1, 2.3, 2.4**
Property: when a strong nearby resistance exists alongside a weak distant
resistance, the scanner does NOT always select the most distant level.
On unfixed code this would fail for every example because max-R:R always
picks the farthest level.
"""
from tests.conftest import _test_engine, _test_session_factory
# Each hypothesis example needs a fresh DB state
async with _test_engine.begin() as conn:
from app.database import Base
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
async with _test_session_factory() as session:
ticker = Ticker(symbol="PROP")
session.add(ticker)
await session.flush()
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
session.add_all(bars)
near_level = SRLevel(
ticker_id=ticker.id,
price_level=pair["near_price"],
type="resistance",
strength=pair["near_strength"],
detection_method="volume_profile",
)
far_level = SRLevel(
ticker_id=ticker.id,
price_level=pair["far_price"],
type="resistance",
strength=pair["far_strength"],
detection_method="volume_profile",
)
session.add_all([near_level, far_level])
await session.commit()
setups = await scan_ticker(session, "PROP", rr_threshold=1.5)
long_setups = [s for s in setups if s.direction == "long"]
assert len(long_setups) == 1, "Expected exactly one long setup"
selected_target = long_setups[0].target
most_distant = round(pair["far_price"], 4)
# The fixed scanner should prefer the strong nearby level, not the
# most distant weak one.
assert selected_target != pytest.approx(most_distant, abs=0.01), (
f"Bug: scanner picked the most distant level ({most_distant}) "
f"with strength={pair['far_strength']} over the nearby level "
f"({round(pair['near_price'], 4)}) with strength={pair['near_strength']}"
)

View File

@@ -0,0 +1,383 @@
"""Fix-checking tests for R:R scanner quality-score selection.
Verify that the fixed scan_ticker selects the candidate with the highest
quality score among all candidates meeting the R:R threshold, for both
long and short setups.
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
"""
from __future__ import annotations
from datetime import date, timedelta
import pytest
from hypothesis import given, settings, HealthCheck, strategies as st
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.ohlcv import OHLCVRecord
from app.models.sr_level import SRLevel
from app.models.ticker import Ticker
from app.services.rr_scanner_service import scan_ticker, _compute_quality_score
# ---------------------------------------------------------------------------
# Session fixture (plain session, not wrapped in begin())
# ---------------------------------------------------------------------------
@pytest.fixture
async def scan_session() -> AsyncSession:
"""Provide a DB session compatible with scan_ticker (which commits)."""
from tests.conftest import _test_session_factory
async with _test_session_factory() as session:
yield session
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ohlcv_bars(
ticker_id: int,
num_bars: int = 20,
base_close: float = 100.0,
) -> list[OHLCVRecord]:
"""Generate OHLCV bars closing around base_close with ATR ≈ 2.0."""
bars: list[OHLCVRecord] = []
start = date(2024, 1, 1)
for i in range(num_bars):
close = base_close + (i % 3 - 1) * 0.5 # oscillate ±0.5
bars.append(OHLCVRecord(
ticker_id=ticker_id,
date=start + timedelta(days=i),
open=close - 0.3,
high=close + 1.0,
low=close - 1.0,
close=close,
volume=100_000,
))
return bars
# ---------------------------------------------------------------------------
# Hypothesis strategy: multiple resistance levels above entry for longs
# ---------------------------------------------------------------------------
@st.composite
def long_candidate_levels(draw: st.DrawFn) -> list[dict]:
"""Generate 2-5 resistance levels above entry_price=100.
All levels meet the R:R threshold of 1.5 given ATR≈2, risk≈3,
so min reward=4.5, min target=104.5.
"""
num_levels = draw(st.integers(min_value=2, max_value=5))
levels = []
for _ in range(num_levels):
# Distance from entry: 5 to 50 (all above 4.5 threshold)
distance = draw(st.floats(min_value=5.0, max_value=50.0))
strength = draw(st.integers(min_value=0, max_value=100))
levels.append({
"price": 100.0 + distance,
"strength": strength,
})
return levels
@st.composite
def short_candidate_levels(draw: st.DrawFn) -> list[dict]:
"""Generate 2-5 support levels below entry_price=100.
All levels meet the R:R threshold of 1.5 given ATR≈2, risk≈3,
so min reward=4.5, max target=95.5.
"""
num_levels = draw(st.integers(min_value=2, max_value=5))
levels = []
for _ in range(num_levels):
# Distance below entry: 5 to 50 (all above 4.5 threshold)
distance = draw(st.floats(min_value=5.0, max_value=50.0))
strength = draw(st.integers(min_value=0, max_value=100))
levels.append({
"price": 100.0 - distance,
"strength": strength,
})
return levels
# ---------------------------------------------------------------------------
# Property test: long setup selects highest quality score candidate
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@given(levels=long_candidate_levels())
@settings(
max_examples=20,
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
)
async def test_property_long_selects_highest_quality(
levels: list[dict],
scan_session: AsyncSession,
):
"""**Validates: Requirements 2.1, 2.3, 2.4**
Property: when multiple resistance levels meet the R:R threshold,
the fixed scan_ticker selects the one with the highest quality score.
"""
from tests.conftest import _test_engine, _test_session_factory
from app.database import Base
# Fresh DB state per hypothesis example
async with _test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
async with _test_session_factory() as session:
ticker = Ticker(symbol="FIXL")
session.add(ticker)
await session.flush()
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
session.add_all(bars)
sr_levels = []
for lv in levels:
sr_levels.append(SRLevel(
ticker_id=ticker.id,
price_level=lv["price"],
type="resistance",
strength=lv["strength"],
detection_method="volume_profile",
))
session.add_all(sr_levels)
await session.commit()
setups = await scan_ticker(session, "FIXL", rr_threshold=1.5)
long_setups = [s for s in setups if s.direction == "long"]
assert len(long_setups) == 1, "Expected exactly one long setup"
selected_target = long_setups[0].target
# Compute entry_price and risk from the bars (same logic as scan_ticker)
# entry_price = last close ≈ 100.0, ATR ≈ 2.0, risk = ATR * 1.5 = 3.0
entry_price = bars[-1].close
# Use approximate risk; the exact value comes from ATR computation
# We reconstruct it from the setup's entry and stop
risk = long_setups[0].entry_price - long_setups[0].stop_loss
# Compute quality scores for all candidates that meet threshold
best_quality = -1.0
best_target = None
for lv in levels:
distance = lv["price"] - entry_price
if distance > 0:
rr = distance / risk
if rr >= 1.5:
quality = _compute_quality_score(rr, lv["strength"], distance, entry_price)
if quality > best_quality:
best_quality = quality
best_target = round(lv["price"], 4)
assert best_target is not None, "At least one candidate should meet threshold"
assert selected_target == pytest.approx(best_target, abs=0.01), (
f"Selected target {selected_target} != expected best-quality target "
f"{best_target} (quality={best_quality:.4f})"
)
# ---------------------------------------------------------------------------
# Property test: short setup selects highest quality score candidate
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@given(levels=short_candidate_levels())
@settings(
max_examples=20,
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
)
async def test_property_short_selects_highest_quality(
levels: list[dict],
scan_session: AsyncSession,
):
"""**Validates: Requirements 2.2, 2.3, 2.4**
Property: when multiple support levels meet the R:R threshold,
the fixed scan_ticker selects the one with the highest quality score.
"""
from tests.conftest import _test_engine, _test_session_factory
from app.database import Base
# Fresh DB state per hypothesis example
async with _test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
async with _test_session_factory() as session:
ticker = Ticker(symbol="FIXS")
session.add(ticker)
await session.flush()
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
session.add_all(bars)
sr_levels = []
for lv in levels:
sr_levels.append(SRLevel(
ticker_id=ticker.id,
price_level=lv["price"],
type="support",
strength=lv["strength"],
detection_method="pivot_point",
))
session.add_all(sr_levels)
await session.commit()
setups = await scan_ticker(session, "FIXS", rr_threshold=1.5)
short_setups = [s for s in setups if s.direction == "short"]
assert len(short_setups) == 1, "Expected exactly one short setup"
selected_target = short_setups[0].target
entry_price = bars[-1].close
risk = short_setups[0].stop_loss - short_setups[0].entry_price
# Compute quality scores for all candidates that meet threshold
best_quality = -1.0
best_target = None
for lv in levels:
distance = entry_price - lv["price"]
if distance > 0:
rr = distance / risk
if rr >= 1.5:
quality = _compute_quality_score(rr, lv["strength"], distance, entry_price)
if quality > best_quality:
best_quality = quality
best_target = round(lv["price"], 4)
assert best_target is not None, "At least one candidate should meet threshold"
assert selected_target == pytest.approx(best_target, abs=0.01), (
f"Selected target {selected_target} != expected best-quality target "
f"{best_target} (quality={best_quality:.4f})"
)
# ---------------------------------------------------------------------------
# Deterministic test: 3 levels with known quality scores (long)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_deterministic_long_three_levels(scan_session: AsyncSession):
"""**Validates: Requirements 2.1, 2.3, 2.4**
Concrete example with 3 resistance levels of known quality scores.
Entry=100, ATR≈2, risk≈3.
Level A: price=105, strength=90 → rr=5/3≈1.67, dist=5
quality = 0.35*(1.67/10) + 0.35*(90/100) + 0.30*(1-5/100)
= 0.35*0.167 + 0.35*0.9 + 0.30*0.95
= 0.0585 + 0.315 + 0.285 = 0.6585
Level B: price=112, strength=50 → rr=12/3=4.0, dist=12
quality = 0.35*(4/10) + 0.35*(50/100) + 0.30*(1-12/100)
= 0.35*0.4 + 0.35*0.5 + 0.30*0.88
= 0.14 + 0.175 + 0.264 = 0.579
Level C: price=130, strength=10 → rr=30/3=10.0, dist=30
quality = 0.35*(10/10) + 0.35*(10/100) + 0.30*(1-30/100)
= 0.35*1.0 + 0.35*0.1 + 0.30*0.7
= 0.35 + 0.035 + 0.21 = 0.595
Expected winner: Level A (quality=0.6585)
"""
ticker = Ticker(symbol="DET3L")
scan_session.add(ticker)
await scan_session.flush()
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
scan_session.add_all(bars)
level_a = SRLevel(
ticker_id=ticker.id, price_level=105.0, type="resistance",
strength=90, detection_method="volume_profile",
)
level_b = SRLevel(
ticker_id=ticker.id, price_level=112.0, type="resistance",
strength=50, detection_method="volume_profile",
)
level_c = SRLevel(
ticker_id=ticker.id, price_level=130.0, type="resistance",
strength=10, detection_method="volume_profile",
)
scan_session.add_all([level_a, level_b, level_c])
await scan_session.flush()
setups = await scan_ticker(scan_session, "DET3L", rr_threshold=1.5)
long_setups = [s for s in setups if s.direction == "long"]
assert len(long_setups) == 1, "Expected exactly one long setup"
# Level A (105, strength=90) should win with highest quality
assert long_setups[0].target == pytest.approx(105.0, abs=0.01), (
f"Expected target=105.0 (highest quality), got {long_setups[0].target}"
)
# ---------------------------------------------------------------------------
# Deterministic test: 3 levels with known quality scores (short)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_deterministic_short_three_levels(scan_session: AsyncSession):
"""**Validates: Requirements 2.2, 2.3, 2.4**
Concrete example with 3 support levels of known quality scores.
Entry=100, ATR≈2, risk≈3.
Level A: price=95, strength=85 → rr=5/3≈1.67, dist=5
quality = 0.35*(1.67/10) + 0.35*(85/100) + 0.30*(1-5/100)
= 0.0585 + 0.2975 + 0.285 = 0.641
Level B: price=88, strength=45 → rr=12/3=4.0, dist=12
quality = 0.35*(4/10) + 0.35*(45/100) + 0.30*(1-12/100)
= 0.14 + 0.1575 + 0.264 = 0.5615
Level C: price=70, strength=8 → rr=30/3=10.0, dist=30
quality = 0.35*(10/10) + 0.35*(8/100) + 0.30*(1-30/100)
= 0.35 + 0.028 + 0.21 = 0.588
Expected winner: Level A (quality=0.641)
"""
ticker = Ticker(symbol="DET3S")
scan_session.add(ticker)
await scan_session.flush()
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
scan_session.add_all(bars)
level_a = SRLevel(
ticker_id=ticker.id, price_level=95.0, type="support",
strength=85, detection_method="pivot_point",
)
level_b = SRLevel(
ticker_id=ticker.id, price_level=88.0, type="support",
strength=45, detection_method="pivot_point",
)
level_c = SRLevel(
ticker_id=ticker.id, price_level=70.0, type="support",
strength=8, detection_method="pivot_point",
)
scan_session.add_all([level_a, level_b, level_c])
await scan_session.flush()
setups = await scan_ticker(scan_session, "DET3S", rr_threshold=1.5)
short_setups = [s for s in setups if s.direction == "short"]
assert len(short_setups) == 1, "Expected exactly one short setup"
# Level A (95, strength=85) should win with highest quality
assert short_setups[0].target == pytest.approx(95.0, abs=0.01), (
f"Expected target=95.0 (highest quality), got {short_setups[0].target}"
)

View File

@@ -0,0 +1,259 @@
"""Integration tests for R:R scanner full flow with quality-based target selection.
Verifies the complete scan_ticker pipeline: quality-based S/R level selection,
correct TradeSetup field population, and database persistence.
**Validates: Requirements 2.1, 2.2, 2.3, 2.4, 3.4**
"""
from __future__ import annotations
from datetime import date, datetime, timedelta, timezone
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.ohlcv import OHLCVRecord
from app.models.score import CompositeScore
from app.models.sr_level import SRLevel
from app.models.ticker import Ticker
from app.models.trade_setup import TradeSetup
from app.services.rr_scanner_service import scan_ticker, _compute_quality_score
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
async def scan_session() -> AsyncSession:
"""Provide a DB session compatible with scan_ticker (which commits)."""
from tests.conftest import _test_session_factory
async with _test_session_factory() as session:
yield session
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ohlcv_bars(
ticker_id: int,
num_bars: int = 20,
base_close: float = 100.0,
) -> list[OHLCVRecord]:
"""Generate OHLCV bars closing around base_close with ATR ≈ 2.0."""
bars: list[OHLCVRecord] = []
start = date(2024, 1, 1)
for i in range(num_bars):
close = base_close + (i % 3 - 1) * 0.5 # oscillate ±0.5
bars.append(OHLCVRecord(
ticker_id=ticker_id,
date=start + timedelta(days=i),
open=close - 0.3,
high=close + 1.0,
low=close - 1.0,
close=close,
volume=100_000,
))
return bars
# ===========================================================================
# 8.1 Integration test: full scan_ticker flow with quality-based selection,
# correct TradeSetup fields, and database persistence
# ===========================================================================
@pytest.mark.asyncio
async def test_scan_ticker_full_flow_quality_selection_and_persistence(
scan_session: AsyncSession,
):
"""Integration test for the complete scan_ticker pipeline.
Scenario:
- Entry ≈ 100, ATR ≈ 2.0, risk ≈ 3.0 (atr_multiplier=1.5)
- 3 resistance levels above (long candidates):
A: price=105, strength=90 (strong, near) → highest quality
B: price=115, strength=40 (medium, mid)
C: price=135, strength=5 (weak, far)
- 3 support levels below (short candidates):
D: price=95, strength=85 (strong, near) → highest quality
E: price=85, strength=35 (medium, mid)
F: price=65, strength=8 (weak, far)
- CompositeScore: 72.5
Verifies:
1. Both long and short setups are produced
2. Long target = Level A (highest quality, not most distant)
3. Short target = Level D (highest quality, not most distant)
4. All TradeSetup fields are correct and rounded to 4 decimals
5. rr_ratio is the actual R:R of the selected level
6. Old setups are deleted, new ones persisted
"""
# -- Setup: create ticker --
ticker = Ticker(symbol="INTEG")
scan_session.add(ticker)
await scan_session.flush()
# -- Setup: OHLCV bars (20 bars, close ≈ 100, ATR ≈ 2.0) --
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
scan_session.add_all(bars)
# -- Setup: S/R levels --
sr_levels = [
# Long candidates (resistance above entry)
SRLevel(ticker_id=ticker.id, price_level=105.0, type="resistance",
strength=90, detection_method="volume_profile"),
SRLevel(ticker_id=ticker.id, price_level=115.0, type="resistance",
strength=40, detection_method="volume_profile"),
SRLevel(ticker_id=ticker.id, price_level=135.0, type="resistance",
strength=5, detection_method="pivot_point"),
# Short candidates (support below entry)
SRLevel(ticker_id=ticker.id, price_level=95.0, type="support",
strength=85, detection_method="volume_profile"),
SRLevel(ticker_id=ticker.id, price_level=85.0, type="support",
strength=35, detection_method="pivot_point"),
SRLevel(ticker_id=ticker.id, price_level=65.0, type="support",
strength=8, detection_method="volume_profile"),
]
scan_session.add_all(sr_levels)
# -- Setup: CompositeScore --
comp = CompositeScore(
ticker_id=ticker.id,
score=72.5,
is_stale=False,
weights_json="{}",
computed_at=datetime.now(timezone.utc),
)
scan_session.add(comp)
# -- Setup: dummy old setups that should be deleted --
old_setup = TradeSetup(
ticker_id=ticker.id,
direction="long",
entry_price=99.0,
stop_loss=96.0,
target=120.0,
rr_ratio=7.0,
composite_score=50.0,
detected_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
)
scan_session.add(old_setup)
await scan_session.commit()
# Verify old setup exists before scan
pre_result = await scan_session.execute(
select(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
)
pre_setups = list(pre_result.scalars().all())
assert len(pre_setups) == 1, "Dummy old setup should exist before scan"
# -- Act: run scan_ticker --
setups = await scan_ticker(scan_session, "INTEG", rr_threshold=1.5, atr_multiplier=1.5)
# -- Assert: both directions produced --
assert len(setups) == 2, f"Expected 2 setups (long + short), got {len(setups)}"
long_setups = [s for s in setups if s.direction == "long"]
short_setups = [s for s in setups if s.direction == "short"]
assert len(long_setups) == 1, f"Expected 1 long setup, got {len(long_setups)}"
assert len(short_setups) == 1, f"Expected 1 short setup, got {len(short_setups)}"
long_setup = long_setups[0]
short_setup = short_setups[0]
# -- Assert: long target is Level A (highest quality, not most distant) --
# Level A: price=105 (strong, near) should beat Level C: price=135 (weak, far)
assert long_setup.target == pytest.approx(105.0, abs=0.01), (
f"Long target should be 105.0 (highest quality), got {long_setup.target}"
)
# -- Assert: short target is Level D (highest quality, not most distant) --
# Level D: price=95 (strong, near) should beat Level F: price=65 (weak, far)
assert short_setup.target == pytest.approx(95.0, abs=0.01), (
f"Short target should be 95.0 (highest quality), got {short_setup.target}"
)
# -- Assert: entry_price is the last close (≈ 100) --
# Last bar: index 19, close = 100 + (19 % 3 - 1) * 0.5 = 100 + 0*0.5 = 100.0
expected_entry = 100.0
assert long_setup.entry_price == pytest.approx(expected_entry, abs=0.5)
assert short_setup.entry_price == pytest.approx(expected_entry, abs=0.5)
entry = long_setup.entry_price # actual entry for R:R calculations
# -- Assert: stop_loss values --
# ATR ≈ 2.0, risk = ATR × 1.5 = 3.0
# Long stop = entry - risk, Short stop = entry + risk
risk = long_setup.entry_price - long_setup.stop_loss
assert risk > 0, "Long risk must be positive"
assert short_setup.stop_loss > short_setup.entry_price, "Short stop must be above entry"
# -- Assert: rr_ratio is the actual R:R of the selected level --
long_reward = long_setup.target - long_setup.entry_price
long_expected_rr = round(long_reward / risk, 4)
assert long_setup.rr_ratio == pytest.approx(long_expected_rr, abs=0.01), (
f"Long rr_ratio should be actual R:R={long_expected_rr}, got {long_setup.rr_ratio}"
)
short_risk = short_setup.stop_loss - short_setup.entry_price
short_reward = short_setup.entry_price - short_setup.target
short_expected_rr = round(short_reward / short_risk, 4)
assert short_setup.rr_ratio == pytest.approx(short_expected_rr, abs=0.01), (
f"Short rr_ratio should be actual R:R={short_expected_rr}, got {short_setup.rr_ratio}"
)
# -- Assert: composite_score matches --
assert long_setup.composite_score == pytest.approx(72.5, abs=0.01)
assert short_setup.composite_score == pytest.approx(72.5, abs=0.01)
# -- Assert: ticker_id is correct --
assert long_setup.ticker_id == ticker.id
assert short_setup.ticker_id == ticker.id
# -- Assert: detected_at is set --
assert long_setup.detected_at is not None
assert short_setup.detected_at is not None
# -- Assert: fields are rounded to 4 decimal places --
for setup in [long_setup, short_setup]:
for field_name in ("entry_price", "stop_loss", "target", "rr_ratio", "composite_score"):
val = getattr(setup, field_name)
rounded = round(val, 4)
assert val == pytest.approx(rounded, abs=1e-6), (
f"{setup.direction} {field_name}={val} not rounded to 4 decimals"
)
# -- Assert: database persistence --
# Old dummy setup should be gone, only the 2 new setups should exist
db_result = await scan_session.execute(
select(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
)
persisted = list(db_result.scalars().all())
assert len(persisted) == 2, (
f"Expected 2 persisted setups (old deleted), got {len(persisted)}"
)
persisted_directions = sorted(s.direction for s in persisted)
assert persisted_directions == ["long", "short"], (
f"Expected ['long', 'short'] persisted, got {persisted_directions}"
)
# Verify persisted records match returned setups
persisted_long = [s for s in persisted if s.direction == "long"][0]
persisted_short = [s for s in persisted if s.direction == "short"][0]
assert persisted_long.target == long_setup.target
assert persisted_long.rr_ratio == long_setup.rr_ratio
assert persisted_long.entry_price == long_setup.entry_price
assert persisted_long.stop_loss == long_setup.stop_loss
assert persisted_long.composite_score == long_setup.composite_score
assert persisted_short.target == short_setup.target
assert persisted_short.rr_ratio == short_setup.rr_ratio
assert persisted_short.entry_price == short_setup.entry_price
assert persisted_short.stop_loss == short_setup.stop_loss
assert persisted_short.composite_score == short_setup.composite_score

View File

@@ -0,0 +1,433 @@
"""Preservation tests for R:R scanner target quality bugfix.
Verify that the fix does NOT change behavior for zero-candidate and
single-candidate scenarios, and that get_trade_setups sorting is unchanged.
The fix only changes selection logic when MULTIPLE candidates exist.
Zero-candidate and single-candidate scenarios must produce identical results.
**Validates: Requirements 3.1, 3.2, 3.3, 3.5**
"""
from __future__ import annotations
from datetime import date, datetime, timedelta, timezone
import pytest
from hypothesis import given, settings, HealthCheck, strategies as st
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.ohlcv import OHLCVRecord
from app.models.sr_level import SRLevel
from app.models.ticker import Ticker
from app.models.trade_setup import TradeSetup
from app.models.score import CompositeScore
from app.services.rr_scanner_service import scan_ticker, get_trade_setups
# ---------------------------------------------------------------------------
# Session fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
async def scan_session() -> AsyncSession:
"""Provide a DB session compatible with scan_ticker (which commits)."""
from tests.conftest import _test_session_factory
async with _test_session_factory() as session:
yield session
@pytest.fixture
async def db_session() -> AsyncSession:
"""Provide a transactional DB session for get_trade_setups tests."""
from tests.conftest import _test_session_factory
async with _test_session_factory() as session:
async with session.begin():
yield session
await session.rollback()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ohlcv_bars(
ticker_id: int,
num_bars: int = 20,
base_close: float = 100.0,
) -> list[OHLCVRecord]:
"""Generate OHLCV bars closing around base_close with ATR ≈ 2.0."""
bars: list[OHLCVRecord] = []
start = date(2024, 1, 1)
for i in range(num_bars):
close = base_close + (i % 3 - 1) * 0.5 # oscillate ±0.5
bars.append(OHLCVRecord(
ticker_id=ticker_id,
date=start + timedelta(days=i),
open=close - 0.3,
high=close + 1.0,
low=close - 1.0,
close=close,
volume=100_000,
))
return bars
# ===========================================================================
# 7.1 [PBT-preservation] Property test: zero-candidate and single-candidate
# scenarios produce the same output as the original code.
# ===========================================================================
@st.composite
def zero_candidate_scenario(draw: st.DrawFn) -> dict:
"""Generate a scenario where no S/R levels qualify as candidates.
Variants:
- No SR levels at all
- All levels below entry (no long targets) and all above entry (no short targets)
but all below the R:R threshold for their respective directions
- Levels in the right direction but below R:R threshold
Note: scan_ticker does NOT filter by SR level type — it only checks whether
the price_level is above or below entry. So "wrong side" means all levels
are clustered near entry and below threshold in both directions.
"""
variant = draw(st.sampled_from(["no_levels", "below_threshold"]))
if variant == "no_levels":
return {"variant": variant, "levels": []}
else: # below_threshold
# All levels close to entry so R:R < 1.5 with risk ≈ 3
# For longs: reward < 4.5 → price < 104.5
# For shorts: reward < 4.5 → price > 95.5
# Place all levels in the 96104 band (below threshold both ways)
num = draw(st.integers(min_value=1, max_value=3))
levels = []
for _ in range(num):
price = draw(st.floats(min_value=100.5, max_value=103.5))
levels.append({
"price": price,
"type": draw(st.sampled_from(["resistance", "support"])),
"strength": draw(st.integers(min_value=10, max_value=100)),
})
# Also add some below entry but still below threshold
for _ in range(draw(st.integers(min_value=0, max_value=2))):
price = draw(st.floats(min_value=96.5, max_value=99.5))
levels.append({
"price": price,
"type": draw(st.sampled_from(["resistance", "support"])),
"strength": draw(st.integers(min_value=10, max_value=100)),
})
return {"variant": variant, "levels": levels}
@st.composite
def single_candidate_scenario(draw: st.DrawFn) -> dict:
"""Generate a scenario with exactly one S/R level that meets the R:R threshold.
For longs: one resistance above entry with R:R >= 1.5 (price >= 104.5 with risk ≈ 3).
"""
direction = draw(st.sampled_from(["long", "short"]))
if direction == "long":
# Single resistance above entry meeting threshold
price = draw(st.floats(min_value=105.0, max_value=150.0))
strength = draw(st.integers(min_value=1, max_value=100))
return {
"direction": direction,
"level": {"price": price, "type": "resistance", "strength": strength},
}
else:
# Single support below entry meeting threshold
price = draw(st.floats(min_value=50.0, max_value=95.0))
strength = draw(st.integers(min_value=1, max_value=100))
return {
"direction": direction,
"level": {"price": price, "type": "support", "strength": strength},
}
@pytest.mark.asyncio
@given(scenario=zero_candidate_scenario())
@settings(
max_examples=15,
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
)
async def test_property_zero_candidates_produce_no_setup(
scenario: dict,
scan_session: AsyncSession,
):
"""**Validates: Requirements 3.1, 3.2**
Property: when zero candidate S/R levels exist (no levels, wrong side,
or below threshold), scan_ticker produces no setup — unchanged from
original behavior.
"""
from tests.conftest import _test_engine, _test_session_factory
from app.database import Base
async with _test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
async with _test_session_factory() as session:
ticker = Ticker(symbol="PRSV0")
session.add(ticker)
await session.flush()
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
session.add_all(bars)
for lv_data in scenario.get("levels", []):
session.add(SRLevel(
ticker_id=ticker.id,
price_level=lv_data["price"],
type=lv_data["type"],
strength=lv_data["strength"],
detection_method="volume_profile",
))
await session.commit()
setups = await scan_ticker(session, "PRSV0", rr_threshold=1.5)
assert setups == [], (
f"Expected no setups for zero-candidate scenario "
f"(variant={scenario.get('variant', 'unknown')}), got {len(setups)}"
)
@pytest.mark.asyncio
@given(scenario=single_candidate_scenario())
@settings(
max_examples=15,
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
)
async def test_property_single_candidate_selected_unchanged(
scenario: dict,
scan_session: AsyncSession,
):
"""**Validates: Requirements 3.3**
Property: when exactly one candidate S/R level meets the R:R threshold,
scan_ticker selects it — same as the original code would.
"""
from tests.conftest import _test_engine, _test_session_factory
from app.database import Base
async with _test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
async with _test_session_factory() as session:
ticker = Ticker(symbol="PRSV1")
session.add(ticker)
await session.flush()
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
session.add_all(bars)
lv = scenario["level"]
session.add(SRLevel(
ticker_id=ticker.id,
price_level=lv["price"],
type=lv["type"],
strength=lv["strength"],
detection_method="volume_profile",
))
await session.commit()
setups = await scan_ticker(session, "PRSV1", rr_threshold=1.5)
direction = scenario["direction"]
dir_setups = [s for s in setups if s.direction == direction]
assert len(dir_setups) == 1, (
f"Expected exactly one {direction} setup for single candidate, "
f"got {len(dir_setups)}"
)
selected_target = dir_setups[0].target
expected_target = round(lv["price"], 4)
assert selected_target == pytest.approx(expected_target, abs=0.01), (
f"Single candidate: expected target={expected_target}, "
f"got {selected_target}"
)
# ===========================================================================
# 7.2 Unit test: no S/R levels → no setup produced
# ===========================================================================
@pytest.mark.asyncio
async def test_no_sr_levels_produces_no_setup(scan_session: AsyncSession):
"""**Validates: Requirements 3.1**
When a ticker has OHLCV data but no S/R levels at all,
scan_ticker should return an empty list.
"""
ticker = Ticker(symbol="NOSRL")
scan_session.add(ticker)
await scan_session.flush()
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
scan_session.add_all(bars)
await scan_session.flush()
setups = await scan_ticker(scan_session, "NOSRL", rr_threshold=1.5)
assert setups == [], (
f"Expected no setups when no SR levels exist, got {len(setups)}"
)
# ===========================================================================
# 7.3 Unit test: single candidate meets threshold → selected
# ===========================================================================
@pytest.mark.asyncio
async def test_single_resistance_above_threshold_selected(scan_session: AsyncSession):
"""**Validates: Requirements 3.3**
When exactly one resistance level above entry meets the R:R threshold,
it should be selected as the long setup target.
Entry ≈ 100, ATR ≈ 2, risk ≈ 3. Resistance at 110 → R:R ≈ 3.33 (>= 1.5).
"""
ticker = Ticker(symbol="SINGL")
scan_session.add(ticker)
await scan_session.flush()
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
scan_session.add_all(bars)
level = SRLevel(
ticker_id=ticker.id,
price_level=110.0,
type="resistance",
strength=60,
detection_method="volume_profile",
)
scan_session.add(level)
await scan_session.flush()
setups = await scan_ticker(scan_session, "SINGL", rr_threshold=1.5)
long_setups = [s for s in setups if s.direction == "long"]
assert len(long_setups) == 1, (
f"Expected exactly one long setup, got {len(long_setups)}"
)
assert long_setups[0].target == pytest.approx(110.0, abs=0.01), (
f"Expected target=110.0, got {long_setups[0].target}"
)
@pytest.mark.asyncio
async def test_single_support_below_threshold_selected(scan_session: AsyncSession):
"""**Validates: Requirements 3.3**
When exactly one support level below entry meets the R:R threshold,
it should be selected as the short setup target.
Entry ≈ 100, ATR ≈ 2, risk ≈ 3. Support at 90 → R:R ≈ 3.33 (>= 1.5).
"""
ticker = Ticker(symbol="SINGS")
scan_session.add(ticker)
await scan_session.flush()
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
scan_session.add_all(bars)
level = SRLevel(
ticker_id=ticker.id,
price_level=90.0,
type="support",
strength=55,
detection_method="pivot_point",
)
scan_session.add(level)
await scan_session.flush()
setups = await scan_ticker(scan_session, "SINGS", rr_threshold=1.5)
short_setups = [s for s in setups if s.direction == "short"]
assert len(short_setups) == 1, (
f"Expected exactly one short setup, got {len(short_setups)}"
)
assert short_setups[0].target == pytest.approx(90.0, abs=0.01), (
f"Expected target=90.0, got {short_setups[0].target}"
)
# ===========================================================================
# 7.4 Unit test: get_trade_setups sorting is unchanged (R:R desc, composite desc)
# ===========================================================================
@pytest.mark.asyncio
async def test_get_trade_setups_sorting_rr_desc_composite_desc(db_session: AsyncSession):
"""**Validates: Requirements 3.5**
get_trade_setups must return results sorted by rr_ratio descending,
with composite_score descending as the secondary sort key.
"""
now = datetime.now(timezone.utc)
# Create tickers for each setup
ticker_a = Ticker(symbol="SORTA")
ticker_b = Ticker(symbol="SORTB")
ticker_c = Ticker(symbol="SORTC")
ticker_d = Ticker(symbol="SORTD")
db_session.add_all([ticker_a, ticker_b, ticker_c, ticker_d])
await db_session.flush()
# Create setups with different rr_ratio and composite_score values
# Expected order: D (rr=5.0), C (rr=3.0, comp=80), B (rr=3.0, comp=50), A (rr=1.5)
setup_a = TradeSetup(
ticker_id=ticker_a.id, direction="long",
entry_price=100.0, stop_loss=97.0, target=104.5,
rr_ratio=1.5, composite_score=90.0, detected_at=now,
)
setup_b = TradeSetup(
ticker_id=ticker_b.id, direction="long",
entry_price=100.0, stop_loss=97.0, target=109.0,
rr_ratio=3.0, composite_score=50.0, detected_at=now,
)
setup_c = TradeSetup(
ticker_id=ticker_c.id, direction="short",
entry_price=100.0, stop_loss=103.0, target=91.0,
rr_ratio=3.0, composite_score=80.0, detected_at=now,
)
setup_d = TradeSetup(
ticker_id=ticker_d.id, direction="long",
entry_price=100.0, stop_loss=97.0, target=115.0,
rr_ratio=5.0, composite_score=30.0, detected_at=now,
)
db_session.add_all([setup_a, setup_b, setup_c, setup_d])
await db_session.flush()
results = await get_trade_setups(db_session)
assert len(results) == 4, f"Expected 4 setups, got {len(results)}"
# Verify ordering: rr_ratio desc, then composite_score desc
rr_values = [r["rr_ratio"] for r in results]
assert rr_values == [5.0, 3.0, 3.0, 1.5], (
f"Expected rr_ratio order [5.0, 3.0, 3.0, 1.5], got {rr_values}"
)
# For the two setups with rr_ratio=3.0, composite_score should be desc
tied_composites = [r["composite_score"] for r in results if r["rr_ratio"] == 3.0]
assert tied_composites == [80.0, 50.0], (
f"Expected composite_score order [80.0, 50.0] for tied R:R, "
f"got {tied_composites}"
)
# Verify symbols match expected order
symbols = [r["symbol"] for r in results]
assert symbols == ["SORTD", "SORTC", "SORTB", "SORTA"], (
f"Expected symbol order ['SORTD', 'SORTC', 'SORTB', 'SORTA'], "
f"got {symbols}"
)

View File

@@ -0,0 +1,159 @@
"""Unit tests for _compute_quality_score in rr_scanner_service."""
import pytest
from app.services.rr_scanner_service import _compute_quality_score
# ---------------------------------------------------------------------------
# 4.1 — Known inputs with hand-computed expected outputs
# ---------------------------------------------------------------------------
class TestKnownInputs:
def test_typical_candidate(self):
# rr=5, strength=80, distance=3, entry_price=100
# norm_rr = 5/10 = 0.5
# norm_strength = 80/100 = 0.8
# norm_proximity = 1 - 3/100 = 0.97
# score = 0.35*0.5 + 0.35*0.8 + 0.30*0.97 = 0.175 + 0.28 + 0.291 = 0.746
result = _compute_quality_score(rr=5.0, strength=80, distance=3.0, entry_price=100.0)
assert result == pytest.approx(0.746, abs=1e-9)
def test_weak_distant_candidate(self):
# rr=8, strength=10, distance=50, entry_price=100
# norm_rr = 8/10 = 0.8
# norm_strength = 10/100 = 0.1
# norm_proximity = 1 - 50/100 = 0.5
# score = 0.35*0.8 + 0.35*0.1 + 0.30*0.5 = 0.28 + 0.035 + 0.15 = 0.465
result = _compute_quality_score(rr=8.0, strength=10, distance=50.0, entry_price=100.0)
assert result == pytest.approx(0.465, abs=1e-9)
def test_strong_near_candidate(self):
# rr=2, strength=95, distance=5, entry_price=200
# norm_rr = 2/10 = 0.2
# norm_strength = 95/100 = 0.95
# norm_proximity = 1 - 5/200 = 0.975
# score = 0.35*0.2 + 0.35*0.95 + 0.30*0.975 = 0.07 + 0.3325 + 0.2925 = 0.695
result = _compute_quality_score(rr=2.0, strength=95, distance=5.0, entry_price=200.0)
assert result == pytest.approx(0.695, abs=1e-9)
def test_custom_weights(self):
# rr=4, strength=50, distance=10, entry_price=100
# norm_rr = 4/10 = 0.4, norm_strength = 0.5, norm_proximity = 1 - 10/100 = 0.9
# With w_rr=0.5, w_strength=0.3, w_proximity=0.2:
# score = 0.5*0.4 + 0.3*0.5 + 0.2*0.9 = 0.2 + 0.15 + 0.18 = 0.53
result = _compute_quality_score(
rr=4.0, strength=50, distance=10.0, entry_price=100.0,
w_rr=0.5, w_strength=0.3, w_proximity=0.2,
)
assert result == pytest.approx(0.53, abs=1e-9)
def test_custom_rr_cap(self):
# rr=3, strength=60, distance=8, entry_price=100, rr_cap=5
# norm_rr = 3/5 = 0.6
# norm_strength = 60/100 = 0.6
# norm_proximity = 1 - 8/100 = 0.92
# score = 0.35*0.6 + 0.35*0.6 + 0.30*0.92 = 0.21 + 0.21 + 0.276 = 0.696
result = _compute_quality_score(
rr=3.0, strength=60, distance=8.0, entry_price=100.0, rr_cap=5.0,
)
assert result == pytest.approx(0.696, abs=1e-9)
# ---------------------------------------------------------------------------
# 4.2 — Edge cases
# ---------------------------------------------------------------------------
class TestEdgeCases:
def test_strength_zero(self):
# strength=0 → norm_strength=0
# rr=5, distance=10, entry_price=100
# norm_rr=0.5, norm_strength=0.0, norm_proximity=0.9
# score = 0.35*0.5 + 0.35*0.0 + 0.30*0.9 = 0.175 + 0.0 + 0.27 = 0.445
result = _compute_quality_score(rr=5.0, strength=0, distance=10.0, entry_price=100.0)
assert result == pytest.approx(0.445, abs=1e-9)
def test_strength_100(self):
# strength=100 → norm_strength=1.0
# rr=5, distance=10, entry_price=100
# norm_rr=0.5, norm_strength=1.0, norm_proximity=0.9
# score = 0.35*0.5 + 0.35*1.0 + 0.30*0.9 = 0.175 + 0.35 + 0.27 = 0.795
result = _compute_quality_score(rr=5.0, strength=100, distance=10.0, entry_price=100.0)
assert result == pytest.approx(0.795, abs=1e-9)
def test_distance_zero(self):
# distance=0 → norm_proximity = 1 - 0/100 = 1.0
# rr=5, strength=50, entry_price=100
# norm_rr=0.5, norm_strength=0.5, norm_proximity=1.0
# score = 0.35*0.5 + 0.35*0.5 + 0.30*1.0 = 0.175 + 0.175 + 0.3 = 0.65
result = _compute_quality_score(rr=5.0, strength=50, distance=0.0, entry_price=100.0)
assert result == pytest.approx(0.65, abs=1e-9)
def test_rr_at_cap(self):
# rr=10 (== rr_cap) → norm_rr = min(10/10, 1.0) = 1.0
# strength=50, distance=10, entry_price=100
# norm_strength=0.5, norm_proximity=0.9
# score = 0.35*1.0 + 0.35*0.5 + 0.30*0.9 = 0.35 + 0.175 + 0.27 = 0.795
result = _compute_quality_score(rr=10.0, strength=50, distance=10.0, entry_price=100.0)
assert result == pytest.approx(0.795, abs=1e-9)
def test_rr_above_cap(self):
# rr=15 (> rr_cap=10) → norm_rr = min(15/10, 1.0) = 1.0 (capped)
# Same result as rr_at_cap since norm_rr is capped at 1.0
result = _compute_quality_score(rr=15.0, strength=50, distance=10.0, entry_price=100.0)
assert result == pytest.approx(0.795, abs=1e-9)
def test_rr_above_cap_equals_rr_at_cap(self):
# Explicitly verify capping: rr=15 and rr=10 produce the same score
at_cap = _compute_quality_score(rr=10.0, strength=50, distance=10.0, entry_price=100.0)
above_cap = _compute_quality_score(rr=15.0, strength=50, distance=10.0, entry_price=100.0)
assert at_cap == pytest.approx(above_cap, abs=1e-9)
# ---------------------------------------------------------------------------
# 4.3 — Normalized components stay in 01 range
# ---------------------------------------------------------------------------
class TestNormalizedComponentsRange:
"""Verify that each normalized component stays within [0, 1]."""
@pytest.mark.parametrize("rr, rr_cap", [
(0.0, 10.0),
(5.0, 10.0),
(10.0, 10.0),
(15.0, 10.0),
(100.0, 10.0),
(3.0, 5.0),
(7.0, 5.0),
])
def test_norm_rr_in_range(self, rr, rr_cap):
norm_rr = min(rr / rr_cap, 1.0)
assert 0.0 <= norm_rr <= 1.0
@pytest.mark.parametrize("strength", [0, 1, 50, 99, 100])
def test_norm_strength_in_range(self, strength):
norm_strength = strength / 100.0
assert 0.0 <= norm_strength <= 1.0
@pytest.mark.parametrize("distance, entry_price", [
(0.0, 100.0),
(1.0, 100.0),
(50.0, 100.0),
(100.0, 100.0),
(200.0, 100.0), # distance > entry_price → clamped
(0.5, 50.0),
])
def test_norm_proximity_in_range(self, distance, entry_price):
norm_proximity = 1.0 - min(distance / entry_price, 1.0)
assert 0.0 <= norm_proximity <= 1.0
@pytest.mark.parametrize("rr, strength, distance, entry_price", [
(0.0, 0, 0.0, 100.0), # all minimums
(10.0, 100, 0.0, 100.0), # all maximums
(15.0, 100, 200.0, 100.0), # rr above cap, distance > entry
(5.0, 50, 50.0, 100.0), # mid-range values
(1.0, 1, 99.0, 100.0), # near-minimum non-zero
])
def test_total_score_in_zero_one(self, rr, strength, distance, entry_price):
score = _compute_quality_score(rr=rr, strength=strength, distance=distance, entry_price=entry_price)
assert 0.0 <= score <= 1.0

View File

@@ -0,0 +1,205 @@
"""Unit tests for get_score composite breakdown and dimension breakdown wiring."""
from __future__ import annotations
from datetime import date
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.database import Base
from app.models.ticker import Ticker
from app.services.scoring_service import get_score, _DIMENSION_COMPUTERS
TEST_DATABASE_URL = "sqlite+aiosqlite://"
@pytest.fixture
async def fresh_db():
"""Provide a non-transactional session so get_score can commit."""
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with session_factory() as session:
yield session
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
def _make_ohlcv_records(n: int, base_close: float = 100.0) -> list:
"""Create n mock OHLCV records with realistic price data."""
records = []
for i in range(n):
price = base_close + (i * 0.5)
records.append(
SimpleNamespace(
date=date(2024, 1, 1),
open=price - 0.5,
high=price + 1.0,
low=price - 1.0,
close=price,
volume=1000000,
)
)
return records
def _mock_none_computer():
"""Return an AsyncMock that returns (None, None) — simulates missing dimension data."""
return AsyncMock(return_value=(None, None))
def _mock_score_computer(score: float, breakdown: dict | None = None):
"""Return an AsyncMock that returns a fixed (score, breakdown) tuple."""
bd = breakdown or {
"sub_scores": [{"name": "mock", "score": score, "weight": 1.0, "raw_value": score, "description": "mock"}],
"formula": "mock formula",
"unavailable": [],
}
return AsyncMock(return_value=(score, bd))
async def _seed_ticker(session: AsyncSession, symbol: str = "AAPL") -> Ticker:
"""Insert a ticker row and return it."""
ticker = Ticker(symbol=symbol)
session.add(ticker)
await session.commit()
return ticker
@pytest.mark.asyncio
async def test_get_score_returns_composite_breakdown(fresh_db):
"""get_score should include a composite_breakdown dict with weights and re-normalization info."""
await _seed_ticker(fresh_db, "AAPL")
original = dict(_DIMENSION_COMPUTERS)
try:
_DIMENSION_COMPUTERS["technical"] = _mock_score_computer(70.0)
_DIMENSION_COMPUTERS["momentum"] = _mock_score_computer(60.0)
_DIMENSION_COMPUTERS["sentiment"] = _mock_none_computer()
_DIMENSION_COMPUTERS["fundamental"] = _mock_none_computer()
_DIMENSION_COMPUTERS["sr_quality"] = _mock_none_computer()
result = await get_score(fresh_db, "AAPL")
finally:
_DIMENSION_COMPUTERS.update(original)
assert "composite_breakdown" in result
cb = result["composite_breakdown"]
assert cb is not None
assert "weights" in cb
assert "available_dimensions" in cb
assert "missing_dimensions" in cb
assert "renormalized_weights" in cb
assert "formula" in cb
@pytest.mark.asyncio
async def test_get_score_composite_breakdown_has_correct_available_missing(fresh_db):
"""Composite breakdown should correctly list available and missing dimensions."""
await _seed_ticker(fresh_db, "AAPL")
original = dict(_DIMENSION_COMPUTERS)
try:
_DIMENSION_COMPUTERS["technical"] = _mock_score_computer(70.0)
_DIMENSION_COMPUTERS["momentum"] = _mock_score_computer(60.0)
_DIMENSION_COMPUTERS["sentiment"] = _mock_none_computer()
_DIMENSION_COMPUTERS["fundamental"] = _mock_none_computer()
_DIMENSION_COMPUTERS["sr_quality"] = _mock_none_computer()
result = await get_score(fresh_db, "AAPL")
finally:
_DIMENSION_COMPUTERS.update(original)
cb = result["composite_breakdown"]
assert "technical" in cb["available_dimensions"]
assert "momentum" in cb["available_dimensions"]
assert "sentiment" in cb["missing_dimensions"]
assert "fundamental" in cb["missing_dimensions"]
assert "sr_quality" in cb["missing_dimensions"]
@pytest.mark.asyncio
async def test_get_score_renormalized_weights_sum_to_one(fresh_db):
"""Re-normalized weights should sum to 1.0 when at least one dimension is available."""
await _seed_ticker(fresh_db, "AAPL")
original = dict(_DIMENSION_COMPUTERS)
try:
_DIMENSION_COMPUTERS["technical"] = _mock_score_computer(70.0)
_DIMENSION_COMPUTERS["momentum"] = _mock_score_computer(60.0)
_DIMENSION_COMPUTERS["sentiment"] = _mock_none_computer()
_DIMENSION_COMPUTERS["fundamental"] = _mock_none_computer()
_DIMENSION_COMPUTERS["sr_quality"] = _mock_none_computer()
result = await get_score(fresh_db, "AAPL")
finally:
_DIMENSION_COMPUTERS.update(original)
cb = result["composite_breakdown"]
assert cb["renormalized_weights"]
total = sum(cb["renormalized_weights"].values())
assert abs(total - 1.0) < 1e-9
@pytest.mark.asyncio
async def test_get_score_dimensions_include_breakdowns(fresh_db):
"""Each available dimension entry should include a breakdown dict."""
await _seed_ticker(fresh_db, "AAPL")
tech_breakdown = {
"sub_scores": [
{"name": "ADX", "score": 72.0, "weight": 0.4, "raw_value": 72.0, "description": "ADX value"},
{"name": "EMA", "score": 65.0, "weight": 0.3, "raw_value": 1.5, "description": "EMA diff"},
{"name": "RSI", "score": 62.0, "weight": 0.3, "raw_value": 62.0, "description": "RSI value"},
],
"formula": "Weighted average: 0.4*ADX + 0.3*EMA + 0.3*RSI",
"unavailable": [],
}
original = dict(_DIMENSION_COMPUTERS)
try:
_DIMENSION_COMPUTERS["technical"] = _mock_score_computer(68.2, tech_breakdown)
_DIMENSION_COMPUTERS["momentum"] = _mock_score_computer(55.0)
_DIMENSION_COMPUTERS["sentiment"] = _mock_none_computer()
_DIMENSION_COMPUTERS["fundamental"] = _mock_none_computer()
_DIMENSION_COMPUTERS["sr_quality"] = _mock_none_computer()
result = await get_score(fresh_db, "AAPL")
finally:
_DIMENSION_COMPUTERS.update(original)
tech_dim = next((d for d in result["dimensions"] if d["dimension"] == "technical"), None)
assert tech_dim is not None
assert "breakdown" in tech_dim
assert tech_dim["breakdown"] is not None
assert len(tech_dim["breakdown"]["sub_scores"]) == 3
names = [s["name"] for s in tech_dim["breakdown"]["sub_scores"]]
assert "ADX" in names
assert "EMA" in names
assert "RSI" in names
@pytest.mark.asyncio
async def test_get_score_all_dimensions_missing(fresh_db):
"""When all dimensions return None, composite_breakdown should list all as missing."""
await _seed_ticker(fresh_db, "AAPL")
original = dict(_DIMENSION_COMPUTERS)
try:
for dim in _DIMENSION_COMPUTERS:
_DIMENSION_COMPUTERS[dim] = _mock_none_computer()
result = await get_score(fresh_db, "AAPL")
finally:
_DIMENSION_COMPUTERS.update(original)
cb = result["composite_breakdown"]
assert cb["available_dimensions"] == []
assert len(cb["missing_dimensions"]) == 5
assert cb["renormalized_weights"] == {}
assert result["composite_score"] is None

View File

@@ -0,0 +1,150 @@
"""Unit tests for _compute_momentum_score breakdown refactor."""
from __future__ import annotations
from datetime import date
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from app.services.scoring_service import _compute_momentum_score
def _make_ohlcv_records(n: int, base_close: float = 100.0) -> list:
"""Create n mock OHLCV records with incrementing close prices."""
records = []
for i in range(n):
price = base_close + (i * 0.5)
records.append(
SimpleNamespace(
date=date(2024, 1, 1),
open=price - 0.5,
high=price + 1.0,
low=price - 1.0,
close=price,
volume=1000000,
)
)
return records
@pytest.mark.asyncio
async def test_returns_none_none_when_no_records(db_session):
"""When no OHLCV data exists, returns (None, None)."""
with patch(
"app.services.price_service.query_ohlcv",
new_callable=AsyncMock,
return_value=[],
):
score, breakdown = await _compute_momentum_score(db_session, "AAPL")
assert score is None
assert breakdown is None
@pytest.mark.asyncio
async def test_returns_none_none_when_fewer_than_6_records(db_session):
"""With fewer than 6 records, returns (None, None)."""
records = _make_ohlcv_records(5)
with patch(
"app.services.price_service.query_ohlcv",
new_callable=AsyncMock,
return_value=records,
):
score, breakdown = await _compute_momentum_score(db_session, "AAPL")
assert score is None
assert breakdown is None
@pytest.mark.asyncio
async def test_returns_only_5day_roc_when_fewer_than_21_records(db_session):
"""With 6-20 records, returns only 5-day ROC sub-score; 20-day ROC is unavailable."""
records = _make_ohlcv_records(10)
with patch(
"app.services.price_service.query_ohlcv",
new_callable=AsyncMock,
return_value=records,
):
score, breakdown = await _compute_momentum_score(db_session, "AAPL")
assert score is not None
assert 0.0 <= score <= 100.0
assert breakdown is not None
names = [s["name"] for s in breakdown["sub_scores"]]
assert "5-day ROC" in names
assert "20-day ROC" not in names
# 20-day ROC should be unavailable
unavail_names = [u["name"] for u in breakdown["unavailable"]]
assert "20-day ROC" in unavail_names
@pytest.mark.asyncio
async def test_returns_both_sub_scores_with_enough_data(db_session):
"""With 21+ records, returns both 5-day and 20-day ROC sub-scores."""
records = _make_ohlcv_records(30)
with patch(
"app.services.price_service.query_ohlcv",
new_callable=AsyncMock,
return_value=records,
):
score, breakdown = await _compute_momentum_score(db_session, "AAPL")
assert score is not None
assert 0.0 <= score <= 100.0
assert breakdown is not None
names = [s["name"] for s in breakdown["sub_scores"]]
assert "5-day ROC" in names
assert "20-day ROC" in names
# Verify weights
weight_map = {s["name"]: s["weight"] for s in breakdown["sub_scores"]}
assert weight_map["5-day ROC"] == 0.5
assert weight_map["20-day ROC"] == 0.5
# Verify raw_value is present and numeric
for sub in breakdown["sub_scores"]:
assert sub["raw_value"] is not None
assert isinstance(sub["raw_value"], (int, float))
assert sub["description"]
assert breakdown["unavailable"] == []
@pytest.mark.asyncio
async def test_formula_string_present(db_session):
"""Breakdown always includes the formula description."""
records = _make_ohlcv_records(30)
with patch(
"app.services.price_service.query_ohlcv",
new_callable=AsyncMock,
return_value=records,
):
_, breakdown = await _compute_momentum_score(db_session, "AAPL")
assert "formula" in breakdown
assert "ROC_5" in breakdown["formula"]
assert "ROC_20" in breakdown["formula"]
@pytest.mark.asyncio
async def test_raw_values_are_roc_percentages(db_session):
"""Raw values should be ROC percentages matching the actual price change."""
records = _make_ohlcv_records(30, base_close=100.0)
with patch(
"app.services.price_service.query_ohlcv",
new_callable=AsyncMock,
return_value=records,
):
_, breakdown = await _compute_momentum_score(db_session, "AAPL")
closes = [100.0 + i * 0.5 for i in range(30)]
latest = closes[-1]
expected_roc_5 = (latest - closes[-6]) / closes[-6] * 100.0
expected_roc_20 = (latest - closes[-21]) / closes[-21] * 100.0
roc_map = {s["name"]: s["raw_value"] for s in breakdown["sub_scores"]}
assert abs(roc_map["5-day ROC"] - round(expected_roc_5, 4)) < 1e-6
assert abs(roc_map["20-day ROC"] - round(expected_roc_20, 4)) < 1e-6

View File

@@ -0,0 +1,138 @@
"""Unit tests for _compute_sentiment_score breakdown refactor."""
from __future__ import annotations
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from app.services.scoring_service import _compute_sentiment_score
def _make_sentiment_records(n: int, classification: str = "bullish", confidence: int = 80) -> list:
"""Create n mock sentiment records with recent timestamps."""
now = datetime.now(timezone.utc)
records = []
for i in range(n):
records.append(
SimpleNamespace(
classification=classification,
confidence=confidence,
source="test",
timestamp=now,
reasoning="",
citations_json="[]",
)
)
return records
@pytest.mark.asyncio
async def test_returns_none_with_breakdown_when_no_records(db_session):
"""When no sentiment records exist, returns (None, breakdown) with unavailable entry."""
with patch(
"app.services.sentiment_service.get_sentiment_scores",
new_callable=AsyncMock,
return_value=[],
):
score, breakdown = await _compute_sentiment_score(db_session, "AAPL")
assert score is None
assert breakdown is not None
assert breakdown["sub_scores"] == []
assert len(breakdown["unavailable"]) == 1
assert breakdown["unavailable"][0]["name"] == "sentiment_records"
assert "formula" in breakdown
@pytest.mark.asyncio
async def test_returns_none_none_when_get_scores_raises(db_session):
"""When get_sentiment_scores raises, returns (None, None)."""
with patch(
"app.services.sentiment_service.get_sentiment_scores",
new_callable=AsyncMock,
side_effect=Exception("DB error"),
):
score, breakdown = await _compute_sentiment_score(db_session, "AAPL")
assert score is None
assert breakdown is None
@pytest.mark.asyncio
async def test_returns_breakdown_with_sub_scores(db_session):
"""With sentiment records, returns score and breakdown with expected sub-scores."""
records = _make_sentiment_records(3)
with patch(
"app.services.sentiment_service.get_sentiment_scores",
new_callable=AsyncMock,
return_value=records,
), patch(
"app.services.sentiment_service.compute_sentiment_dimension_score",
new_callable=AsyncMock,
return_value=75.0,
):
score, breakdown = await _compute_sentiment_score(db_session, "AAPL")
assert score == 75.0
assert breakdown is not None
assert "sub_scores" in breakdown
assert "formula" in breakdown
assert "unavailable" in breakdown
names = [s["name"] for s in breakdown["sub_scores"]]
assert "record_count" in names
assert "decay_rate" in names
assert "lookback_window" in names
# Verify raw values
raw_map = {s["name"]: s["raw_value"] for s in breakdown["sub_scores"]}
assert raw_map["record_count"] == 3
assert raw_map["decay_rate"] == 0.1
assert raw_map["lookback_window"] == 24
assert breakdown["unavailable"] == []
@pytest.mark.asyncio
async def test_formula_contains_decay_info(db_session):
"""Breakdown formula describes the time-decay weighted average."""
records = _make_sentiment_records(2)
with patch(
"app.services.sentiment_service.get_sentiment_scores",
new_callable=AsyncMock,
return_value=records,
), patch(
"app.services.sentiment_service.compute_sentiment_dimension_score",
new_callable=AsyncMock,
return_value=60.0,
):
_, breakdown = await _compute_sentiment_score(db_session, "AAPL")
assert "Time-decay" in breakdown["formula"]
assert "decay_rate" in breakdown["formula"]
assert "24" in breakdown["formula"]
@pytest.mark.asyncio
async def test_sub_scores_have_descriptions(db_session):
"""Each sub-score has a non-empty description."""
records = _make_sentiment_records(1)
with patch(
"app.services.sentiment_service.get_sentiment_scores",
new_callable=AsyncMock,
return_value=records,
), patch(
"app.services.sentiment_service.compute_sentiment_dimension_score",
new_callable=AsyncMock,
return_value=50.0,
):
_, breakdown = await _compute_sentiment_score(db_session, "AAPL")
for sub in breakdown["sub_scores"]:
assert sub["description"], f"Sub-score {sub['name']} missing description"

View File

@@ -0,0 +1,151 @@
"""Unit tests for _compute_technical_score breakdown refactor."""
from __future__ import annotations
from datetime import date
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from app.services.scoring_service import _compute_technical_score
def _make_ohlcv_records(n: int, base_close: float = 100.0) -> list:
"""Create n mock OHLCV records with realistic price data."""
records = []
for i in range(n):
price = base_close + (i * 0.5)
records.append(
SimpleNamespace(
date=date(2024, 1, 1),
open=price - 0.5,
high=price + 1.0,
low=price - 1.0,
close=price,
volume=1000000,
)
)
return records
@pytest.mark.asyncio
async def test_returns_none_tuple_when_no_records(db_session):
"""When no OHLCV data exists, returns (None, None)."""
with patch(
"app.services.price_service.query_ohlcv",
new_callable=AsyncMock,
return_value=[],
):
score, breakdown = await _compute_technical_score(db_session, "AAPL")
assert score is None
assert breakdown is None
@pytest.mark.asyncio
async def test_returns_breakdown_with_all_sub_scores(db_session):
"""With enough data, returns score and breakdown with ADX, EMA, RSI sub-scores."""
records = _make_ohlcv_records(50)
with patch(
"app.services.price_service.query_ohlcv",
new_callable=AsyncMock,
return_value=records,
):
score, breakdown = await _compute_technical_score(db_session, "AAPL")
assert score is not None
assert 0.0 <= score <= 100.0
assert breakdown is not None
assert "sub_scores" in breakdown
assert "formula" in breakdown
assert "unavailable" in breakdown
names = [s["name"] for s in breakdown["sub_scores"]]
assert "ADX" in names
assert "EMA" in names
assert "RSI" in names
# Verify weights
weight_map = {s["name"]: s["weight"] for s in breakdown["sub_scores"]}
assert weight_map["ADX"] == 0.4
assert weight_map["EMA"] == 0.3
assert weight_map["RSI"] == 0.3
# Verify raw_value is present and numeric
for sub in breakdown["sub_scores"]:
assert sub["raw_value"] is not None
assert isinstance(sub["raw_value"], (int, float))
assert sub["description"]
assert breakdown["unavailable"] == []
@pytest.mark.asyncio
async def test_partial_sub_scores_with_insufficient_data(db_session):
"""With limited data (enough for EMA/RSI but not ADX), returns partial breakdown."""
# 22 bars: enough for EMA(20) and RSI(14) but not ADX (needs 28)
records = _make_ohlcv_records(22)
with patch(
"app.services.price_service.query_ohlcv",
new_callable=AsyncMock,
return_value=records,
):
score, breakdown = await _compute_technical_score(db_session, "AAPL")
assert score is not None
assert breakdown is not None
names = [s["name"] for s in breakdown["sub_scores"]]
assert "EMA" in names
assert "RSI" in names
assert "ADX" not in names
# ADX should be in unavailable
unavail_names = [u["name"] for u in breakdown["unavailable"]]
assert "ADX" in unavail_names
assert any(u["reason"] for u in breakdown["unavailable"])
@pytest.mark.asyncio
async def test_all_sub_scores_unavailable(db_session):
"""With very few bars (not enough for any indicator), returns None score with breakdown."""
# 5 bars: not enough for any indicator
records = _make_ohlcv_records(5)
with patch(
"app.services.price_service.query_ohlcv",
new_callable=AsyncMock,
return_value=records,
):
score, breakdown = await _compute_technical_score(db_session, "AAPL")
assert score is None
assert breakdown is not None
assert breakdown["sub_scores"] == []
assert len(breakdown["unavailable"]) == 3
unavail_names = [u["name"] for u in breakdown["unavailable"]]
assert "ADX" in unavail_names
assert "EMA" in unavail_names
assert "RSI" in unavail_names
@pytest.mark.asyncio
async def test_formula_string_present(db_session):
"""Breakdown always includes the formula description."""
records = _make_ohlcv_records(50)
with patch(
"app.services.price_service.query_ohlcv",
new_callable=AsyncMock,
return_value=records,
):
_, breakdown = await _compute_technical_score(db_session, "AAPL")
assert "formula" in breakdown
assert "ADX" in breakdown["formula"]
assert "EMA" in breakdown["formula"]
assert "RSI" in breakdown["formula"]

View File

@@ -0,0 +1,243 @@
"""Unit tests for the S/R levels router — zone integration."""
from datetime import datetime
from unittest.mock import AsyncMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.middleware import register_exception_handlers
from app.routers.sr_levels import router
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _FakeLevel:
"""Mimics an SRLevel ORM model."""
def __init__(self, id, price_level, type, strength, detection_method):
self.id = id
self.price_level = price_level
self.type = type
self.strength = strength
self.detection_method = detection_method
self.created_at = datetime(2024, 1, 1)
class _FakeOHLCV:
"""Mimics an OHLCVRecord with a close attribute."""
def __init__(self, close: float):
self.close = close
def _make_app() -> FastAPI:
app = FastAPI()
register_exception_handlers(app)
app.include_router(router, prefix="/api/v1")
# Override auth dependency to no-op
from app.dependencies import require_access, get_db
app.dependency_overrides[require_access] = lambda: None
app.dependency_overrides[get_db] = lambda: AsyncMock()
return app
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
SAMPLE_LEVELS = [
_FakeLevel(1, 95.0, "support", 60, "volume_profile"),
_FakeLevel(2, 96.0, "support", 40, "pivot_point"),
_FakeLevel(3, 110.0, "resistance", 80, "merged"),
]
SAMPLE_OHLCV = [_FakeOHLCV(100.0)]
class TestSRLevelsRouterZones:
"""Tests for max_zones parameter and zone inclusion in response."""
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
def test_default_max_zones_returns_zones(self, mock_get_sr, mock_ohlcv):
mock_get_sr.return_value = SAMPLE_LEVELS
mock_ohlcv.return_value = SAMPLE_OHLCV
app = _make_app()
client = TestClient(app)
resp = client.get("/api/v1/sr-levels/AAPL")
assert resp.status_code == 200
body = resp.json()
assert body["status"] == "success"
data = body["data"]
assert "zones" in data
assert isinstance(data["zones"], list)
# With default max_zones=6, we should get zones
assert len(data["zones"]) > 0
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
def test_max_zones_zero_returns_empty_zones(self, mock_get_sr, mock_ohlcv):
mock_get_sr.return_value = SAMPLE_LEVELS
mock_ohlcv.return_value = SAMPLE_OHLCV
app = _make_app()
client = TestClient(app)
resp = client.get("/api/v1/sr-levels/AAPL?max_zones=0")
assert resp.status_code == 200
data = resp.json()["data"]
assert data["zones"] == []
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
def test_max_zones_limits_zone_count(self, mock_get_sr, mock_ohlcv):
mock_get_sr.return_value = SAMPLE_LEVELS
mock_ohlcv.return_value = SAMPLE_OHLCV
app = _make_app()
client = TestClient(app)
resp = client.get("/api/v1/sr-levels/AAPL?max_zones=1")
assert resp.status_code == 200
data = resp.json()["data"]
assert len(data["zones"]) <= 1
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
def test_no_ohlcv_data_returns_empty_zones(self, mock_get_sr, mock_ohlcv):
mock_get_sr.return_value = SAMPLE_LEVELS
mock_ohlcv.return_value = [] # No OHLCV data
app = _make_app()
client = TestClient(app)
resp = client.get("/api/v1/sr-levels/AAPL")
assert resp.status_code == 200
data = resp.json()["data"]
assert data["zones"] == []
# Levels should still be present
assert len(data["levels"]) == 3
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
def test_no_levels_returns_empty_zones(self, mock_get_sr, mock_ohlcv):
mock_get_sr.return_value = []
mock_ohlcv.return_value = SAMPLE_OHLCV
app = _make_app()
client = TestClient(app)
resp = client.get("/api/v1/sr-levels/AAPL")
assert resp.status_code == 200
data = resp.json()["data"]
assert data["zones"] == []
assert data["levels"] == []
assert data["count"] == 0
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
def test_zone_fields_present(self, mock_get_sr, mock_ohlcv):
mock_get_sr.return_value = SAMPLE_LEVELS
mock_ohlcv.return_value = SAMPLE_OHLCV
app = _make_app()
client = TestClient(app)
resp = client.get("/api/v1/sr-levels/AAPL")
data = resp.json()["data"]
for zone in data["zones"]:
assert "low" in zone
assert "high" in zone
assert "midpoint" in zone
assert "strength" in zone
assert "type" in zone
assert "level_count" in zone
assert zone["type"] in ("support", "resistance")
class TestSRLevelsRouterVisibleLevels:
"""Tests for visible_levels filtering in the SR levels response."""
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
def test_visible_levels_present_in_response(self, mock_get_sr, mock_ohlcv):
"""visible_levels field is always present in the API response."""
mock_get_sr.return_value = SAMPLE_LEVELS
mock_ohlcv.return_value = SAMPLE_OHLCV
app = _make_app()
client = TestClient(app)
resp = client.get("/api/v1/sr-levels/AAPL")
assert resp.status_code == 200
data = resp.json()["data"]
assert "visible_levels" in data
assert isinstance(data["visible_levels"], list)
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
def test_visible_levels_within_zone_bounds(self, mock_get_sr, mock_ohlcv):
"""Every visible level has a price within at least one zone's [low, high] range."""
mock_get_sr.return_value = SAMPLE_LEVELS
mock_ohlcv.return_value = SAMPLE_OHLCV
app = _make_app()
client = TestClient(app)
resp = client.get("/api/v1/sr-levels/AAPL")
data = resp.json()["data"]
zones = data["zones"]
visible = data["visible_levels"]
# When zones exist, each visible level must fall within a zone
for lvl in visible:
price = lvl["price_level"]
assert any(
z["low"] <= price <= z["high"] for z in zones
), f"visible level price {price} not within any zone bounds"
# visible_levels must be a subset of levels (by id)
level_ids = {l["id"] for l in data["levels"]}
for lvl in visible:
assert lvl["id"] in level_ids
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
def test_visible_levels_empty_when_no_ohlcv(self, mock_get_sr, mock_ohlcv):
"""visible_levels is empty when no OHLCV data exists (zones are empty)."""
mock_get_sr.return_value = SAMPLE_LEVELS
mock_ohlcv.return_value = []
app = _make_app()
client = TestClient(app)
resp = client.get("/api/v1/sr-levels/AAPL")
data = resp.json()["data"]
assert data["zones"] == []
assert data["visible_levels"] == []
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
def test_visible_levels_empty_when_max_zones_zero(self, mock_get_sr, mock_ohlcv):
"""visible_levels is empty when max_zones=0 (zones are empty)."""
mock_get_sr.return_value = SAMPLE_LEVELS
mock_ohlcv.return_value = SAMPLE_OHLCV
app = _make_app()
client = TestClient(app)
resp = client.get("/api/v1/sr-levels/AAPL?max_zones=0")
data = resp.json()["data"]
assert data["zones"] == []
assert data["visible_levels"] == []