major update
This commit is contained in:
211
tests/unit/test_cluster_sr_zones.py
Normal file
211
tests/unit/test_cluster_sr_zones.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Unit tests for cluster_sr_zones() in app.services.sr_service."""
|
||||
|
||||
from app.services.sr_service import cluster_sr_zones
|
||||
|
||||
|
||||
def _level(price: float, strength: int = 50) -> dict:
|
||||
"""Helper to build a level dict."""
|
||||
return {"price_level": price, "strength": strength}
|
||||
|
||||
|
||||
class TestClusterSrZonesEmptyAndEdge:
|
||||
"""Edge cases: empty input, max_zones boundaries."""
|
||||
|
||||
def test_empty_levels_returns_empty(self):
|
||||
assert cluster_sr_zones([], current_price=100.0) == []
|
||||
|
||||
def test_max_zones_zero_returns_empty(self):
|
||||
levels = [_level(100.0)]
|
||||
assert cluster_sr_zones(levels, current_price=100.0, max_zones=0) == []
|
||||
|
||||
def test_max_zones_negative_returns_empty(self):
|
||||
levels = [_level(100.0)]
|
||||
assert cluster_sr_zones(levels, current_price=100.0, max_zones=-1) == []
|
||||
|
||||
def test_single_level(self):
|
||||
levels = [_level(95.0, 60)]
|
||||
zones = cluster_sr_zones(levels, current_price=100.0)
|
||||
assert len(zones) == 1
|
||||
z = zones[0]
|
||||
assert z["low"] == 95.0
|
||||
assert z["high"] == 95.0
|
||||
assert z["midpoint"] == 95.0
|
||||
assert z["strength"] == 60
|
||||
assert z["type"] == "support"
|
||||
assert z["level_count"] == 1
|
||||
|
||||
|
||||
class TestClusterSrZonesMerging:
|
||||
"""Greedy merge behaviour."""
|
||||
|
||||
def test_two_levels_within_tolerance_merge(self):
|
||||
# 100 and 101 are 1% apart; tolerance=2% → should merge
|
||||
levels = [_level(100.0, 30), _level(101.0, 40)]
|
||||
zones = cluster_sr_zones(levels, current_price=200.0, tolerance=0.02)
|
||||
assert len(zones) == 1
|
||||
z = zones[0]
|
||||
assert z["low"] == 100.0
|
||||
assert z["high"] == 101.0
|
||||
assert z["midpoint"] == 100.5
|
||||
assert z["strength"] == 70
|
||||
assert z["level_count"] == 2
|
||||
|
||||
def test_two_levels_outside_tolerance_stay_separate(self):
|
||||
# 100 and 110 are 10% apart; tolerance=2% → separate
|
||||
levels = [_level(100.0, 30), _level(110.0, 40)]
|
||||
zones = cluster_sr_zones(levels, current_price=200.0, tolerance=0.02)
|
||||
assert len(zones) == 2
|
||||
|
||||
def test_all_same_price_merge_into_one(self):
|
||||
levels = [_level(50.0, 20), _level(50.0, 30), _level(50.0, 10)]
|
||||
zones = cluster_sr_zones(levels, current_price=100.0)
|
||||
assert len(zones) == 1
|
||||
assert zones[0]["strength"] == 60
|
||||
assert zones[0]["level_count"] == 3
|
||||
|
||||
def test_levels_at_tolerance_boundary(self):
|
||||
# midpoint of cluster starting at 100 is 100. 2% of 100 = 2.
|
||||
# A level at 102 is exactly at the boundary → should merge
|
||||
levels = [_level(100.0, 25), _level(102.0, 25)]
|
||||
zones = cluster_sr_zones(levels, current_price=200.0, tolerance=0.02)
|
||||
assert len(zones) == 1
|
||||
|
||||
|
||||
class TestClusterSrZonesStrength:
|
||||
"""Strength capping and computation."""
|
||||
|
||||
def test_strength_capped_at_100(self):
|
||||
levels = [_level(100.0, 80), _level(100.5, 80)]
|
||||
zones = cluster_sr_zones(levels, current_price=200.0, tolerance=0.02)
|
||||
assert len(zones) == 1
|
||||
assert zones[0]["strength"] == 100
|
||||
|
||||
def test_strength_sum_when_under_cap(self):
|
||||
levels = [_level(100.0, 10), _level(100.5, 20)]
|
||||
zones = cluster_sr_zones(levels, current_price=200.0, tolerance=0.02)
|
||||
assert zones[0]["strength"] == 30
|
||||
|
||||
|
||||
class TestClusterSrZonesTypeTagging:
|
||||
"""Support vs resistance tagging."""
|
||||
|
||||
def test_support_when_midpoint_below_current(self):
|
||||
levels = [_level(90.0, 50)]
|
||||
zones = cluster_sr_zones(levels, current_price=100.0)
|
||||
assert zones[0]["type"] == "support"
|
||||
|
||||
def test_resistance_when_midpoint_above_current(self):
|
||||
levels = [_level(110.0, 50)]
|
||||
zones = cluster_sr_zones(levels, current_price=100.0)
|
||||
assert zones[0]["type"] == "resistance"
|
||||
|
||||
def test_resistance_when_midpoint_equals_current(self):
|
||||
# "else resistance" per spec
|
||||
levels = [_level(100.0, 50)]
|
||||
zones = cluster_sr_zones(levels, current_price=100.0)
|
||||
assert zones[0]["type"] == "resistance"
|
||||
|
||||
|
||||
class TestClusterSrZonesSorting:
|
||||
"""Sorting by strength descending."""
|
||||
|
||||
def test_sorted_by_strength_descending(self):
|
||||
levels = [_level(50.0, 20), _level(150.0, 80), _level(250.0, 50)]
|
||||
zones = cluster_sr_zones(levels, current_price=100.0, tolerance=0.001)
|
||||
strengths = [z["strength"] for z in zones]
|
||||
assert strengths == sorted(strengths, reverse=True)
|
||||
|
||||
|
||||
class TestClusterSrZonesMaxZones:
|
||||
"""max_zones filtering."""
|
||||
|
||||
def test_max_zones_limits_output(self):
|
||||
levels = [_level(50.0, 20), _level(150.0, 80), _level(250.0, 50)]
|
||||
zones = cluster_sr_zones(
|
||||
levels, current_price=100.0, tolerance=0.001, max_zones=2
|
||||
)
|
||||
assert len(zones) == 2
|
||||
# Balanced selection: 1 support (strength 20) + 1 resistance (strength 80)
|
||||
types = {z["type"] for z in zones}
|
||||
assert "support" in types
|
||||
assert "resistance" in types
|
||||
assert zones[0]["strength"] == 80
|
||||
assert zones[1]["strength"] == 20
|
||||
|
||||
def test_max_zones_none_returns_all(self):
|
||||
levels = [_level(50.0, 20), _level(150.0, 80), _level(250.0, 50)]
|
||||
zones = cluster_sr_zones(
|
||||
levels, current_price=100.0, tolerance=0.001, max_zones=None
|
||||
)
|
||||
assert len(zones) == 3
|
||||
|
||||
def test_max_zones_larger_than_count_returns_all(self):
|
||||
levels = [_level(50.0, 20)]
|
||||
zones = cluster_sr_zones(
|
||||
levels, current_price=100.0, max_zones=10
|
||||
)
|
||||
assert len(zones) == 1
|
||||
|
||||
|
||||
class TestClusterSrZonesBalancedSelection:
|
||||
"""Balanced interleave selection behaviour (Requirements 1.1, 1.2, 1.3, 1.5, 1.6)."""
|
||||
|
||||
def test_mixed_input_produces_balanced_output(self):
|
||||
"""3 support + 3 resistance with max_zones=4 → 2 support + 2 resistance."""
|
||||
levels = [
|
||||
_level(80.0, 70), # support
|
||||
_level(85.0, 50), # support
|
||||
_level(90.0, 30), # support
|
||||
_level(110.0, 60), # resistance
|
||||
_level(115.0, 40), # resistance
|
||||
_level(120.0, 20), # resistance
|
||||
]
|
||||
zones = cluster_sr_zones(levels, current_price=100.0, tolerance=0.001, max_zones=4)
|
||||
assert len(zones) == 4
|
||||
support_count = sum(1 for z in zones if z["type"] == "support")
|
||||
resistance_count = sum(1 for z in zones if z["type"] == "resistance")
|
||||
assert support_count == 2
|
||||
assert resistance_count == 2
|
||||
|
||||
def test_all_support_fills_from_support_only(self):
|
||||
"""When no resistance levels exist, all slots filled from support."""
|
||||
levels = [
|
||||
_level(80.0, 70),
|
||||
_level(85.0, 50),
|
||||
_level(90.0, 30),
|
||||
]
|
||||
zones = cluster_sr_zones(levels, current_price=200.0, tolerance=0.001, max_zones=2)
|
||||
assert len(zones) == 2
|
||||
assert all(z["type"] == "support" for z in zones)
|
||||
|
||||
def test_all_resistance_fills_from_resistance_only(self):
|
||||
"""When no support levels exist, all slots filled from resistance."""
|
||||
levels = [
|
||||
_level(110.0, 60),
|
||||
_level(115.0, 40),
|
||||
_level(120.0, 20),
|
||||
]
|
||||
zones = cluster_sr_zones(levels, current_price=50.0, tolerance=0.001, max_zones=2)
|
||||
assert len(zones) == 2
|
||||
assert all(z["type"] == "resistance" for z in zones)
|
||||
|
||||
def test_single_zone_edge_case(self):
|
||||
"""Only 1 level total → returns exactly 1 zone."""
|
||||
levels = [_level(95.0, 45)]
|
||||
zones = cluster_sr_zones(levels, current_price=100.0, max_zones=5)
|
||||
assert len(zones) == 1
|
||||
assert zones[0]["strength"] == 45
|
||||
|
||||
def test_both_types_present_when_max_zones_gte_2(self):
|
||||
"""When both types exist and max_zones >= 2, at least one of each is present."""
|
||||
levels = [
|
||||
_level(70.0, 90), # support (strongest overall)
|
||||
_level(75.0, 80), # support
|
||||
_level(80.0, 70), # support
|
||||
_level(130.0, 10), # resistance (weakest overall)
|
||||
]
|
||||
zones = cluster_sr_zones(levels, current_price=100.0, tolerance=0.001, max_zones=2)
|
||||
types = {z["type"] for z in zones}
|
||||
assert "support" in types
|
||||
assert "resistance" in types
|
||||
|
||||
156
tests/unit/test_fmp_provider.py
Normal file
156
tests/unit/test_fmp_provider.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Unit tests for FMPFundamentalProvider 402 reason recording."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.providers.fmp import FMPFundamentalProvider
|
||||
|
||||
|
||||
def _mock_response(status_code: int, json_data: object = None) -> httpx.Response:
|
||||
"""Build a fake httpx.Response."""
|
||||
resp = httpx.Response(
|
||||
status_code=status_code,
|
||||
json=json_data if json_data is not None else {},
|
||||
request=httpx.Request("GET", "https://example.com"),
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider() -> FMPFundamentalProvider:
|
||||
return FMPFundamentalProvider(api_key="test-key")
|
||||
|
||||
|
||||
class TestFetchJsonOptional402Tracking:
|
||||
"""_fetch_json_optional returns (data, was_402) tuple."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dict_and_true_on_402(self, provider):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_response(402)
|
||||
|
||||
data, was_402 = await provider._fetch_json_optional(
|
||||
mock_client, "ratios-ttm", {}, "AAPL"
|
||||
)
|
||||
|
||||
assert data == {}
|
||||
assert was_402 is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_data_and_false_on_200(self, provider):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_response(
|
||||
200, [{"priceToEarningsRatioTTM": 25.5}]
|
||||
)
|
||||
|
||||
data, was_402 = await provider._fetch_json_optional(
|
||||
mock_client, "ratios-ttm", {}, "AAPL"
|
||||
)
|
||||
|
||||
assert data == {"priceToEarningsRatioTTM": 25.5}
|
||||
assert was_402 is False
|
||||
|
||||
|
||||
class TestFetchFundamentals402Recording:
|
||||
"""fetch_fundamentals records 402 endpoints in unavailable_fields."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_402_records_all_fields(self, provider):
|
||||
"""When all supplementary endpoints return 402, all three fields are recorded."""
|
||||
profile_resp = _mock_response(200, [{"marketCap": 1_000_000}])
|
||||
ratios_resp = _mock_response(402)
|
||||
growth_resp = _mock_response(402)
|
||||
earnings_resp = _mock_response(402)
|
||||
|
||||
async def mock_get(url, params=None):
|
||||
if "profile" in url:
|
||||
return profile_resp
|
||||
if "ratios-ttm" in url:
|
||||
return ratios_resp
|
||||
if "financial-growth" in url:
|
||||
return growth_resp
|
||||
if "earnings" in url:
|
||||
return earnings_resp
|
||||
return _mock_response(200, [{}])
|
||||
|
||||
with patch("app.providers.fmp.httpx.AsyncClient") as MockClient:
|
||||
instance = AsyncMock()
|
||||
instance.get.side_effect = mock_get
|
||||
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||
instance.__aexit__ = AsyncMock(return_value=False)
|
||||
MockClient.return_value = instance
|
||||
|
||||
result = await provider.fetch_fundamentals("AAPL")
|
||||
|
||||
assert result.unavailable_fields == {
|
||||
"pe_ratio": "requires paid plan",
|
||||
"revenue_growth": "requires paid plan",
|
||||
"earnings_surprise": "requires paid plan",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_200_402_records_only_402_fields(self, provider):
|
||||
"""When only ratios-ttm returns 402, only pe_ratio is recorded."""
|
||||
profile_resp = _mock_response(200, [{"marketCap": 2_000_000}])
|
||||
ratios_resp = _mock_response(402)
|
||||
growth_resp = _mock_response(200, [{"revenueGrowth": 0.15}])
|
||||
earnings_resp = _mock_response(200, [{"epsActual": 3.0, "epsEstimated": 2.5}])
|
||||
|
||||
async def mock_get(url, params=None):
|
||||
if "profile" in url:
|
||||
return profile_resp
|
||||
if "ratios-ttm" in url:
|
||||
return ratios_resp
|
||||
if "financial-growth" in url:
|
||||
return growth_resp
|
||||
if "earnings" in url:
|
||||
return earnings_resp
|
||||
return _mock_response(200, [{}])
|
||||
|
||||
with patch("app.providers.fmp.httpx.AsyncClient") as MockClient:
|
||||
instance = AsyncMock()
|
||||
instance.get.side_effect = mock_get
|
||||
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||
instance.__aexit__ = AsyncMock(return_value=False)
|
||||
MockClient.return_value = instance
|
||||
|
||||
result = await provider.fetch_fundamentals("AAPL")
|
||||
|
||||
assert result.unavailable_fields == {"pe_ratio": "requires paid plan"}
|
||||
assert result.revenue_growth == 0.15
|
||||
assert result.earnings_surprise is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_402_empty_unavailable_fields(self, provider):
|
||||
"""When all endpoints succeed, unavailable_fields is empty."""
|
||||
profile_resp = _mock_response(200, [{"marketCap": 3_000_000}])
|
||||
ratios_resp = _mock_response(200, [{"priceToEarningsRatioTTM": 20.0}])
|
||||
growth_resp = _mock_response(200, [{"revenueGrowth": 0.10}])
|
||||
earnings_resp = _mock_response(200, [{"epsActual": 2.0, "epsEstimated": 1.8}])
|
||||
|
||||
async def mock_get(url, params=None):
|
||||
if "profile" in url:
|
||||
return profile_resp
|
||||
if "ratios-ttm" in url:
|
||||
return ratios_resp
|
||||
if "financial-growth" in url:
|
||||
return growth_resp
|
||||
if "earnings" in url:
|
||||
return earnings_resp
|
||||
return _mock_response(200, [{}])
|
||||
|
||||
with patch("app.providers.fmp.httpx.AsyncClient") as MockClient:
|
||||
instance = AsyncMock()
|
||||
instance.get.side_effect = mock_get
|
||||
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||
instance.__aexit__ = AsyncMock(return_value=False)
|
||||
MockClient.return_value = instance
|
||||
|
||||
result = await provider.fetch_fundamentals("AAPL")
|
||||
|
||||
assert result.unavailable_fields == {}
|
||||
assert result.pe_ratio == 20.0
|
||||
99
tests/unit/test_fundamental_service.py
Normal file
99
tests/unit/test_fundamental_service.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Unit tests for fundamental_service — unavailable_fields persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.database import Base
|
||||
from app.models.ticker import Ticker
|
||||
from app.services import fundamental_service
|
||||
|
||||
# Use a dedicated engine so commit/refresh work without conflicting
|
||||
# with the conftest transactional session.
|
||||
_engine = create_async_engine("sqlite+aiosqlite://", echo=False)
|
||||
_session_factory = async_sessionmaker(_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def _setup_tables():
|
||||
async with _engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
async with _engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def session() -> AsyncSession:
|
||||
async with _session_factory() as s:
|
||||
yield s
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def ticker(session: AsyncSession) -> Ticker:
|
||||
"""Create a test ticker."""
|
||||
t = Ticker(symbol="AAPL")
|
||||
session.add(t)
|
||||
await session.commit()
|
||||
await session.refresh(t)
|
||||
return t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_fundamental_persists_unavailable_fields(
|
||||
session: AsyncSession, ticker: Ticker
|
||||
):
|
||||
"""unavailable_fields dict is serialized to JSON and stored."""
|
||||
fields = {"pe_ratio": "requires paid plan", "revenue_growth": "requires paid plan"}
|
||||
|
||||
record = await fundamental_service.store_fundamental(
|
||||
session,
|
||||
symbol="AAPL",
|
||||
pe_ratio=None,
|
||||
revenue_growth=None,
|
||||
market_cap=1_000_000.0,
|
||||
unavailable_fields=fields,
|
||||
)
|
||||
|
||||
assert json.loads(record.unavailable_fields_json) == fields
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_fundamental_defaults_to_empty_dict(
|
||||
session: AsyncSession, ticker: Ticker
|
||||
):
|
||||
"""When unavailable_fields is not provided, column defaults to '{}'."""
|
||||
record = await fundamental_service.store_fundamental(
|
||||
session,
|
||||
symbol="AAPL",
|
||||
pe_ratio=25.0,
|
||||
)
|
||||
|
||||
assert json.loads(record.unavailable_fields_json) == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_fundamental_updates_unavailable_fields(
|
||||
session: AsyncSession, ticker: Ticker
|
||||
):
|
||||
"""Updating an existing record also updates unavailable_fields_json."""
|
||||
# First store
|
||||
await fundamental_service.store_fundamental(
|
||||
session,
|
||||
symbol="AAPL",
|
||||
pe_ratio=None,
|
||||
unavailable_fields={"pe_ratio": "requires paid plan"},
|
||||
)
|
||||
|
||||
# Second store — fields now available
|
||||
record = await fundamental_service.store_fundamental(
|
||||
session,
|
||||
symbol="AAPL",
|
||||
pe_ratio=25.0,
|
||||
unavailable_fields={},
|
||||
)
|
||||
|
||||
assert json.loads(record.unavailable_fields_json) == {}
|
||||
162
tests/unit/test_openai_sentiment_provider.py
Normal file
162
tests/unit/test_openai_sentiment_provider.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Unit tests for OpenAISentimentProvider reasoning and citations extraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.providers.openai_sentiment import OpenAISentimentProvider
|
||||
|
||||
|
||||
def _make_annotation(ann_type: str, url: str = "", title: str = "") -> SimpleNamespace:
|
||||
return SimpleNamespace(type=ann_type, url=url, title=title)
|
||||
|
||||
|
||||
def _make_content_block(text: str, annotations: list | None = None) -> SimpleNamespace:
|
||||
block = SimpleNamespace(text=text, annotations=annotations or [])
|
||||
return block
|
||||
|
||||
|
||||
def _make_message_item(content_blocks: list) -> SimpleNamespace:
|
||||
return SimpleNamespace(type="message", content=content_blocks)
|
||||
|
||||
|
||||
def _make_web_search_item() -> SimpleNamespace:
|
||||
return SimpleNamespace(type="web_search_call")
|
||||
|
||||
|
||||
def _build_response(json_text: str, annotations: list | None = None) -> SimpleNamespace:
|
||||
"""Build a mock OpenAI Responses API response object."""
|
||||
content_block = _make_content_block(json_text, annotations or [])
|
||||
message_item = _make_message_item([content_block])
|
||||
items = [message_item]
|
||||
return SimpleNamespace(output=items)
|
||||
|
||||
|
||||
def _build_response_with_search(
|
||||
json_text: str, annotations: list | None = None
|
||||
) -> SimpleNamespace:
|
||||
"""Build a response with a web_search_call item followed by a message with annotations."""
|
||||
search_item = _make_web_search_item()
|
||||
content_block = _make_content_block(json_text, annotations or [])
|
||||
message_item = _make_message_item([content_block])
|
||||
return SimpleNamespace(output=[search_item, message_item])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider():
|
||||
"""Create an OpenAISentimentProvider with a mocked client."""
|
||||
with patch("app.providers.openai_sentiment.AsyncOpenAI"):
|
||||
p = OpenAISentimentProvider(api_key="test-key")
|
||||
return p
|
||||
|
||||
|
||||
class TestReasoningExtraction:
|
||||
"""Tests for extracting reasoning from the parsed JSON response."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_extracted_from_json(self, provider):
|
||||
json_text = '{"classification": "bullish", "confidence": 85, "reasoning": "Strong earnings report"}'
|
||||
mock_response = _build_response(json_text)
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await provider.fetch_sentiment("AAPL")
|
||||
|
||||
assert result.reasoning == "Strong earnings report"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_reasoning_when_field_missing(self, provider):
|
||||
json_text = '{"classification": "neutral", "confidence": 50}'
|
||||
mock_response = _build_response(json_text)
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await provider.fetch_sentiment("MSFT")
|
||||
|
||||
assert result.reasoning == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_reasoning_when_field_is_empty_string(self, provider):
|
||||
json_text = '{"classification": "bearish", "confidence": 70, "reasoning": ""}'
|
||||
mock_response = _build_response(json_text)
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await provider.fetch_sentiment("TSLA")
|
||||
|
||||
assert result.reasoning == ""
|
||||
|
||||
|
||||
class TestCitationsExtraction:
|
||||
"""Tests for extracting url_citation annotations from the response."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_citations_extracted_from_annotations(self, provider):
|
||||
json_text = '{"classification": "bullish", "confidence": 90, "reasoning": "Good news"}'
|
||||
annotations = [
|
||||
_make_annotation("url_citation", url="https://example.com/1", title="Article 1"),
|
||||
_make_annotation("url_citation", url="https://example.com/2", title="Article 2"),
|
||||
]
|
||||
mock_response = _build_response(json_text, annotations)
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await provider.fetch_sentiment("AAPL")
|
||||
|
||||
assert len(result.citations) == 2
|
||||
assert result.citations[0] == {"url": "https://example.com/1", "title": "Article 1"}
|
||||
assert result.citations[1] == {"url": "https://example.com/2", "title": "Article 2"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_citations_when_no_annotations(self, provider):
|
||||
json_text = '{"classification": "neutral", "confidence": 50, "reasoning": "No news"}'
|
||||
mock_response = _build_response(json_text)
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await provider.fetch_sentiment("GOOG")
|
||||
|
||||
assert result.citations == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_url_citation_annotations_ignored(self, provider):
|
||||
json_text = '{"classification": "bearish", "confidence": 60, "reasoning": "Mixed signals"}'
|
||||
annotations = [
|
||||
_make_annotation("file_citation", url="https://file.com", title="File"),
|
||||
_make_annotation("url_citation", url="https://real.com", title="Real"),
|
||||
]
|
||||
mock_response = _build_response(json_text, annotations)
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await provider.fetch_sentiment("META")
|
||||
|
||||
assert len(result.citations) == 1
|
||||
assert result.citations[0] == {"url": "https://real.com", "title": "Real"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_citations_from_response_with_web_search_call(self, provider):
|
||||
json_text = '{"classification": "bullish", "confidence": 80, "reasoning": "Positive outlook"}'
|
||||
annotations = [
|
||||
_make_annotation("url_citation", url="https://news.com/a", title="News A"),
|
||||
]
|
||||
mock_response = _build_response_with_search(json_text, annotations)
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await provider.fetch_sentiment("NVDA")
|
||||
|
||||
assert len(result.citations) == 1
|
||||
assert result.citations[0] == {"url": "https://news.com/a", "title": "News A"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_error_when_annotations_attr_missing(self, provider):
|
||||
"""Content blocks without annotations attribute should not cause errors."""
|
||||
json_text = '{"classification": "neutral", "confidence": 50, "reasoning": "Quiet day"}'
|
||||
# Block with no annotations attribute at all
|
||||
block = SimpleNamespace(text=json_text)
|
||||
message_item = SimpleNamespace(type="message", content=[block])
|
||||
mock_response = SimpleNamespace(output=[message_item])
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await provider.fetch_sentiment("AMD")
|
||||
|
||||
assert result.citations == []
|
||||
assert result.reasoning == "Quiet day"
|
||||
273
tests/unit/test_rr_scanner_bug_exploration.py
Normal file
273
tests/unit/test_rr_scanner_bug_exploration.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""Bug-condition exploration tests for R:R scanner target quality.
|
||||
|
||||
These tests confirm the bug described in bugfix.md: the old code always selected
|
||||
the most distant S/R level (highest raw R:R) regardless of strength or proximity.
|
||||
The fix replaces max-R:R selection with quality-score selection.
|
||||
|
||||
Since the code is already fixed, these tests PASS on the current codebase.
|
||||
On the unfixed code they would FAIL, confirming the bug.
|
||||
|
||||
**Validates: Requirements 1.1, 1.3, 1.4, 2.1, 2.3, 2.4**
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, timedelta
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings, HealthCheck, strategies as st
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.ohlcv import OHLCVRecord
|
||||
from app.models.sr_level import SRLevel
|
||||
from app.models.ticker import Ticker
|
||||
from app.services.rr_scanner_service import scan_ticker
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session fixture that allows scan_ticker to commit
|
||||
# ---------------------------------------------------------------------------
|
||||
# The default db_session fixture wraps in session.begin() which conflicts
|
||||
# with scan_ticker's internal commit(). We use a plain session instead.
|
||||
|
||||
@pytest.fixture
|
||||
async def scan_session() -> AsyncSession:
|
||||
"""Provide a DB session compatible with scan_ticker (which commits)."""
|
||||
from tests.conftest import _test_session_factory
|
||||
|
||||
async with _test_session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_ohlcv_bars(
|
||||
ticker_id: int,
|
||||
num_bars: int = 20,
|
||||
base_close: float = 100.0,
|
||||
) -> list[OHLCVRecord]:
|
||||
"""Generate realistic OHLCV bars with small daily variation.
|
||||
|
||||
Produces bars where close ≈ base_close, with enough range for ATR
|
||||
computation (needs >= 15 bars). The ATR will be roughly 2.0.
|
||||
"""
|
||||
bars: list[OHLCVRecord] = []
|
||||
start = date(2024, 1, 1)
|
||||
for i in range(num_bars):
|
||||
close = base_close + (i % 3 - 1) * 0.5 # oscillate ±0.5
|
||||
bars.append(OHLCVRecord(
|
||||
ticker_id=ticker_id,
|
||||
date=start + timedelta(days=i),
|
||||
open=close - 0.3,
|
||||
high=close + 1.0,
|
||||
low=close - 1.0,
|
||||
close=close,
|
||||
volume=100_000,
|
||||
))
|
||||
return bars
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deterministic test: strong-near vs weak-far (long setup)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_prefers_strong_near_over_weak_far(scan_session: AsyncSession):
|
||||
"""With a strong nearby resistance and a weak distant resistance, the
|
||||
scanner should pick the strong nearby one — NOT the most distant.
|
||||
|
||||
On unfixed code this would fail because max-R:R always picks the
|
||||
farthest level.
|
||||
"""
|
||||
ticker = Ticker(symbol="EXPLR")
|
||||
scan_session.add(ticker)
|
||||
await scan_session.flush()
|
||||
|
||||
# 20 bars closing around 100
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
scan_session.add_all(bars)
|
||||
|
||||
# With ATR=2.0 and multiplier=1.5, risk=3.0.
|
||||
# R:R threshold=1.5 → min reward=4.5 → min target=104.5
|
||||
# Strong nearby resistance: price=105, strength=90 (R:R≈1.67, quality≈0.66)
|
||||
near_level = SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=105.0,
|
||||
type="resistance",
|
||||
strength=90,
|
||||
detection_method="volume_profile",
|
||||
)
|
||||
# Weak distant resistance: price=130, strength=5 (R:R=10, quality≈0.58)
|
||||
far_level = SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=130.0,
|
||||
type="resistance",
|
||||
strength=5,
|
||||
detection_method="volume_profile",
|
||||
)
|
||||
scan_session.add_all([near_level, far_level])
|
||||
await scan_session.flush()
|
||||
|
||||
setups = await scan_ticker(scan_session, "EXPLR", rr_threshold=1.5)
|
||||
|
||||
long_setups = [s for s in setups if s.direction == "long"]
|
||||
assert len(long_setups) == 1, "Expected exactly one long setup"
|
||||
|
||||
selected_target = long_setups[0].target
|
||||
# The scanner must NOT pick the most distant level (130)
|
||||
assert selected_target != pytest.approx(130.0, abs=0.01), (
|
||||
"Bug: scanner picked the weak distant level (130) instead of the "
|
||||
"strong nearby level (105)"
|
||||
)
|
||||
# It should pick the strong nearby level
|
||||
assert selected_target == pytest.approx(105.0, abs=0.01)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deterministic test: strong-near vs weak-far (short setup)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_short_prefers_strong_near_over_weak_far(scan_session: AsyncSession):
|
||||
"""Short-side mirror: strong nearby support should be preferred over
|
||||
weak distant support.
|
||||
"""
|
||||
ticker = Ticker(symbol="EXPLS")
|
||||
scan_session.add(ticker)
|
||||
await scan_session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
scan_session.add_all(bars)
|
||||
|
||||
# With ATR=2.0 and multiplier=1.5, risk=3.0.
|
||||
# R:R threshold=1.5 → min reward=4.5 → min target below 95.5
|
||||
# Strong nearby support: price=95, strength=85 (R:R≈1.67, quality≈0.64)
|
||||
near_level = SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=95.0,
|
||||
type="support",
|
||||
strength=85,
|
||||
detection_method="pivot_point",
|
||||
)
|
||||
# Weak distant support: price=70, strength=5 (R:R=10, quality≈0.58)
|
||||
far_level = SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=70.0,
|
||||
type="support",
|
||||
strength=5,
|
||||
detection_method="pivot_point",
|
||||
)
|
||||
scan_session.add_all([near_level, far_level])
|
||||
await scan_session.flush()
|
||||
|
||||
setups = await scan_ticker(scan_session, "EXPLS", rr_threshold=1.5)
|
||||
|
||||
short_setups = [s for s in setups if s.direction == "short"]
|
||||
assert len(short_setups) == 1, "Expected exactly one short setup"
|
||||
|
||||
selected_target = short_setups[0].target
|
||||
assert selected_target != pytest.approx(70.0, abs=0.01), (
|
||||
"Bug: scanner picked the weak distant level (70) instead of the "
|
||||
"strong nearby level (95)"
|
||||
)
|
||||
assert selected_target == pytest.approx(95.0, abs=0.01)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hypothesis property test: selection is NOT always the most distant level
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@st.composite
|
||||
def strong_near_weak_far_pair(draw: st.DrawFn) -> dict:
|
||||
"""Generate a (strong-near, weak-far) resistance pair above entry=100.
|
||||
|
||||
Guarantees:
|
||||
- near_price < far_price (both above entry)
|
||||
- near_strength >> far_strength
|
||||
- Both meet the R:R threshold of 1.5 given typical ATR ≈ 2 → risk ≈ 3
|
||||
"""
|
||||
# Near level: 5–15 above entry (R:R ≈ 1.7–5.0 with risk≈3)
|
||||
near_dist = draw(st.floats(min_value=5.0, max_value=15.0))
|
||||
near_strength = draw(st.integers(min_value=70, max_value=100))
|
||||
|
||||
# Far level: 25–60 above entry (R:R ≈ 8.3–20 with risk≈3)
|
||||
far_dist = draw(st.floats(min_value=25.0, max_value=60.0))
|
||||
far_strength = draw(st.integers(min_value=1, max_value=15))
|
||||
|
||||
return {
|
||||
"near_price": 100.0 + near_dist,
|
||||
"near_strength": near_strength,
|
||||
"far_price": 100.0 + far_dist,
|
||||
"far_strength": far_strength,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@given(pair=strong_near_weak_far_pair())
|
||||
@settings(
|
||||
max_examples=15,
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
)
|
||||
async def test_property_scanner_does_not_always_pick_most_distant(
|
||||
pair: dict,
|
||||
scan_session: AsyncSession,
|
||||
):
|
||||
"""**Validates: Requirements 1.1, 1.3, 1.4, 2.1, 2.3, 2.4**
|
||||
|
||||
Property: when a strong nearby resistance exists alongside a weak distant
|
||||
resistance, the scanner does NOT always select the most distant level.
|
||||
|
||||
On unfixed code this would fail for every example because max-R:R always
|
||||
picks the farthest level.
|
||||
"""
|
||||
from tests.conftest import _test_engine, _test_session_factory
|
||||
|
||||
# Each hypothesis example needs a fresh DB state
|
||||
async with _test_engine.begin() as conn:
|
||||
from app.database import Base
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with _test_session_factory() as session:
|
||||
ticker = Ticker(symbol="PROP")
|
||||
session.add(ticker)
|
||||
await session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
session.add_all(bars)
|
||||
|
||||
near_level = SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=pair["near_price"],
|
||||
type="resistance",
|
||||
strength=pair["near_strength"],
|
||||
detection_method="volume_profile",
|
||||
)
|
||||
far_level = SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=pair["far_price"],
|
||||
type="resistance",
|
||||
strength=pair["far_strength"],
|
||||
detection_method="volume_profile",
|
||||
)
|
||||
session.add_all([near_level, far_level])
|
||||
await session.commit()
|
||||
|
||||
setups = await scan_ticker(session, "PROP", rr_threshold=1.5)
|
||||
|
||||
long_setups = [s for s in setups if s.direction == "long"]
|
||||
assert len(long_setups) == 1, "Expected exactly one long setup"
|
||||
|
||||
selected_target = long_setups[0].target
|
||||
most_distant = round(pair["far_price"], 4)
|
||||
|
||||
# The fixed scanner should prefer the strong nearby level, not the
|
||||
# most distant weak one.
|
||||
assert selected_target != pytest.approx(most_distant, abs=0.01), (
|
||||
f"Bug: scanner picked the most distant level ({most_distant}) "
|
||||
f"with strength={pair['far_strength']} over the nearby level "
|
||||
f"({round(pair['near_price'], 4)}) with strength={pair['near_strength']}"
|
||||
)
|
||||
383
tests/unit/test_rr_scanner_fix_check.py
Normal file
383
tests/unit/test_rr_scanner_fix_check.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""Fix-checking tests for R:R scanner quality-score selection.
|
||||
|
||||
Verify that the fixed scan_ticker selects the candidate with the highest
|
||||
quality score among all candidates meeting the R:R threshold, for both
|
||||
long and short setups.
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, timedelta
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings, HealthCheck, strategies as st
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.ohlcv import OHLCVRecord
|
||||
from app.models.sr_level import SRLevel
|
||||
from app.models.ticker import Ticker
|
||||
from app.services.rr_scanner_service import scan_ticker, _compute_quality_score
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session fixture (plain session, not wrapped in begin())
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
async def scan_session() -> AsyncSession:
|
||||
"""Provide a DB session compatible with scan_ticker (which commits)."""
|
||||
from tests.conftest import _test_session_factory
|
||||
|
||||
async with _test_session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_ohlcv_bars(
|
||||
ticker_id: int,
|
||||
num_bars: int = 20,
|
||||
base_close: float = 100.0,
|
||||
) -> list[OHLCVRecord]:
|
||||
"""Generate OHLCV bars closing around base_close with ATR ≈ 2.0."""
|
||||
bars: list[OHLCVRecord] = []
|
||||
start = date(2024, 1, 1)
|
||||
for i in range(num_bars):
|
||||
close = base_close + (i % 3 - 1) * 0.5 # oscillate ±0.5
|
||||
bars.append(OHLCVRecord(
|
||||
ticker_id=ticker_id,
|
||||
date=start + timedelta(days=i),
|
||||
open=close - 0.3,
|
||||
high=close + 1.0,
|
||||
low=close - 1.0,
|
||||
close=close,
|
||||
volume=100_000,
|
||||
))
|
||||
return bars
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hypothesis strategy: multiple resistance levels above entry for longs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@st.composite
|
||||
def long_candidate_levels(draw: st.DrawFn) -> list[dict]:
|
||||
"""Generate 2-5 resistance levels above entry_price=100.
|
||||
|
||||
All levels meet the R:R threshold of 1.5 given ATR≈2, risk≈3,
|
||||
so min reward=4.5, min target=104.5.
|
||||
"""
|
||||
num_levels = draw(st.integers(min_value=2, max_value=5))
|
||||
levels = []
|
||||
for _ in range(num_levels):
|
||||
# Distance from entry: 5 to 50 (all above 4.5 threshold)
|
||||
distance = draw(st.floats(min_value=5.0, max_value=50.0))
|
||||
strength = draw(st.integers(min_value=0, max_value=100))
|
||||
levels.append({
|
||||
"price": 100.0 + distance,
|
||||
"strength": strength,
|
||||
})
|
||||
return levels
|
||||
|
||||
|
||||
@st.composite
|
||||
def short_candidate_levels(draw: st.DrawFn) -> list[dict]:
|
||||
"""Generate 2-5 support levels below entry_price=100.
|
||||
|
||||
All levels meet the R:R threshold of 1.5 given ATR≈2, risk≈3,
|
||||
so min reward=4.5, max target=95.5.
|
||||
"""
|
||||
num_levels = draw(st.integers(min_value=2, max_value=5))
|
||||
levels = []
|
||||
for _ in range(num_levels):
|
||||
# Distance below entry: 5 to 50 (all above 4.5 threshold)
|
||||
distance = draw(st.floats(min_value=5.0, max_value=50.0))
|
||||
strength = draw(st.integers(min_value=0, max_value=100))
|
||||
levels.append({
|
||||
"price": 100.0 - distance,
|
||||
"strength": strength,
|
||||
})
|
||||
return levels
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property test: long setup selects highest quality score candidate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@given(levels=long_candidate_levels())
|
||||
@settings(
|
||||
max_examples=20,
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
)
|
||||
async def test_property_long_selects_highest_quality(
|
||||
levels: list[dict],
|
||||
scan_session: AsyncSession,
|
||||
):
|
||||
"""**Validates: Requirements 2.1, 2.3, 2.4**
|
||||
|
||||
Property: when multiple resistance levels meet the R:R threshold,
|
||||
the fixed scan_ticker selects the one with the highest quality score.
|
||||
"""
|
||||
from tests.conftest import _test_engine, _test_session_factory
|
||||
from app.database import Base
|
||||
|
||||
# Fresh DB state per hypothesis example
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with _test_session_factory() as session:
|
||||
ticker = Ticker(symbol="FIXL")
|
||||
session.add(ticker)
|
||||
await session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
session.add_all(bars)
|
||||
|
||||
sr_levels = []
|
||||
for lv in levels:
|
||||
sr_levels.append(SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=lv["price"],
|
||||
type="resistance",
|
||||
strength=lv["strength"],
|
||||
detection_method="volume_profile",
|
||||
))
|
||||
session.add_all(sr_levels)
|
||||
await session.commit()
|
||||
|
||||
setups = await scan_ticker(session, "FIXL", rr_threshold=1.5)
|
||||
|
||||
long_setups = [s for s in setups if s.direction == "long"]
|
||||
assert len(long_setups) == 1, "Expected exactly one long setup"
|
||||
|
||||
selected_target = long_setups[0].target
|
||||
|
||||
# Compute entry_price and risk from the bars (same logic as scan_ticker)
|
||||
# entry_price = last close ≈ 100.0, ATR ≈ 2.0, risk = ATR * 1.5 = 3.0
|
||||
entry_price = bars[-1].close
|
||||
# Use approximate risk; the exact value comes from ATR computation
|
||||
# We reconstruct it from the setup's entry and stop
|
||||
risk = long_setups[0].entry_price - long_setups[0].stop_loss
|
||||
|
||||
# Compute quality scores for all candidates that meet threshold
|
||||
best_quality = -1.0
|
||||
best_target = None
|
||||
for lv in levels:
|
||||
distance = lv["price"] - entry_price
|
||||
if distance > 0:
|
||||
rr = distance / risk
|
||||
if rr >= 1.5:
|
||||
quality = _compute_quality_score(rr, lv["strength"], distance, entry_price)
|
||||
if quality > best_quality:
|
||||
best_quality = quality
|
||||
best_target = round(lv["price"], 4)
|
||||
|
||||
assert best_target is not None, "At least one candidate should meet threshold"
|
||||
assert selected_target == pytest.approx(best_target, abs=0.01), (
|
||||
f"Selected target {selected_target} != expected best-quality target "
|
||||
f"{best_target} (quality={best_quality:.4f})"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property test: short setup selects highest quality score candidate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@given(levels=short_candidate_levels())
|
||||
@settings(
|
||||
max_examples=20,
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
)
|
||||
async def test_property_short_selects_highest_quality(
|
||||
levels: list[dict],
|
||||
scan_session: AsyncSession,
|
||||
):
|
||||
"""**Validates: Requirements 2.2, 2.3, 2.4**
|
||||
|
||||
Property: when multiple support levels meet the R:R threshold,
|
||||
the fixed scan_ticker selects the one with the highest quality score.
|
||||
"""
|
||||
from tests.conftest import _test_engine, _test_session_factory
|
||||
from app.database import Base
|
||||
|
||||
# Fresh DB state per hypothesis example
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with _test_session_factory() as session:
|
||||
ticker = Ticker(symbol="FIXS")
|
||||
session.add(ticker)
|
||||
await session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
session.add_all(bars)
|
||||
|
||||
sr_levels = []
|
||||
for lv in levels:
|
||||
sr_levels.append(SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=lv["price"],
|
||||
type="support",
|
||||
strength=lv["strength"],
|
||||
detection_method="pivot_point",
|
||||
))
|
||||
session.add_all(sr_levels)
|
||||
await session.commit()
|
||||
|
||||
setups = await scan_ticker(session, "FIXS", rr_threshold=1.5)
|
||||
|
||||
short_setups = [s for s in setups if s.direction == "short"]
|
||||
assert len(short_setups) == 1, "Expected exactly one short setup"
|
||||
|
||||
selected_target = short_setups[0].target
|
||||
|
||||
entry_price = bars[-1].close
|
||||
risk = short_setups[0].stop_loss - short_setups[0].entry_price
|
||||
|
||||
# Compute quality scores for all candidates that meet threshold
|
||||
best_quality = -1.0
|
||||
best_target = None
|
||||
for lv in levels:
|
||||
distance = entry_price - lv["price"]
|
||||
if distance > 0:
|
||||
rr = distance / risk
|
||||
if rr >= 1.5:
|
||||
quality = _compute_quality_score(rr, lv["strength"], distance, entry_price)
|
||||
if quality > best_quality:
|
||||
best_quality = quality
|
||||
best_target = round(lv["price"], 4)
|
||||
|
||||
assert best_target is not None, "At least one candidate should meet threshold"
|
||||
assert selected_target == pytest.approx(best_target, abs=0.01), (
|
||||
f"Selected target {selected_target} != expected best-quality target "
|
||||
f"{best_target} (quality={best_quality:.4f})"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deterministic test: 3 levels with known quality scores (long)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deterministic_long_three_levels(scan_session: AsyncSession):
|
||||
"""**Validates: Requirements 2.1, 2.3, 2.4**
|
||||
|
||||
Concrete example with 3 resistance levels of known quality scores.
|
||||
Entry=100, ATR≈2, risk≈3.
|
||||
|
||||
Level A: price=105, strength=90 → rr=5/3≈1.67, dist=5
|
||||
quality = 0.35*(1.67/10) + 0.35*(90/100) + 0.30*(1-5/100)
|
||||
= 0.35*0.167 + 0.35*0.9 + 0.30*0.95
|
||||
= 0.0585 + 0.315 + 0.285 = 0.6585
|
||||
|
||||
Level B: price=112, strength=50 → rr=12/3=4.0, dist=12
|
||||
quality = 0.35*(4/10) + 0.35*(50/100) + 0.30*(1-12/100)
|
||||
= 0.35*0.4 + 0.35*0.5 + 0.30*0.88
|
||||
= 0.14 + 0.175 + 0.264 = 0.579
|
||||
|
||||
Level C: price=130, strength=10 → rr=30/3=10.0, dist=30
|
||||
quality = 0.35*(10/10) + 0.35*(10/100) + 0.30*(1-30/100)
|
||||
= 0.35*1.0 + 0.35*0.1 + 0.30*0.7
|
||||
= 0.35 + 0.035 + 0.21 = 0.595
|
||||
|
||||
Expected winner: Level A (quality=0.6585)
|
||||
"""
|
||||
ticker = Ticker(symbol="DET3L")
|
||||
scan_session.add(ticker)
|
||||
await scan_session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
scan_session.add_all(bars)
|
||||
|
||||
level_a = SRLevel(
|
||||
ticker_id=ticker.id, price_level=105.0, type="resistance",
|
||||
strength=90, detection_method="volume_profile",
|
||||
)
|
||||
level_b = SRLevel(
|
||||
ticker_id=ticker.id, price_level=112.0, type="resistance",
|
||||
strength=50, detection_method="volume_profile",
|
||||
)
|
||||
level_c = SRLevel(
|
||||
ticker_id=ticker.id, price_level=130.0, type="resistance",
|
||||
strength=10, detection_method="volume_profile",
|
||||
)
|
||||
scan_session.add_all([level_a, level_b, level_c])
|
||||
await scan_session.flush()
|
||||
|
||||
setups = await scan_ticker(scan_session, "DET3L", rr_threshold=1.5)
|
||||
|
||||
long_setups = [s for s in setups if s.direction == "long"]
|
||||
assert len(long_setups) == 1, "Expected exactly one long setup"
|
||||
|
||||
# Level A (105, strength=90) should win with highest quality
|
||||
assert long_setups[0].target == pytest.approx(105.0, abs=0.01), (
|
||||
f"Expected target=105.0 (highest quality), got {long_setups[0].target}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deterministic test: 3 levels with known quality scores (short)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deterministic_short_three_levels(scan_session: AsyncSession):
|
||||
"""**Validates: Requirements 2.2, 2.3, 2.4**
|
||||
|
||||
Concrete example with 3 support levels of known quality scores.
|
||||
Entry=100, ATR≈2, risk≈3.
|
||||
|
||||
Level A: price=95, strength=85 → rr=5/3≈1.67, dist=5
|
||||
quality = 0.35*(1.67/10) + 0.35*(85/100) + 0.30*(1-5/100)
|
||||
= 0.0585 + 0.2975 + 0.285 = 0.641
|
||||
|
||||
Level B: price=88, strength=45 → rr=12/3=4.0, dist=12
|
||||
quality = 0.35*(4/10) + 0.35*(45/100) + 0.30*(1-12/100)
|
||||
= 0.14 + 0.1575 + 0.264 = 0.5615
|
||||
|
||||
Level C: price=70, strength=8 → rr=30/3=10.0, dist=30
|
||||
quality = 0.35*(10/10) + 0.35*(8/100) + 0.30*(1-30/100)
|
||||
= 0.35 + 0.028 + 0.21 = 0.588
|
||||
|
||||
Expected winner: Level A (quality=0.641)
|
||||
"""
|
||||
ticker = Ticker(symbol="DET3S")
|
||||
scan_session.add(ticker)
|
||||
await scan_session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
scan_session.add_all(bars)
|
||||
|
||||
level_a = SRLevel(
|
||||
ticker_id=ticker.id, price_level=95.0, type="support",
|
||||
strength=85, detection_method="pivot_point",
|
||||
)
|
||||
level_b = SRLevel(
|
||||
ticker_id=ticker.id, price_level=88.0, type="support",
|
||||
strength=45, detection_method="pivot_point",
|
||||
)
|
||||
level_c = SRLevel(
|
||||
ticker_id=ticker.id, price_level=70.0, type="support",
|
||||
strength=8, detection_method="pivot_point",
|
||||
)
|
||||
scan_session.add_all([level_a, level_b, level_c])
|
||||
await scan_session.flush()
|
||||
|
||||
setups = await scan_ticker(scan_session, "DET3S", rr_threshold=1.5)
|
||||
|
||||
short_setups = [s for s in setups if s.direction == "short"]
|
||||
assert len(short_setups) == 1, "Expected exactly one short setup"
|
||||
|
||||
# Level A (95, strength=85) should win with highest quality
|
||||
assert short_setups[0].target == pytest.approx(95.0, abs=0.01), (
|
||||
f"Expected target=95.0 (highest quality), got {short_setups[0].target}"
|
||||
)
|
||||
259
tests/unit/test_rr_scanner_integration.py
Normal file
259
tests/unit/test_rr_scanner_integration.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""Integration tests for R:R scanner full flow with quality-based target selection.
|
||||
|
||||
Verifies the complete scan_ticker pipeline: quality-based S/R level selection,
|
||||
correct TradeSetup field population, and database persistence.
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4, 3.4**
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.ohlcv import OHLCVRecord
|
||||
from app.models.score import CompositeScore
|
||||
from app.models.sr_level import SRLevel
|
||||
from app.models.ticker import Ticker
|
||||
from app.models.trade_setup import TradeSetup
|
||||
from app.services.rr_scanner_service import scan_ticker, _compute_quality_score
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
async def scan_session() -> AsyncSession:
|
||||
"""Provide a DB session compatible with scan_ticker (which commits)."""
|
||||
from tests.conftest import _test_session_factory
|
||||
|
||||
async with _test_session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_ohlcv_bars(
|
||||
ticker_id: int,
|
||||
num_bars: int = 20,
|
||||
base_close: float = 100.0,
|
||||
) -> list[OHLCVRecord]:
|
||||
"""Generate OHLCV bars closing around base_close with ATR ≈ 2.0."""
|
||||
bars: list[OHLCVRecord] = []
|
||||
start = date(2024, 1, 1)
|
||||
for i in range(num_bars):
|
||||
close = base_close + (i % 3 - 1) * 0.5 # oscillate ±0.5
|
||||
bars.append(OHLCVRecord(
|
||||
ticker_id=ticker_id,
|
||||
date=start + timedelta(days=i),
|
||||
open=close - 0.3,
|
||||
high=close + 1.0,
|
||||
low=close - 1.0,
|
||||
close=close,
|
||||
volume=100_000,
|
||||
))
|
||||
return bars
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 8.1 Integration test: full scan_ticker flow with quality-based selection,
|
||||
# correct TradeSetup fields, and database persistence
|
||||
# ===========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_ticker_full_flow_quality_selection_and_persistence(
|
||||
scan_session: AsyncSession,
|
||||
):
|
||||
"""Integration test for the complete scan_ticker pipeline.
|
||||
|
||||
Scenario:
|
||||
- Entry ≈ 100, ATR ≈ 2.0, risk ≈ 3.0 (atr_multiplier=1.5)
|
||||
- 3 resistance levels above (long candidates):
|
||||
A: price=105, strength=90 (strong, near) → highest quality
|
||||
B: price=115, strength=40 (medium, mid)
|
||||
C: price=135, strength=5 (weak, far)
|
||||
- 3 support levels below (short candidates):
|
||||
D: price=95, strength=85 (strong, near) → highest quality
|
||||
E: price=85, strength=35 (medium, mid)
|
||||
F: price=65, strength=8 (weak, far)
|
||||
- CompositeScore: 72.5
|
||||
|
||||
Verifies:
|
||||
1. Both long and short setups are produced
|
||||
2. Long target = Level A (highest quality, not most distant)
|
||||
3. Short target = Level D (highest quality, not most distant)
|
||||
4. All TradeSetup fields are correct and rounded to 4 decimals
|
||||
5. rr_ratio is the actual R:R of the selected level
|
||||
6. Old setups are deleted, new ones persisted
|
||||
"""
|
||||
# -- Setup: create ticker --
|
||||
ticker = Ticker(symbol="INTEG")
|
||||
scan_session.add(ticker)
|
||||
await scan_session.flush()
|
||||
|
||||
# -- Setup: OHLCV bars (20 bars, close ≈ 100, ATR ≈ 2.0) --
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
scan_session.add_all(bars)
|
||||
|
||||
# -- Setup: S/R levels --
|
||||
sr_levels = [
|
||||
# Long candidates (resistance above entry)
|
||||
SRLevel(ticker_id=ticker.id, price_level=105.0, type="resistance",
|
||||
strength=90, detection_method="volume_profile"),
|
||||
SRLevel(ticker_id=ticker.id, price_level=115.0, type="resistance",
|
||||
strength=40, detection_method="volume_profile"),
|
||||
SRLevel(ticker_id=ticker.id, price_level=135.0, type="resistance",
|
||||
strength=5, detection_method="pivot_point"),
|
||||
# Short candidates (support below entry)
|
||||
SRLevel(ticker_id=ticker.id, price_level=95.0, type="support",
|
||||
strength=85, detection_method="volume_profile"),
|
||||
SRLevel(ticker_id=ticker.id, price_level=85.0, type="support",
|
||||
strength=35, detection_method="pivot_point"),
|
||||
SRLevel(ticker_id=ticker.id, price_level=65.0, type="support",
|
||||
strength=8, detection_method="volume_profile"),
|
||||
]
|
||||
scan_session.add_all(sr_levels)
|
||||
|
||||
# -- Setup: CompositeScore --
|
||||
comp = CompositeScore(
|
||||
ticker_id=ticker.id,
|
||||
score=72.5,
|
||||
is_stale=False,
|
||||
weights_json="{}",
|
||||
computed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
scan_session.add(comp)
|
||||
|
||||
# -- Setup: dummy old setups that should be deleted --
|
||||
old_setup = TradeSetup(
|
||||
ticker_id=ticker.id,
|
||||
direction="long",
|
||||
entry_price=99.0,
|
||||
stop_loss=96.0,
|
||||
target=120.0,
|
||||
rr_ratio=7.0,
|
||||
composite_score=50.0,
|
||||
detected_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
scan_session.add(old_setup)
|
||||
await scan_session.commit()
|
||||
|
||||
# Verify old setup exists before scan
|
||||
pre_result = await scan_session.execute(
|
||||
select(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
|
||||
)
|
||||
pre_setups = list(pre_result.scalars().all())
|
||||
assert len(pre_setups) == 1, "Dummy old setup should exist before scan"
|
||||
|
||||
# -- Act: run scan_ticker --
|
||||
setups = await scan_ticker(scan_session, "INTEG", rr_threshold=1.5, atr_multiplier=1.5)
|
||||
|
||||
# -- Assert: both directions produced --
|
||||
assert len(setups) == 2, f"Expected 2 setups (long + short), got {len(setups)}"
|
||||
|
||||
long_setups = [s for s in setups if s.direction == "long"]
|
||||
short_setups = [s for s in setups if s.direction == "short"]
|
||||
assert len(long_setups) == 1, f"Expected 1 long setup, got {len(long_setups)}"
|
||||
assert len(short_setups) == 1, f"Expected 1 short setup, got {len(short_setups)}"
|
||||
|
||||
long_setup = long_setups[0]
|
||||
short_setup = short_setups[0]
|
||||
|
||||
# -- Assert: long target is Level A (highest quality, not most distant) --
|
||||
# Level A: price=105 (strong, near) should beat Level C: price=135 (weak, far)
|
||||
assert long_setup.target == pytest.approx(105.0, abs=0.01), (
|
||||
f"Long target should be 105.0 (highest quality), got {long_setup.target}"
|
||||
)
|
||||
|
||||
# -- Assert: short target is Level D (highest quality, not most distant) --
|
||||
# Level D: price=95 (strong, near) should beat Level F: price=65 (weak, far)
|
||||
assert short_setup.target == pytest.approx(95.0, abs=0.01), (
|
||||
f"Short target should be 95.0 (highest quality), got {short_setup.target}"
|
||||
)
|
||||
|
||||
# -- Assert: entry_price is the last close (≈ 100) --
|
||||
# Last bar: index 19, close = 100 + (19 % 3 - 1) * 0.5 = 100 + 0*0.5 = 100.0
|
||||
expected_entry = 100.0
|
||||
assert long_setup.entry_price == pytest.approx(expected_entry, abs=0.5)
|
||||
assert short_setup.entry_price == pytest.approx(expected_entry, abs=0.5)
|
||||
|
||||
entry = long_setup.entry_price # actual entry for R:R calculations
|
||||
|
||||
# -- Assert: stop_loss values --
|
||||
# ATR ≈ 2.0, risk = ATR × 1.5 = 3.0
|
||||
# Long stop = entry - risk, Short stop = entry + risk
|
||||
risk = long_setup.entry_price - long_setup.stop_loss
|
||||
assert risk > 0, "Long risk must be positive"
|
||||
assert short_setup.stop_loss > short_setup.entry_price, "Short stop must be above entry"
|
||||
|
||||
# -- Assert: rr_ratio is the actual R:R of the selected level --
|
||||
long_reward = long_setup.target - long_setup.entry_price
|
||||
long_expected_rr = round(long_reward / risk, 4)
|
||||
assert long_setup.rr_ratio == pytest.approx(long_expected_rr, abs=0.01), (
|
||||
f"Long rr_ratio should be actual R:R={long_expected_rr}, got {long_setup.rr_ratio}"
|
||||
)
|
||||
|
||||
short_risk = short_setup.stop_loss - short_setup.entry_price
|
||||
short_reward = short_setup.entry_price - short_setup.target
|
||||
short_expected_rr = round(short_reward / short_risk, 4)
|
||||
assert short_setup.rr_ratio == pytest.approx(short_expected_rr, abs=0.01), (
|
||||
f"Short rr_ratio should be actual R:R={short_expected_rr}, got {short_setup.rr_ratio}"
|
||||
)
|
||||
|
||||
# -- Assert: composite_score matches --
|
||||
assert long_setup.composite_score == pytest.approx(72.5, abs=0.01)
|
||||
assert short_setup.composite_score == pytest.approx(72.5, abs=0.01)
|
||||
|
||||
# -- Assert: ticker_id is correct --
|
||||
assert long_setup.ticker_id == ticker.id
|
||||
assert short_setup.ticker_id == ticker.id
|
||||
|
||||
# -- Assert: detected_at is set --
|
||||
assert long_setup.detected_at is not None
|
||||
assert short_setup.detected_at is not None
|
||||
|
||||
# -- Assert: fields are rounded to 4 decimal places --
|
||||
for setup in [long_setup, short_setup]:
|
||||
for field_name in ("entry_price", "stop_loss", "target", "rr_ratio", "composite_score"):
|
||||
val = getattr(setup, field_name)
|
||||
rounded = round(val, 4)
|
||||
assert val == pytest.approx(rounded, abs=1e-6), (
|
||||
f"{setup.direction} {field_name}={val} not rounded to 4 decimals"
|
||||
)
|
||||
|
||||
# -- Assert: database persistence --
|
||||
# Old dummy setup should be gone, only the 2 new setups should exist
|
||||
db_result = await scan_session.execute(
|
||||
select(TradeSetup).where(TradeSetup.ticker_id == ticker.id)
|
||||
)
|
||||
persisted = list(db_result.scalars().all())
|
||||
assert len(persisted) == 2, (
|
||||
f"Expected 2 persisted setups (old deleted), got {len(persisted)}"
|
||||
)
|
||||
|
||||
persisted_directions = sorted(s.direction for s in persisted)
|
||||
assert persisted_directions == ["long", "short"], (
|
||||
f"Expected ['long', 'short'] persisted, got {persisted_directions}"
|
||||
)
|
||||
|
||||
# Verify persisted records match returned setups
|
||||
persisted_long = [s for s in persisted if s.direction == "long"][0]
|
||||
persisted_short = [s for s in persisted if s.direction == "short"][0]
|
||||
|
||||
assert persisted_long.target == long_setup.target
|
||||
assert persisted_long.rr_ratio == long_setup.rr_ratio
|
||||
assert persisted_long.entry_price == long_setup.entry_price
|
||||
assert persisted_long.stop_loss == long_setup.stop_loss
|
||||
assert persisted_long.composite_score == long_setup.composite_score
|
||||
|
||||
assert persisted_short.target == short_setup.target
|
||||
assert persisted_short.rr_ratio == short_setup.rr_ratio
|
||||
assert persisted_short.entry_price == short_setup.entry_price
|
||||
assert persisted_short.stop_loss == short_setup.stop_loss
|
||||
assert persisted_short.composite_score == short_setup.composite_score
|
||||
433
tests/unit/test_rr_scanner_preservation.py
Normal file
433
tests/unit/test_rr_scanner_preservation.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Preservation tests for R:R scanner target quality bugfix.
|
||||
|
||||
Verify that the fix does NOT change behavior for zero-candidate and
|
||||
single-candidate scenarios, and that get_trade_setups sorting is unchanged.
|
||||
|
||||
The fix only changes selection logic when MULTIPLE candidates exist.
|
||||
Zero-candidate and single-candidate scenarios must produce identical results.
|
||||
|
||||
**Validates: Requirements 3.1, 3.2, 3.3, 3.5**
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings, HealthCheck, strategies as st
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.ohlcv import OHLCVRecord
|
||||
from app.models.sr_level import SRLevel
|
||||
from app.models.ticker import Ticker
|
||||
from app.models.trade_setup import TradeSetup
|
||||
from app.models.score import CompositeScore
|
||||
from app.services.rr_scanner_service import scan_ticker, get_trade_setups
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
async def scan_session() -> AsyncSession:
|
||||
"""Provide a DB session compatible with scan_ticker (which commits)."""
|
||||
from tests.conftest import _test_session_factory
|
||||
|
||||
async with _test_session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session() -> AsyncSession:
|
||||
"""Provide a transactional DB session for get_trade_setups tests."""
|
||||
from tests.conftest import _test_session_factory
|
||||
|
||||
async with _test_session_factory() as session:
|
||||
async with session.begin():
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_ohlcv_bars(
|
||||
ticker_id: int,
|
||||
num_bars: int = 20,
|
||||
base_close: float = 100.0,
|
||||
) -> list[OHLCVRecord]:
|
||||
"""Generate OHLCV bars closing around base_close with ATR ≈ 2.0."""
|
||||
bars: list[OHLCVRecord] = []
|
||||
start = date(2024, 1, 1)
|
||||
for i in range(num_bars):
|
||||
close = base_close + (i % 3 - 1) * 0.5 # oscillate ±0.5
|
||||
bars.append(OHLCVRecord(
|
||||
ticker_id=ticker_id,
|
||||
date=start + timedelta(days=i),
|
||||
open=close - 0.3,
|
||||
high=close + 1.0,
|
||||
low=close - 1.0,
|
||||
close=close,
|
||||
volume=100_000,
|
||||
))
|
||||
return bars
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 7.1 [PBT-preservation] Property test: zero-candidate and single-candidate
|
||||
# scenarios produce the same output as the original code.
|
||||
# ===========================================================================
|
||||
|
||||
@st.composite
|
||||
def zero_candidate_scenario(draw: st.DrawFn) -> dict:
|
||||
"""Generate a scenario where no S/R levels qualify as candidates.
|
||||
|
||||
Variants:
|
||||
- No SR levels at all
|
||||
- All levels below entry (no long targets) and all above entry (no short targets)
|
||||
but all below the R:R threshold for their respective directions
|
||||
- Levels in the right direction but below R:R threshold
|
||||
|
||||
Note: scan_ticker does NOT filter by SR level type — it only checks whether
|
||||
the price_level is above or below entry. So "wrong side" means all levels
|
||||
are clustered near entry and below threshold in both directions.
|
||||
"""
|
||||
variant = draw(st.sampled_from(["no_levels", "below_threshold"]))
|
||||
|
||||
if variant == "no_levels":
|
||||
return {"variant": variant, "levels": []}
|
||||
|
||||
else: # below_threshold
|
||||
# All levels close to entry so R:R < 1.5 with risk ≈ 3
|
||||
# For longs: reward < 4.5 → price < 104.5
|
||||
# For shorts: reward < 4.5 → price > 95.5
|
||||
# Place all levels in the 96–104 band (below threshold both ways)
|
||||
num = draw(st.integers(min_value=1, max_value=3))
|
||||
levels = []
|
||||
for _ in range(num):
|
||||
price = draw(st.floats(min_value=100.5, max_value=103.5))
|
||||
levels.append({
|
||||
"price": price,
|
||||
"type": draw(st.sampled_from(["resistance", "support"])),
|
||||
"strength": draw(st.integers(min_value=10, max_value=100)),
|
||||
})
|
||||
# Also add some below entry but still below threshold
|
||||
for _ in range(draw(st.integers(min_value=0, max_value=2))):
|
||||
price = draw(st.floats(min_value=96.5, max_value=99.5))
|
||||
levels.append({
|
||||
"price": price,
|
||||
"type": draw(st.sampled_from(["resistance", "support"])),
|
||||
"strength": draw(st.integers(min_value=10, max_value=100)),
|
||||
})
|
||||
return {"variant": variant, "levels": levels}
|
||||
|
||||
|
||||
@st.composite
|
||||
def single_candidate_scenario(draw: st.DrawFn) -> dict:
|
||||
"""Generate a scenario with exactly one S/R level that meets the R:R threshold.
|
||||
|
||||
For longs: one resistance above entry with R:R >= 1.5 (price >= 104.5 with risk ≈ 3).
|
||||
"""
|
||||
direction = draw(st.sampled_from(["long", "short"]))
|
||||
|
||||
if direction == "long":
|
||||
# Single resistance above entry meeting threshold
|
||||
price = draw(st.floats(min_value=105.0, max_value=150.0))
|
||||
strength = draw(st.integers(min_value=1, max_value=100))
|
||||
return {
|
||||
"direction": direction,
|
||||
"level": {"price": price, "type": "resistance", "strength": strength},
|
||||
}
|
||||
else:
|
||||
# Single support below entry meeting threshold
|
||||
price = draw(st.floats(min_value=50.0, max_value=95.0))
|
||||
strength = draw(st.integers(min_value=1, max_value=100))
|
||||
return {
|
||||
"direction": direction,
|
||||
"level": {"price": price, "type": "support", "strength": strength},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@given(scenario=zero_candidate_scenario())
|
||||
@settings(
|
||||
max_examples=15,
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
)
|
||||
async def test_property_zero_candidates_produce_no_setup(
|
||||
scenario: dict,
|
||||
scan_session: AsyncSession,
|
||||
):
|
||||
"""**Validates: Requirements 3.1, 3.2**
|
||||
|
||||
Property: when zero candidate S/R levels exist (no levels, wrong side,
|
||||
or below threshold), scan_ticker produces no setup — unchanged from
|
||||
original behavior.
|
||||
"""
|
||||
from tests.conftest import _test_engine, _test_session_factory
|
||||
from app.database import Base
|
||||
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with _test_session_factory() as session:
|
||||
ticker = Ticker(symbol="PRSV0")
|
||||
session.add(ticker)
|
||||
await session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
session.add_all(bars)
|
||||
|
||||
for lv_data in scenario.get("levels", []):
|
||||
session.add(SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=lv_data["price"],
|
||||
type=lv_data["type"],
|
||||
strength=lv_data["strength"],
|
||||
detection_method="volume_profile",
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
setups = await scan_ticker(session, "PRSV0", rr_threshold=1.5)
|
||||
|
||||
assert setups == [], (
|
||||
f"Expected no setups for zero-candidate scenario "
|
||||
f"(variant={scenario.get('variant', 'unknown')}), got {len(setups)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@given(scenario=single_candidate_scenario())
|
||||
@settings(
|
||||
max_examples=15,
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
)
|
||||
async def test_property_single_candidate_selected_unchanged(
|
||||
scenario: dict,
|
||||
scan_session: AsyncSession,
|
||||
):
|
||||
"""**Validates: Requirements 3.3**
|
||||
|
||||
Property: when exactly one candidate S/R level meets the R:R threshold,
|
||||
scan_ticker selects it — same as the original code would.
|
||||
"""
|
||||
from tests.conftest import _test_engine, _test_session_factory
|
||||
from app.database import Base
|
||||
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with _test_session_factory() as session:
|
||||
ticker = Ticker(symbol="PRSV1")
|
||||
session.add(ticker)
|
||||
await session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
session.add_all(bars)
|
||||
|
||||
lv = scenario["level"]
|
||||
session.add(SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=lv["price"],
|
||||
type=lv["type"],
|
||||
strength=lv["strength"],
|
||||
detection_method="volume_profile",
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
setups = await scan_ticker(session, "PRSV1", rr_threshold=1.5)
|
||||
|
||||
direction = scenario["direction"]
|
||||
dir_setups = [s for s in setups if s.direction == direction]
|
||||
assert len(dir_setups) == 1, (
|
||||
f"Expected exactly one {direction} setup for single candidate, "
|
||||
f"got {len(dir_setups)}"
|
||||
)
|
||||
|
||||
selected_target = dir_setups[0].target
|
||||
expected_target = round(lv["price"], 4)
|
||||
assert selected_target == pytest.approx(expected_target, abs=0.01), (
|
||||
f"Single candidate: expected target={expected_target}, "
|
||||
f"got {selected_target}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 7.2 Unit test: no S/R levels → no setup produced
|
||||
# ===========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_sr_levels_produces_no_setup(scan_session: AsyncSession):
|
||||
"""**Validates: Requirements 3.1**
|
||||
|
||||
When a ticker has OHLCV data but no S/R levels at all,
|
||||
scan_ticker should return an empty list.
|
||||
"""
|
||||
ticker = Ticker(symbol="NOSRL")
|
||||
scan_session.add(ticker)
|
||||
await scan_session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
scan_session.add_all(bars)
|
||||
await scan_session.flush()
|
||||
|
||||
setups = await scan_ticker(scan_session, "NOSRL", rr_threshold=1.5)
|
||||
|
||||
assert setups == [], (
|
||||
f"Expected no setups when no SR levels exist, got {len(setups)}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 7.3 Unit test: single candidate meets threshold → selected
|
||||
# ===========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_resistance_above_threshold_selected(scan_session: AsyncSession):
|
||||
"""**Validates: Requirements 3.3**
|
||||
|
||||
When exactly one resistance level above entry meets the R:R threshold,
|
||||
it should be selected as the long setup target.
|
||||
|
||||
Entry ≈ 100, ATR ≈ 2, risk ≈ 3. Resistance at 110 → R:R ≈ 3.33 (>= 1.5).
|
||||
"""
|
||||
ticker = Ticker(symbol="SINGL")
|
||||
scan_session.add(ticker)
|
||||
await scan_session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
scan_session.add_all(bars)
|
||||
|
||||
level = SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=110.0,
|
||||
type="resistance",
|
||||
strength=60,
|
||||
detection_method="volume_profile",
|
||||
)
|
||||
scan_session.add(level)
|
||||
await scan_session.flush()
|
||||
|
||||
setups = await scan_ticker(scan_session, "SINGL", rr_threshold=1.5)
|
||||
|
||||
long_setups = [s for s in setups if s.direction == "long"]
|
||||
assert len(long_setups) == 1, (
|
||||
f"Expected exactly one long setup, got {len(long_setups)}"
|
||||
)
|
||||
assert long_setups[0].target == pytest.approx(110.0, abs=0.01), (
|
||||
f"Expected target=110.0, got {long_setups[0].target}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_support_below_threshold_selected(scan_session: AsyncSession):
|
||||
"""**Validates: Requirements 3.3**
|
||||
|
||||
When exactly one support level below entry meets the R:R threshold,
|
||||
it should be selected as the short setup target.
|
||||
|
||||
Entry ≈ 100, ATR ≈ 2, risk ≈ 3. Support at 90 → R:R ≈ 3.33 (>= 1.5).
|
||||
"""
|
||||
ticker = Ticker(symbol="SINGS")
|
||||
scan_session.add(ticker)
|
||||
await scan_session.flush()
|
||||
|
||||
bars = _make_ohlcv_bars(ticker.id, num_bars=20, base_close=100.0)
|
||||
scan_session.add_all(bars)
|
||||
|
||||
level = SRLevel(
|
||||
ticker_id=ticker.id,
|
||||
price_level=90.0,
|
||||
type="support",
|
||||
strength=55,
|
||||
detection_method="pivot_point",
|
||||
)
|
||||
scan_session.add(level)
|
||||
await scan_session.flush()
|
||||
|
||||
setups = await scan_ticker(scan_session, "SINGS", rr_threshold=1.5)
|
||||
|
||||
short_setups = [s for s in setups if s.direction == "short"]
|
||||
assert len(short_setups) == 1, (
|
||||
f"Expected exactly one short setup, got {len(short_setups)}"
|
||||
)
|
||||
assert short_setups[0].target == pytest.approx(90.0, abs=0.01), (
|
||||
f"Expected target=90.0, got {short_setups[0].target}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 7.4 Unit test: get_trade_setups sorting is unchanged (R:R desc, composite desc)
|
||||
# ===========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_trade_setups_sorting_rr_desc_composite_desc(db_session: AsyncSession):
|
||||
"""**Validates: Requirements 3.5**
|
||||
|
||||
get_trade_setups must return results sorted by rr_ratio descending,
|
||||
with composite_score descending as the secondary sort key.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Create tickers for each setup
|
||||
ticker_a = Ticker(symbol="SORTA")
|
||||
ticker_b = Ticker(symbol="SORTB")
|
||||
ticker_c = Ticker(symbol="SORTC")
|
||||
ticker_d = Ticker(symbol="SORTD")
|
||||
db_session.add_all([ticker_a, ticker_b, ticker_c, ticker_d])
|
||||
await db_session.flush()
|
||||
|
||||
# Create setups with different rr_ratio and composite_score values
|
||||
# Expected order: D (rr=5.0), C (rr=3.0, comp=80), B (rr=3.0, comp=50), A (rr=1.5)
|
||||
setup_a = TradeSetup(
|
||||
ticker_id=ticker_a.id, direction="long",
|
||||
entry_price=100.0, stop_loss=97.0, target=104.5,
|
||||
rr_ratio=1.5, composite_score=90.0, detected_at=now,
|
||||
)
|
||||
setup_b = TradeSetup(
|
||||
ticker_id=ticker_b.id, direction="long",
|
||||
entry_price=100.0, stop_loss=97.0, target=109.0,
|
||||
rr_ratio=3.0, composite_score=50.0, detected_at=now,
|
||||
)
|
||||
setup_c = TradeSetup(
|
||||
ticker_id=ticker_c.id, direction="short",
|
||||
entry_price=100.0, stop_loss=103.0, target=91.0,
|
||||
rr_ratio=3.0, composite_score=80.0, detected_at=now,
|
||||
)
|
||||
setup_d = TradeSetup(
|
||||
ticker_id=ticker_d.id, direction="long",
|
||||
entry_price=100.0, stop_loss=97.0, target=115.0,
|
||||
rr_ratio=5.0, composite_score=30.0, detected_at=now,
|
||||
)
|
||||
db_session.add_all([setup_a, setup_b, setup_c, setup_d])
|
||||
await db_session.flush()
|
||||
|
||||
results = await get_trade_setups(db_session)
|
||||
|
||||
assert len(results) == 4, f"Expected 4 setups, got {len(results)}"
|
||||
|
||||
# Verify ordering: rr_ratio desc, then composite_score desc
|
||||
rr_values = [r["rr_ratio"] for r in results]
|
||||
assert rr_values == [5.0, 3.0, 3.0, 1.5], (
|
||||
f"Expected rr_ratio order [5.0, 3.0, 3.0, 1.5], got {rr_values}"
|
||||
)
|
||||
|
||||
# For the two setups with rr_ratio=3.0, composite_score should be desc
|
||||
tied_composites = [r["composite_score"] for r in results if r["rr_ratio"] == 3.0]
|
||||
assert tied_composites == [80.0, 50.0], (
|
||||
f"Expected composite_score order [80.0, 50.0] for tied R:R, "
|
||||
f"got {tied_composites}"
|
||||
)
|
||||
|
||||
# Verify symbols match expected order
|
||||
symbols = [r["symbol"] for r in results]
|
||||
assert symbols == ["SORTD", "SORTC", "SORTB", "SORTA"], (
|
||||
f"Expected symbol order ['SORTD', 'SORTC', 'SORTB', 'SORTA'], "
|
||||
f"got {symbols}"
|
||||
)
|
||||
159
tests/unit/test_rr_scanner_quality_score.py
Normal file
159
tests/unit/test_rr_scanner_quality_score.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Unit tests for _compute_quality_score in rr_scanner_service."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.rr_scanner_service import _compute_quality_score
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4.1 — Known inputs with hand-computed expected outputs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestKnownInputs:
|
||||
def test_typical_candidate(self):
|
||||
# rr=5, strength=80, distance=3, entry_price=100
|
||||
# norm_rr = 5/10 = 0.5
|
||||
# norm_strength = 80/100 = 0.8
|
||||
# norm_proximity = 1 - 3/100 = 0.97
|
||||
# score = 0.35*0.5 + 0.35*0.8 + 0.30*0.97 = 0.175 + 0.28 + 0.291 = 0.746
|
||||
result = _compute_quality_score(rr=5.0, strength=80, distance=3.0, entry_price=100.0)
|
||||
assert result == pytest.approx(0.746, abs=1e-9)
|
||||
|
||||
def test_weak_distant_candidate(self):
|
||||
# rr=8, strength=10, distance=50, entry_price=100
|
||||
# norm_rr = 8/10 = 0.8
|
||||
# norm_strength = 10/100 = 0.1
|
||||
# norm_proximity = 1 - 50/100 = 0.5
|
||||
# score = 0.35*0.8 + 0.35*0.1 + 0.30*0.5 = 0.28 + 0.035 + 0.15 = 0.465
|
||||
result = _compute_quality_score(rr=8.0, strength=10, distance=50.0, entry_price=100.0)
|
||||
assert result == pytest.approx(0.465, abs=1e-9)
|
||||
|
||||
def test_strong_near_candidate(self):
|
||||
# rr=2, strength=95, distance=5, entry_price=200
|
||||
# norm_rr = 2/10 = 0.2
|
||||
# norm_strength = 95/100 = 0.95
|
||||
# norm_proximity = 1 - 5/200 = 0.975
|
||||
# score = 0.35*0.2 + 0.35*0.95 + 0.30*0.975 = 0.07 + 0.3325 + 0.2925 = 0.695
|
||||
result = _compute_quality_score(rr=2.0, strength=95, distance=5.0, entry_price=200.0)
|
||||
assert result == pytest.approx(0.695, abs=1e-9)
|
||||
|
||||
def test_custom_weights(self):
|
||||
# rr=4, strength=50, distance=10, entry_price=100
|
||||
# norm_rr = 4/10 = 0.4, norm_strength = 0.5, norm_proximity = 1 - 10/100 = 0.9
|
||||
# With w_rr=0.5, w_strength=0.3, w_proximity=0.2:
|
||||
# score = 0.5*0.4 + 0.3*0.5 + 0.2*0.9 = 0.2 + 0.15 + 0.18 = 0.53
|
||||
result = _compute_quality_score(
|
||||
rr=4.0, strength=50, distance=10.0, entry_price=100.0,
|
||||
w_rr=0.5, w_strength=0.3, w_proximity=0.2,
|
||||
)
|
||||
assert result == pytest.approx(0.53, abs=1e-9)
|
||||
|
||||
def test_custom_rr_cap(self):
|
||||
# rr=3, strength=60, distance=8, entry_price=100, rr_cap=5
|
||||
# norm_rr = 3/5 = 0.6
|
||||
# norm_strength = 60/100 = 0.6
|
||||
# norm_proximity = 1 - 8/100 = 0.92
|
||||
# score = 0.35*0.6 + 0.35*0.6 + 0.30*0.92 = 0.21 + 0.21 + 0.276 = 0.696
|
||||
result = _compute_quality_score(
|
||||
rr=3.0, strength=60, distance=8.0, entry_price=100.0, rr_cap=5.0,
|
||||
)
|
||||
assert result == pytest.approx(0.696, abs=1e-9)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4.2 — Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_strength_zero(self):
|
||||
# strength=0 → norm_strength=0
|
||||
# rr=5, distance=10, entry_price=100
|
||||
# norm_rr=0.5, norm_strength=0.0, norm_proximity=0.9
|
||||
# score = 0.35*0.5 + 0.35*0.0 + 0.30*0.9 = 0.175 + 0.0 + 0.27 = 0.445
|
||||
result = _compute_quality_score(rr=5.0, strength=0, distance=10.0, entry_price=100.0)
|
||||
assert result == pytest.approx(0.445, abs=1e-9)
|
||||
|
||||
def test_strength_100(self):
|
||||
# strength=100 → norm_strength=1.0
|
||||
# rr=5, distance=10, entry_price=100
|
||||
# norm_rr=0.5, norm_strength=1.0, norm_proximity=0.9
|
||||
# score = 0.35*0.5 + 0.35*1.0 + 0.30*0.9 = 0.175 + 0.35 + 0.27 = 0.795
|
||||
result = _compute_quality_score(rr=5.0, strength=100, distance=10.0, entry_price=100.0)
|
||||
assert result == pytest.approx(0.795, abs=1e-9)
|
||||
|
||||
def test_distance_zero(self):
|
||||
# distance=0 → norm_proximity = 1 - 0/100 = 1.0
|
||||
# rr=5, strength=50, entry_price=100
|
||||
# norm_rr=0.5, norm_strength=0.5, norm_proximity=1.0
|
||||
# score = 0.35*0.5 + 0.35*0.5 + 0.30*1.0 = 0.175 + 0.175 + 0.3 = 0.65
|
||||
result = _compute_quality_score(rr=5.0, strength=50, distance=0.0, entry_price=100.0)
|
||||
assert result == pytest.approx(0.65, abs=1e-9)
|
||||
|
||||
def test_rr_at_cap(self):
|
||||
# rr=10 (== rr_cap) → norm_rr = min(10/10, 1.0) = 1.0
|
||||
# strength=50, distance=10, entry_price=100
|
||||
# norm_strength=0.5, norm_proximity=0.9
|
||||
# score = 0.35*1.0 + 0.35*0.5 + 0.30*0.9 = 0.35 + 0.175 + 0.27 = 0.795
|
||||
result = _compute_quality_score(rr=10.0, strength=50, distance=10.0, entry_price=100.0)
|
||||
assert result == pytest.approx(0.795, abs=1e-9)
|
||||
|
||||
def test_rr_above_cap(self):
|
||||
# rr=15 (> rr_cap=10) → norm_rr = min(15/10, 1.0) = 1.0 (capped)
|
||||
# Same result as rr_at_cap since norm_rr is capped at 1.0
|
||||
result = _compute_quality_score(rr=15.0, strength=50, distance=10.0, entry_price=100.0)
|
||||
assert result == pytest.approx(0.795, abs=1e-9)
|
||||
|
||||
def test_rr_above_cap_equals_rr_at_cap(self):
|
||||
# Explicitly verify capping: rr=15 and rr=10 produce the same score
|
||||
at_cap = _compute_quality_score(rr=10.0, strength=50, distance=10.0, entry_price=100.0)
|
||||
above_cap = _compute_quality_score(rr=15.0, strength=50, distance=10.0, entry_price=100.0)
|
||||
assert at_cap == pytest.approx(above_cap, abs=1e-9)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4.3 — Normalized components stay in 0–1 range
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNormalizedComponentsRange:
|
||||
"""Verify that each normalized component stays within [0, 1]."""
|
||||
|
||||
@pytest.mark.parametrize("rr, rr_cap", [
|
||||
(0.0, 10.0),
|
||||
(5.0, 10.0),
|
||||
(10.0, 10.0),
|
||||
(15.0, 10.0),
|
||||
(100.0, 10.0),
|
||||
(3.0, 5.0),
|
||||
(7.0, 5.0),
|
||||
])
|
||||
def test_norm_rr_in_range(self, rr, rr_cap):
|
||||
norm_rr = min(rr / rr_cap, 1.0)
|
||||
assert 0.0 <= norm_rr <= 1.0
|
||||
|
||||
@pytest.mark.parametrize("strength", [0, 1, 50, 99, 100])
|
||||
def test_norm_strength_in_range(self, strength):
|
||||
norm_strength = strength / 100.0
|
||||
assert 0.0 <= norm_strength <= 1.0
|
||||
|
||||
@pytest.mark.parametrize("distance, entry_price", [
|
||||
(0.0, 100.0),
|
||||
(1.0, 100.0),
|
||||
(50.0, 100.0),
|
||||
(100.0, 100.0),
|
||||
(200.0, 100.0), # distance > entry_price → clamped
|
||||
(0.5, 50.0),
|
||||
])
|
||||
def test_norm_proximity_in_range(self, distance, entry_price):
|
||||
norm_proximity = 1.0 - min(distance / entry_price, 1.0)
|
||||
assert 0.0 <= norm_proximity <= 1.0
|
||||
|
||||
@pytest.mark.parametrize("rr, strength, distance, entry_price", [
|
||||
(0.0, 0, 0.0, 100.0), # all minimums
|
||||
(10.0, 100, 0.0, 100.0), # all maximums
|
||||
(15.0, 100, 200.0, 100.0), # rr above cap, distance > entry
|
||||
(5.0, 50, 50.0, 100.0), # mid-range values
|
||||
(1.0, 1, 99.0, 100.0), # near-minimum non-zero
|
||||
])
|
||||
def test_total_score_in_zero_one(self, rr, strength, distance, entry_price):
|
||||
score = _compute_quality_score(rr=rr, strength=strength, distance=distance, entry_price=entry_price)
|
||||
assert 0.0 <= score <= 1.0
|
||||
205
tests/unit/test_scoring_service_get_score.py
Normal file
205
tests/unit/test_scoring_service_get_score.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Unit tests for get_score composite breakdown and dimension breakdown wiring."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.database import Base
|
||||
from app.models.ticker import Ticker
|
||||
from app.services.scoring_service import get_score, _DIMENSION_COMPUTERS
|
||||
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite://"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def fresh_db():
|
||||
"""Provide a non-transactional session so get_score can commit."""
|
||||
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def _make_ohlcv_records(n: int, base_close: float = 100.0) -> list:
|
||||
"""Create n mock OHLCV records with realistic price data."""
|
||||
records = []
|
||||
for i in range(n):
|
||||
price = base_close + (i * 0.5)
|
||||
records.append(
|
||||
SimpleNamespace(
|
||||
date=date(2024, 1, 1),
|
||||
open=price - 0.5,
|
||||
high=price + 1.0,
|
||||
low=price - 1.0,
|
||||
close=price,
|
||||
volume=1000000,
|
||||
)
|
||||
)
|
||||
return records
|
||||
|
||||
|
||||
def _mock_none_computer():
|
||||
"""Return an AsyncMock that returns (None, None) — simulates missing dimension data."""
|
||||
return AsyncMock(return_value=(None, None))
|
||||
|
||||
|
||||
def _mock_score_computer(score: float, breakdown: dict | None = None):
|
||||
"""Return an AsyncMock that returns a fixed (score, breakdown) tuple."""
|
||||
bd = breakdown or {
|
||||
"sub_scores": [{"name": "mock", "score": score, "weight": 1.0, "raw_value": score, "description": "mock"}],
|
||||
"formula": "mock formula",
|
||||
"unavailable": [],
|
||||
}
|
||||
return AsyncMock(return_value=(score, bd))
|
||||
|
||||
|
||||
async def _seed_ticker(session: AsyncSession, symbol: str = "AAPL") -> Ticker:
|
||||
"""Insert a ticker row and return it."""
|
||||
ticker = Ticker(symbol=symbol)
|
||||
session.add(ticker)
|
||||
await session.commit()
|
||||
return ticker
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_score_returns_composite_breakdown(fresh_db):
|
||||
"""get_score should include a composite_breakdown dict with weights and re-normalization info."""
|
||||
await _seed_ticker(fresh_db, "AAPL")
|
||||
|
||||
original = dict(_DIMENSION_COMPUTERS)
|
||||
try:
|
||||
_DIMENSION_COMPUTERS["technical"] = _mock_score_computer(70.0)
|
||||
_DIMENSION_COMPUTERS["momentum"] = _mock_score_computer(60.0)
|
||||
_DIMENSION_COMPUTERS["sentiment"] = _mock_none_computer()
|
||||
_DIMENSION_COMPUTERS["fundamental"] = _mock_none_computer()
|
||||
_DIMENSION_COMPUTERS["sr_quality"] = _mock_none_computer()
|
||||
|
||||
result = await get_score(fresh_db, "AAPL")
|
||||
finally:
|
||||
_DIMENSION_COMPUTERS.update(original)
|
||||
|
||||
assert "composite_breakdown" in result
|
||||
cb = result["composite_breakdown"]
|
||||
assert cb is not None
|
||||
assert "weights" in cb
|
||||
assert "available_dimensions" in cb
|
||||
assert "missing_dimensions" in cb
|
||||
assert "renormalized_weights" in cb
|
||||
assert "formula" in cb
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_score_composite_breakdown_has_correct_available_missing(fresh_db):
|
||||
"""Composite breakdown should correctly list available and missing dimensions."""
|
||||
await _seed_ticker(fresh_db, "AAPL")
|
||||
|
||||
original = dict(_DIMENSION_COMPUTERS)
|
||||
try:
|
||||
_DIMENSION_COMPUTERS["technical"] = _mock_score_computer(70.0)
|
||||
_DIMENSION_COMPUTERS["momentum"] = _mock_score_computer(60.0)
|
||||
_DIMENSION_COMPUTERS["sentiment"] = _mock_none_computer()
|
||||
_DIMENSION_COMPUTERS["fundamental"] = _mock_none_computer()
|
||||
_DIMENSION_COMPUTERS["sr_quality"] = _mock_none_computer()
|
||||
|
||||
result = await get_score(fresh_db, "AAPL")
|
||||
finally:
|
||||
_DIMENSION_COMPUTERS.update(original)
|
||||
|
||||
cb = result["composite_breakdown"]
|
||||
assert "technical" in cb["available_dimensions"]
|
||||
assert "momentum" in cb["available_dimensions"]
|
||||
assert "sentiment" in cb["missing_dimensions"]
|
||||
assert "fundamental" in cb["missing_dimensions"]
|
||||
assert "sr_quality" in cb["missing_dimensions"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_score_renormalized_weights_sum_to_one(fresh_db):
|
||||
"""Re-normalized weights should sum to 1.0 when at least one dimension is available."""
|
||||
await _seed_ticker(fresh_db, "AAPL")
|
||||
|
||||
original = dict(_DIMENSION_COMPUTERS)
|
||||
try:
|
||||
_DIMENSION_COMPUTERS["technical"] = _mock_score_computer(70.0)
|
||||
_DIMENSION_COMPUTERS["momentum"] = _mock_score_computer(60.0)
|
||||
_DIMENSION_COMPUTERS["sentiment"] = _mock_none_computer()
|
||||
_DIMENSION_COMPUTERS["fundamental"] = _mock_none_computer()
|
||||
_DIMENSION_COMPUTERS["sr_quality"] = _mock_none_computer()
|
||||
|
||||
result = await get_score(fresh_db, "AAPL")
|
||||
finally:
|
||||
_DIMENSION_COMPUTERS.update(original)
|
||||
|
||||
cb = result["composite_breakdown"]
|
||||
assert cb["renormalized_weights"]
|
||||
total = sum(cb["renormalized_weights"].values())
|
||||
assert abs(total - 1.0) < 1e-9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_score_dimensions_include_breakdowns(fresh_db):
|
||||
"""Each available dimension entry should include a breakdown dict."""
|
||||
await _seed_ticker(fresh_db, "AAPL")
|
||||
|
||||
tech_breakdown = {
|
||||
"sub_scores": [
|
||||
{"name": "ADX", "score": 72.0, "weight": 0.4, "raw_value": 72.0, "description": "ADX value"},
|
||||
{"name": "EMA", "score": 65.0, "weight": 0.3, "raw_value": 1.5, "description": "EMA diff"},
|
||||
{"name": "RSI", "score": 62.0, "weight": 0.3, "raw_value": 62.0, "description": "RSI value"},
|
||||
],
|
||||
"formula": "Weighted average: 0.4*ADX + 0.3*EMA + 0.3*RSI",
|
||||
"unavailable": [],
|
||||
}
|
||||
|
||||
original = dict(_DIMENSION_COMPUTERS)
|
||||
try:
|
||||
_DIMENSION_COMPUTERS["technical"] = _mock_score_computer(68.2, tech_breakdown)
|
||||
_DIMENSION_COMPUTERS["momentum"] = _mock_score_computer(55.0)
|
||||
_DIMENSION_COMPUTERS["sentiment"] = _mock_none_computer()
|
||||
_DIMENSION_COMPUTERS["fundamental"] = _mock_none_computer()
|
||||
_DIMENSION_COMPUTERS["sr_quality"] = _mock_none_computer()
|
||||
|
||||
result = await get_score(fresh_db, "AAPL")
|
||||
finally:
|
||||
_DIMENSION_COMPUTERS.update(original)
|
||||
|
||||
tech_dim = next((d for d in result["dimensions"] if d["dimension"] == "technical"), None)
|
||||
assert tech_dim is not None
|
||||
assert "breakdown" in tech_dim
|
||||
assert tech_dim["breakdown"] is not None
|
||||
assert len(tech_dim["breakdown"]["sub_scores"]) == 3
|
||||
names = [s["name"] for s in tech_dim["breakdown"]["sub_scores"]]
|
||||
assert "ADX" in names
|
||||
assert "EMA" in names
|
||||
assert "RSI" in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_score_all_dimensions_missing(fresh_db):
|
||||
"""When all dimensions return None, composite_breakdown should list all as missing."""
|
||||
await _seed_ticker(fresh_db, "AAPL")
|
||||
|
||||
original = dict(_DIMENSION_COMPUTERS)
|
||||
try:
|
||||
for dim in _DIMENSION_COMPUTERS:
|
||||
_DIMENSION_COMPUTERS[dim] = _mock_none_computer()
|
||||
|
||||
result = await get_score(fresh_db, "AAPL")
|
||||
finally:
|
||||
_DIMENSION_COMPUTERS.update(original)
|
||||
|
||||
cb = result["composite_breakdown"]
|
||||
assert cb["available_dimensions"] == []
|
||||
assert len(cb["missing_dimensions"]) == 5
|
||||
assert cb["renormalized_weights"] == {}
|
||||
assert result["composite_score"] is None
|
||||
150
tests/unit/test_scoring_service_momentum.py
Normal file
150
tests/unit/test_scoring_service_momentum.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Unit tests for _compute_momentum_score breakdown refactor."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.scoring_service import _compute_momentum_score
|
||||
|
||||
|
||||
def _make_ohlcv_records(n: int, base_close: float = 100.0) -> list:
|
||||
"""Create n mock OHLCV records with incrementing close prices."""
|
||||
records = []
|
||||
for i in range(n):
|
||||
price = base_close + (i * 0.5)
|
||||
records.append(
|
||||
SimpleNamespace(
|
||||
date=date(2024, 1, 1),
|
||||
open=price - 0.5,
|
||||
high=price + 1.0,
|
||||
low=price - 1.0,
|
||||
close=price,
|
||||
volume=1000000,
|
||||
)
|
||||
)
|
||||
return records
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_none_when_no_records(db_session):
|
||||
"""When no OHLCV data exists, returns (None, None)."""
|
||||
with patch(
|
||||
"app.services.price_service.query_ohlcv",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
score, breakdown = await _compute_momentum_score(db_session, "AAPL")
|
||||
assert score is None
|
||||
assert breakdown is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_none_when_fewer_than_6_records(db_session):
|
||||
"""With fewer than 6 records, returns (None, None)."""
|
||||
records = _make_ohlcv_records(5)
|
||||
with patch(
|
||||
"app.services.price_service.query_ohlcv",
|
||||
new_callable=AsyncMock,
|
||||
return_value=records,
|
||||
):
|
||||
score, breakdown = await _compute_momentum_score(db_session, "AAPL")
|
||||
assert score is None
|
||||
assert breakdown is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_only_5day_roc_when_fewer_than_21_records(db_session):
|
||||
"""With 6-20 records, returns only 5-day ROC sub-score; 20-day ROC is unavailable."""
|
||||
records = _make_ohlcv_records(10)
|
||||
with patch(
|
||||
"app.services.price_service.query_ohlcv",
|
||||
new_callable=AsyncMock,
|
||||
return_value=records,
|
||||
):
|
||||
score, breakdown = await _compute_momentum_score(db_session, "AAPL")
|
||||
|
||||
assert score is not None
|
||||
assert 0.0 <= score <= 100.0
|
||||
|
||||
assert breakdown is not None
|
||||
names = [s["name"] for s in breakdown["sub_scores"]]
|
||||
assert "5-day ROC" in names
|
||||
assert "20-day ROC" not in names
|
||||
|
||||
# 20-day ROC should be unavailable
|
||||
unavail_names = [u["name"] for u in breakdown["unavailable"]]
|
||||
assert "20-day ROC" in unavail_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_both_sub_scores_with_enough_data(db_session):
|
||||
"""With 21+ records, returns both 5-day and 20-day ROC sub-scores."""
|
||||
records = _make_ohlcv_records(30)
|
||||
with patch(
|
||||
"app.services.price_service.query_ohlcv",
|
||||
new_callable=AsyncMock,
|
||||
return_value=records,
|
||||
):
|
||||
score, breakdown = await _compute_momentum_score(db_session, "AAPL")
|
||||
|
||||
assert score is not None
|
||||
assert 0.0 <= score <= 100.0
|
||||
|
||||
assert breakdown is not None
|
||||
names = [s["name"] for s in breakdown["sub_scores"]]
|
||||
assert "5-day ROC" in names
|
||||
assert "20-day ROC" in names
|
||||
|
||||
# Verify weights
|
||||
weight_map = {s["name"]: s["weight"] for s in breakdown["sub_scores"]}
|
||||
assert weight_map["5-day ROC"] == 0.5
|
||||
assert weight_map["20-day ROC"] == 0.5
|
||||
|
||||
# Verify raw_value is present and numeric
|
||||
for sub in breakdown["sub_scores"]:
|
||||
assert sub["raw_value"] is not None
|
||||
assert isinstance(sub["raw_value"], (int, float))
|
||||
assert sub["description"]
|
||||
|
||||
assert breakdown["unavailable"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_formula_string_present(db_session):
|
||||
"""Breakdown always includes the formula description."""
|
||||
records = _make_ohlcv_records(30)
|
||||
with patch(
|
||||
"app.services.price_service.query_ohlcv",
|
||||
new_callable=AsyncMock,
|
||||
return_value=records,
|
||||
):
|
||||
_, breakdown = await _compute_momentum_score(db_session, "AAPL")
|
||||
|
||||
assert "formula" in breakdown
|
||||
assert "ROC_5" in breakdown["formula"]
|
||||
assert "ROC_20" in breakdown["formula"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_values_are_roc_percentages(db_session):
|
||||
"""Raw values should be ROC percentages matching the actual price change."""
|
||||
records = _make_ohlcv_records(30, base_close=100.0)
|
||||
with patch(
|
||||
"app.services.price_service.query_ohlcv",
|
||||
new_callable=AsyncMock,
|
||||
return_value=records,
|
||||
):
|
||||
_, breakdown = await _compute_momentum_score(db_session, "AAPL")
|
||||
|
||||
closes = [100.0 + i * 0.5 for i in range(30)]
|
||||
latest = closes[-1]
|
||||
expected_roc_5 = (latest - closes[-6]) / closes[-6] * 100.0
|
||||
expected_roc_20 = (latest - closes[-21]) / closes[-21] * 100.0
|
||||
|
||||
roc_map = {s["name"]: s["raw_value"] for s in breakdown["sub_scores"]}
|
||||
assert abs(roc_map["5-day ROC"] - round(expected_roc_5, 4)) < 1e-6
|
||||
assert abs(roc_map["20-day ROC"] - round(expected_roc_20, 4)) < 1e-6
|
||||
138
tests/unit/test_scoring_service_sentiment.py
Normal file
138
tests/unit/test_scoring_service_sentiment.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Unit tests for _compute_sentiment_score breakdown refactor."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.scoring_service import _compute_sentiment_score
|
||||
|
||||
|
||||
def _make_sentiment_records(n: int, classification: str = "bullish", confidence: int = 80) -> list:
|
||||
"""Create n mock sentiment records with recent timestamps."""
|
||||
now = datetime.now(timezone.utc)
|
||||
records = []
|
||||
for i in range(n):
|
||||
records.append(
|
||||
SimpleNamespace(
|
||||
classification=classification,
|
||||
confidence=confidence,
|
||||
source="test",
|
||||
timestamp=now,
|
||||
reasoning="",
|
||||
citations_json="[]",
|
||||
)
|
||||
)
|
||||
return records
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_with_breakdown_when_no_records(db_session):
|
||||
"""When no sentiment records exist, returns (None, breakdown) with unavailable entry."""
|
||||
with patch(
|
||||
"app.services.sentiment_service.get_sentiment_scores",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
score, breakdown = await _compute_sentiment_score(db_session, "AAPL")
|
||||
|
||||
assert score is None
|
||||
assert breakdown is not None
|
||||
assert breakdown["sub_scores"] == []
|
||||
assert len(breakdown["unavailable"]) == 1
|
||||
assert breakdown["unavailable"][0]["name"] == "sentiment_records"
|
||||
assert "formula" in breakdown
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_none_when_get_scores_raises(db_session):
|
||||
"""When get_sentiment_scores raises, returns (None, None)."""
|
||||
with patch(
|
||||
"app.services.sentiment_service.get_sentiment_scores",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
score, breakdown = await _compute_sentiment_score(db_session, "AAPL")
|
||||
|
||||
assert score is None
|
||||
assert breakdown is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_breakdown_with_sub_scores(db_session):
|
||||
"""With sentiment records, returns score and breakdown with expected sub-scores."""
|
||||
records = _make_sentiment_records(3)
|
||||
|
||||
with patch(
|
||||
"app.services.sentiment_service.get_sentiment_scores",
|
||||
new_callable=AsyncMock,
|
||||
return_value=records,
|
||||
), patch(
|
||||
"app.services.sentiment_service.compute_sentiment_dimension_score",
|
||||
new_callable=AsyncMock,
|
||||
return_value=75.0,
|
||||
):
|
||||
score, breakdown = await _compute_sentiment_score(db_session, "AAPL")
|
||||
|
||||
assert score == 75.0
|
||||
assert breakdown is not None
|
||||
assert "sub_scores" in breakdown
|
||||
assert "formula" in breakdown
|
||||
assert "unavailable" in breakdown
|
||||
|
||||
names = [s["name"] for s in breakdown["sub_scores"]]
|
||||
assert "record_count" in names
|
||||
assert "decay_rate" in names
|
||||
assert "lookback_window" in names
|
||||
|
||||
# Verify raw values
|
||||
raw_map = {s["name"]: s["raw_value"] for s in breakdown["sub_scores"]}
|
||||
assert raw_map["record_count"] == 3
|
||||
assert raw_map["decay_rate"] == 0.1
|
||||
assert raw_map["lookback_window"] == 24
|
||||
|
||||
assert breakdown["unavailable"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_formula_contains_decay_info(db_session):
|
||||
"""Breakdown formula describes the time-decay weighted average."""
|
||||
records = _make_sentiment_records(2)
|
||||
|
||||
with patch(
|
||||
"app.services.sentiment_service.get_sentiment_scores",
|
||||
new_callable=AsyncMock,
|
||||
return_value=records,
|
||||
), patch(
|
||||
"app.services.sentiment_service.compute_sentiment_dimension_score",
|
||||
new_callable=AsyncMock,
|
||||
return_value=60.0,
|
||||
):
|
||||
_, breakdown = await _compute_sentiment_score(db_session, "AAPL")
|
||||
|
||||
assert "Time-decay" in breakdown["formula"]
|
||||
assert "decay_rate" in breakdown["formula"]
|
||||
assert "24" in breakdown["formula"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sub_scores_have_descriptions(db_session):
|
||||
"""Each sub-score has a non-empty description."""
|
||||
records = _make_sentiment_records(1)
|
||||
|
||||
with patch(
|
||||
"app.services.sentiment_service.get_sentiment_scores",
|
||||
new_callable=AsyncMock,
|
||||
return_value=records,
|
||||
), patch(
|
||||
"app.services.sentiment_service.compute_sentiment_dimension_score",
|
||||
new_callable=AsyncMock,
|
||||
return_value=50.0,
|
||||
):
|
||||
_, breakdown = await _compute_sentiment_score(db_session, "AAPL")
|
||||
|
||||
for sub in breakdown["sub_scores"]:
|
||||
assert sub["description"], f"Sub-score {sub['name']} missing description"
|
||||
151
tests/unit/test_scoring_service_technical.py
Normal file
151
tests/unit/test_scoring_service_technical.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Unit tests for _compute_technical_score breakdown refactor."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.scoring_service import _compute_technical_score
|
||||
|
||||
|
||||
def _make_ohlcv_records(n: int, base_close: float = 100.0) -> list:
|
||||
"""Create n mock OHLCV records with realistic price data."""
|
||||
records = []
|
||||
for i in range(n):
|
||||
price = base_close + (i * 0.5)
|
||||
records.append(
|
||||
SimpleNamespace(
|
||||
date=date(2024, 1, 1),
|
||||
open=price - 0.5,
|
||||
high=price + 1.0,
|
||||
low=price - 1.0,
|
||||
close=price,
|
||||
volume=1000000,
|
||||
)
|
||||
)
|
||||
return records
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_tuple_when_no_records(db_session):
|
||||
"""When no OHLCV data exists, returns (None, None)."""
|
||||
with patch(
|
||||
"app.services.price_service.query_ohlcv",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
score, breakdown = await _compute_technical_score(db_session, "AAPL")
|
||||
assert score is None
|
||||
assert breakdown is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_breakdown_with_all_sub_scores(db_session):
|
||||
"""With enough data, returns score and breakdown with ADX, EMA, RSI sub-scores."""
|
||||
records = _make_ohlcv_records(50)
|
||||
|
||||
with patch(
|
||||
"app.services.price_service.query_ohlcv",
|
||||
new_callable=AsyncMock,
|
||||
return_value=records,
|
||||
):
|
||||
score, breakdown = await _compute_technical_score(db_session, "AAPL")
|
||||
|
||||
assert score is not None
|
||||
assert 0.0 <= score <= 100.0
|
||||
|
||||
assert breakdown is not None
|
||||
assert "sub_scores" in breakdown
|
||||
assert "formula" in breakdown
|
||||
assert "unavailable" in breakdown
|
||||
|
||||
names = [s["name"] for s in breakdown["sub_scores"]]
|
||||
assert "ADX" in names
|
||||
assert "EMA" in names
|
||||
assert "RSI" in names
|
||||
|
||||
# Verify weights
|
||||
weight_map = {s["name"]: s["weight"] for s in breakdown["sub_scores"]}
|
||||
assert weight_map["ADX"] == 0.4
|
||||
assert weight_map["EMA"] == 0.3
|
||||
assert weight_map["RSI"] == 0.3
|
||||
|
||||
# Verify raw_value is present and numeric
|
||||
for sub in breakdown["sub_scores"]:
|
||||
assert sub["raw_value"] is not None
|
||||
assert isinstance(sub["raw_value"], (int, float))
|
||||
assert sub["description"]
|
||||
|
||||
assert breakdown["unavailable"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_sub_scores_with_insufficient_data(db_session):
|
||||
"""With limited data (enough for EMA/RSI but not ADX), returns partial breakdown."""
|
||||
# 22 bars: enough for EMA(20) and RSI(14) but not ADX (needs 28)
|
||||
records = _make_ohlcv_records(22)
|
||||
|
||||
with patch(
|
||||
"app.services.price_service.query_ohlcv",
|
||||
new_callable=AsyncMock,
|
||||
return_value=records,
|
||||
):
|
||||
score, breakdown = await _compute_technical_score(db_session, "AAPL")
|
||||
|
||||
assert score is not None
|
||||
assert breakdown is not None
|
||||
|
||||
names = [s["name"] for s in breakdown["sub_scores"]]
|
||||
assert "EMA" in names
|
||||
assert "RSI" in names
|
||||
assert "ADX" not in names
|
||||
|
||||
# ADX should be in unavailable
|
||||
unavail_names = [u["name"] for u in breakdown["unavailable"]]
|
||||
assert "ADX" in unavail_names
|
||||
assert any(u["reason"] for u in breakdown["unavailable"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_sub_scores_unavailable(db_session):
|
||||
"""With very few bars (not enough for any indicator), returns None score with breakdown."""
|
||||
# 5 bars: not enough for any indicator
|
||||
records = _make_ohlcv_records(5)
|
||||
|
||||
with patch(
|
||||
"app.services.price_service.query_ohlcv",
|
||||
new_callable=AsyncMock,
|
||||
return_value=records,
|
||||
):
|
||||
score, breakdown = await _compute_technical_score(db_session, "AAPL")
|
||||
|
||||
assert score is None
|
||||
assert breakdown is not None
|
||||
assert breakdown["sub_scores"] == []
|
||||
assert len(breakdown["unavailable"]) == 3
|
||||
|
||||
unavail_names = [u["name"] for u in breakdown["unavailable"]]
|
||||
assert "ADX" in unavail_names
|
||||
assert "EMA" in unavail_names
|
||||
assert "RSI" in unavail_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_formula_string_present(db_session):
|
||||
"""Breakdown always includes the formula description."""
|
||||
records = _make_ohlcv_records(50)
|
||||
|
||||
with patch(
|
||||
"app.services.price_service.query_ohlcv",
|
||||
new_callable=AsyncMock,
|
||||
return_value=records,
|
||||
):
|
||||
_, breakdown = await _compute_technical_score(db_session, "AAPL")
|
||||
|
||||
assert "formula" in breakdown
|
||||
assert "ADX" in breakdown["formula"]
|
||||
assert "EMA" in breakdown["formula"]
|
||||
assert "RSI" in breakdown["formula"]
|
||||
243
tests/unit/test_sr_levels_router.py
Normal file
243
tests/unit/test_sr_levels_router.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""Unit tests for the S/R levels router — zone integration."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.middleware import register_exception_handlers
|
||||
from app.routers.sr_levels import router
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeLevel:
|
||||
"""Mimics an SRLevel ORM model."""
|
||||
|
||||
def __init__(self, id, price_level, type, strength, detection_method):
|
||||
self.id = id
|
||||
self.price_level = price_level
|
||||
self.type = type
|
||||
self.strength = strength
|
||||
self.detection_method = detection_method
|
||||
self.created_at = datetime(2024, 1, 1)
|
||||
|
||||
|
||||
class _FakeOHLCV:
|
||||
"""Mimics an OHLCVRecord with a close attribute."""
|
||||
|
||||
def __init__(self, close: float):
|
||||
self.close = close
|
||||
|
||||
|
||||
def _make_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
register_exception_handlers(app)
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
# Override auth dependency to no-op
|
||||
from app.dependencies import require_access, get_db
|
||||
|
||||
app.dependency_overrides[require_access] = lambda: None
|
||||
app.dependency_overrides[get_db] = lambda: AsyncMock()
|
||||
return app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
SAMPLE_LEVELS = [
|
||||
_FakeLevel(1, 95.0, "support", 60, "volume_profile"),
|
||||
_FakeLevel(2, 96.0, "support", 40, "pivot_point"),
|
||||
_FakeLevel(3, 110.0, "resistance", 80, "merged"),
|
||||
]
|
||||
|
||||
SAMPLE_OHLCV = [_FakeOHLCV(100.0)]
|
||||
|
||||
|
||||
class TestSRLevelsRouterZones:
|
||||
"""Tests for max_zones parameter and zone inclusion in response."""
|
||||
|
||||
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
|
||||
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
|
||||
def test_default_max_zones_returns_zones(self, mock_get_sr, mock_ohlcv):
|
||||
mock_get_sr.return_value = SAMPLE_LEVELS
|
||||
mock_ohlcv.return_value = SAMPLE_OHLCV
|
||||
|
||||
app = _make_app()
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/v1/sr-levels/AAPL")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["status"] == "success"
|
||||
data = body["data"]
|
||||
assert "zones" in data
|
||||
assert isinstance(data["zones"], list)
|
||||
# With default max_zones=6, we should get zones
|
||||
assert len(data["zones"]) > 0
|
||||
|
||||
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
|
||||
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
|
||||
def test_max_zones_zero_returns_empty_zones(self, mock_get_sr, mock_ohlcv):
|
||||
mock_get_sr.return_value = SAMPLE_LEVELS
|
||||
mock_ohlcv.return_value = SAMPLE_OHLCV
|
||||
|
||||
app = _make_app()
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/v1/sr-levels/AAPL?max_zones=0")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()["data"]
|
||||
assert data["zones"] == []
|
||||
|
||||
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
|
||||
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
|
||||
def test_max_zones_limits_zone_count(self, mock_get_sr, mock_ohlcv):
|
||||
mock_get_sr.return_value = SAMPLE_LEVELS
|
||||
mock_ohlcv.return_value = SAMPLE_OHLCV
|
||||
|
||||
app = _make_app()
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/v1/sr-levels/AAPL?max_zones=1")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()["data"]
|
||||
assert len(data["zones"]) <= 1
|
||||
|
||||
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
|
||||
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
|
||||
def test_no_ohlcv_data_returns_empty_zones(self, mock_get_sr, mock_ohlcv):
|
||||
mock_get_sr.return_value = SAMPLE_LEVELS
|
||||
mock_ohlcv.return_value = [] # No OHLCV data
|
||||
|
||||
app = _make_app()
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/v1/sr-levels/AAPL")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()["data"]
|
||||
assert data["zones"] == []
|
||||
# Levels should still be present
|
||||
assert len(data["levels"]) == 3
|
||||
|
||||
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
|
||||
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
|
||||
def test_no_levels_returns_empty_zones(self, mock_get_sr, mock_ohlcv):
|
||||
mock_get_sr.return_value = []
|
||||
mock_ohlcv.return_value = SAMPLE_OHLCV
|
||||
|
||||
app = _make_app()
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/v1/sr-levels/AAPL")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()["data"]
|
||||
assert data["zones"] == []
|
||||
assert data["levels"] == []
|
||||
assert data["count"] == 0
|
||||
|
||||
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
|
||||
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
|
||||
def test_zone_fields_present(self, mock_get_sr, mock_ohlcv):
|
||||
mock_get_sr.return_value = SAMPLE_LEVELS
|
||||
mock_ohlcv.return_value = SAMPLE_OHLCV
|
||||
|
||||
app = _make_app()
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/v1/sr-levels/AAPL")
|
||||
|
||||
data = resp.json()["data"]
|
||||
for zone in data["zones"]:
|
||||
assert "low" in zone
|
||||
assert "high" in zone
|
||||
assert "midpoint" in zone
|
||||
assert "strength" in zone
|
||||
assert "type" in zone
|
||||
assert "level_count" in zone
|
||||
assert zone["type"] in ("support", "resistance")
|
||||
|
||||
|
||||
class TestSRLevelsRouterVisibleLevels:
|
||||
"""Tests for visible_levels filtering in the SR levels response."""
|
||||
|
||||
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
|
||||
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
|
||||
def test_visible_levels_present_in_response(self, mock_get_sr, mock_ohlcv):
|
||||
"""visible_levels field is always present in the API response."""
|
||||
mock_get_sr.return_value = SAMPLE_LEVELS
|
||||
mock_ohlcv.return_value = SAMPLE_OHLCV
|
||||
|
||||
app = _make_app()
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/v1/sr-levels/AAPL")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()["data"]
|
||||
assert "visible_levels" in data
|
||||
assert isinstance(data["visible_levels"], list)
|
||||
|
||||
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
|
||||
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
|
||||
def test_visible_levels_within_zone_bounds(self, mock_get_sr, mock_ohlcv):
|
||||
"""Every visible level has a price within at least one zone's [low, high] range."""
|
||||
mock_get_sr.return_value = SAMPLE_LEVELS
|
||||
mock_ohlcv.return_value = SAMPLE_OHLCV
|
||||
|
||||
app = _make_app()
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/v1/sr-levels/AAPL")
|
||||
|
||||
data = resp.json()["data"]
|
||||
zones = data["zones"]
|
||||
visible = data["visible_levels"]
|
||||
|
||||
# When zones exist, each visible level must fall within a zone
|
||||
for lvl in visible:
|
||||
price = lvl["price_level"]
|
||||
assert any(
|
||||
z["low"] <= price <= z["high"] for z in zones
|
||||
), f"visible level price {price} not within any zone bounds"
|
||||
|
||||
# visible_levels must be a subset of levels (by id)
|
||||
level_ids = {l["id"] for l in data["levels"]}
|
||||
for lvl in visible:
|
||||
assert lvl["id"] in level_ids
|
||||
|
||||
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
|
||||
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
|
||||
def test_visible_levels_empty_when_no_ohlcv(self, mock_get_sr, mock_ohlcv):
|
||||
"""visible_levels is empty when no OHLCV data exists (zones are empty)."""
|
||||
mock_get_sr.return_value = SAMPLE_LEVELS
|
||||
mock_ohlcv.return_value = []
|
||||
|
||||
app = _make_app()
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/v1/sr-levels/AAPL")
|
||||
|
||||
data = resp.json()["data"]
|
||||
assert data["zones"] == []
|
||||
assert data["visible_levels"] == []
|
||||
|
||||
@patch("app.routers.sr_levels.query_ohlcv", new_callable=AsyncMock)
|
||||
@patch("app.routers.sr_levels.get_sr_levels", new_callable=AsyncMock)
|
||||
def test_visible_levels_empty_when_max_zones_zero(self, mock_get_sr, mock_ohlcv):
|
||||
"""visible_levels is empty when max_zones=0 (zones are empty)."""
|
||||
mock_get_sr.return_value = SAMPLE_LEVELS
|
||||
mock_ohlcv.return_value = SAMPLE_OHLCV
|
||||
|
||||
app = _make_app()
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/v1/sr-levels/AAPL?max_zones=0")
|
||||
|
||||
data = resp.json()["data"]
|
||||
assert data["zones"] == []
|
||||
assert data["visible_levels"] == []
|
||||
|
||||
Reference in New Issue
Block a user