Add local backtest snapshot runner
This commit is contained in:
@@ -0,0 +1,162 @@
|
||||
"""Create a minimal local SQLite snapshot for offline backtest research.
|
||||
|
||||
Copies only the data required by app.services.backtest_service.run_backtest:
|
||||
tickers, OHLCV bars, SPY benchmark closes, and activation/recommendation
|
||||
settings. Other system settings are intentionally skipped to avoid copying
|
||||
secrets into local snapshot files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import func, insert, or_, select
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
|
||||
def _normalize_postgres_url(url: str) -> str:
|
||||
if url.startswith("postgresql+asyncpg://"):
|
||||
return url
|
||||
if url.startswith("postgresql://"):
|
||||
return "postgresql+asyncpg://" + url[len("postgresql://") :]
|
||||
if url.startswith("postgres://"):
|
||||
return "postgresql+asyncpg://" + url[len("postgres://") :]
|
||||
return url
|
||||
|
||||
|
||||
def _sqlite_url(path: Path) -> str:
|
||||
return f"sqlite+aiosqlite:///{path.resolve().as_posix()}"
|
||||
|
||||
|
||||
def _hide_password(url: str) -> str:
|
||||
return make_url(url).render_as_string(hide_password=True)
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--database-url",
|
||||
default=os.getenv("DATABASE_URL"),
|
||||
help="Source Postgres URL. Defaults to DATABASE_URL, then app .env database_url.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default="backtest_snapshots/prod-backtest.sqlite",
|
||||
help="SQLite snapshot path to create.",
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, default=5000)
|
||||
parser.add_argument("--force", action="store_true", help="Overwrite an existing snapshot file.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def _copy_table(
|
||||
source: AsyncSession,
|
||||
dest: AsyncSession,
|
||||
model: type,
|
||||
*,
|
||||
batch_size: int,
|
||||
where=None,
|
||||
) -> int:
|
||||
table = model.__table__
|
||||
columns = list(table.columns)
|
||||
|
||||
count_stmt = select(func.count()).select_from(table)
|
||||
stmt = select(*columns)
|
||||
if where is not None:
|
||||
count_stmt = count_stmt.where(where)
|
||||
stmt = stmt.where(where)
|
||||
primary_key_columns = list(table.primary_key.columns)
|
||||
if primary_key_columns:
|
||||
stmt = stmt.order_by(*primary_key_columns)
|
||||
|
||||
expected = int((await source.execute(count_stmt)).scalar_one())
|
||||
if expected == 0:
|
||||
print(f"{table.name}: 0 rows")
|
||||
return 0
|
||||
|
||||
copied = 0
|
||||
stream = await source.stream(stmt.execution_options(yield_per=batch_size))
|
||||
async for partition in stream.partitions(batch_size):
|
||||
rows = [dict(row._mapping) for row in partition]
|
||||
if not rows:
|
||||
continue
|
||||
await dest.execute(insert(table), rows)
|
||||
await dest.commit()
|
||||
copied += len(rows)
|
||||
print(f"{table.name}: {copied}/{expected}", end="\r")
|
||||
|
||||
print(f"{table.name}: {copied} rows")
|
||||
return copied
|
||||
|
||||
|
||||
async def _main() -> None:
|
||||
args = _parse_args()
|
||||
|
||||
from app.config import settings
|
||||
from app.database import Base
|
||||
import app.models # noqa: F401 - registers all metadata tables
|
||||
from app.models.benchmark_price import BenchmarkPrice
|
||||
from app.models.ohlcv import OHLCVRecord
|
||||
from app.models.settings import SystemSetting
|
||||
from app.models.ticker import Ticker
|
||||
|
||||
source_url = _normalize_postgres_url(args.database_url or settings.database_url)
|
||||
output = Path(args.output)
|
||||
if output.exists():
|
||||
if not args.force:
|
||||
raise SystemExit(f"{output} already exists. Pass --force to overwrite it.")
|
||||
output.unlink()
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
source_engine = create_async_engine(
|
||||
source_url,
|
||||
pool_pre_ping=True,
|
||||
connect_args={"server_settings": {"default_transaction_read_only": "on"}},
|
||||
)
|
||||
dest_engine = create_async_engine(_sqlite_url(output))
|
||||
SourceSession = async_sessionmaker(source_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
DestSession = async_sessionmaker(dest_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
print(f"Source: {_hide_password(source_url)}")
|
||||
print(f"Snapshot: {output}")
|
||||
|
||||
try:
|
||||
async with dest_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with SourceSession() as source, DestSession() as dest:
|
||||
counts = {
|
||||
"tickers": await _copy_table(source, dest, Ticker, batch_size=args.batch_size),
|
||||
"system_settings": await _copy_table(
|
||||
source,
|
||||
dest,
|
||||
SystemSetting,
|
||||
batch_size=args.batch_size,
|
||||
where=or_(
|
||||
SystemSetting.key.like("activation_%"),
|
||||
SystemSetting.key.like("recommendation_%"),
|
||||
),
|
||||
),
|
||||
"benchmark_prices": await _copy_table(source, dest, BenchmarkPrice, batch_size=args.batch_size),
|
||||
"ohlcv_records": await _copy_table(source, dest, OHLCVRecord, batch_size=args.batch_size),
|
||||
}
|
||||
finally:
|
||||
await source_engine.dispose()
|
||||
await dest_engine.dispose()
|
||||
|
||||
print("Done:")
|
||||
for name, count in counts.items():
|
||||
print(f" {name}: {count}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(_main())
|
||||
@@ -0,0 +1,139 @@
|
||||
"""Run the existing backtest service against a local SQLite snapshot.
|
||||
|
||||
The runner is offline/read-only: it does not refresh benchmark prices and does
|
||||
not cache the report back to any database. It writes a local JSON report.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
|
||||
def _sqlite_url(path: Path) -> str:
|
||||
return f"sqlite+aiosqlite:///{path.resolve().as_posix()}"
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("snapshot", help="SQLite snapshot created by create_backtest_snapshot.py.")
|
||||
parser.add_argument(
|
||||
"--out",
|
||||
default=None,
|
||||
help="JSON report path. Defaults to reports/backtest-<timestamp>.json.",
|
||||
)
|
||||
parser.add_argument("--workers", type=int, default=None, help="Override backtest worker count.")
|
||||
parser.add_argument(
|
||||
"--allow-spawn",
|
||||
action="store_true",
|
||||
help="Allow spawn multiprocessing for offline CLI runs, useful on Windows.",
|
||||
)
|
||||
parser.add_argument("--quiet", action="store_true", help="Hide progress output.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _default_output_path() -> Path:
|
||||
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
return Path("reports") / f"backtest-{stamp}.json"
|
||||
|
||||
|
||||
def _pct(value: Any) -> str:
|
||||
return "-" if value is None else f"{float(value):+.1f}%"
|
||||
|
||||
|
||||
def _r(value: Any) -> str:
|
||||
return "-" if value is None else f"{float(value):+.2f}R"
|
||||
|
||||
|
||||
def _print_summary(report: dict) -> None:
|
||||
qualified = report.get("overall_qualified") or {}
|
||||
all_setups = report.get("overall_all") or {}
|
||||
time_exit = {row.get("hold_days"): row for row in report.get("time_exit_sweep") or []}
|
||||
hold_30 = time_exit.get(30) or {}
|
||||
policies = {
|
||||
row.get("policy"): row
|
||||
for row in ((report.get("portfolio_sim") or {}).get("policies") or [])
|
||||
}
|
||||
hold_policy = policies.get("hold") or {}
|
||||
|
||||
print("")
|
||||
print("Backtest summary")
|
||||
print(f" candidates: {report.get('candidates')}")
|
||||
print(f" qualified: {report.get('qualified')}")
|
||||
print(f" all setups net avg R: {_r(all_setups.get('net_avg_r'))}")
|
||||
print(f" qualified net avg R: {_r(qualified.get('net_avg_r'))}")
|
||||
print(f" qualified total R: {_r(qualified.get('total_r'))}")
|
||||
print(f" 30d hold net avg R: {_r(hold_30.get('net_avg_r'))}")
|
||||
print(f" 30d hold total R: {_r(hold_30.get('total_r'))}")
|
||||
if hold_policy:
|
||||
print(f" hold CAGR: {_pct(hold_policy.get('cagr_pct'))}")
|
||||
print(f" hold max drawdown: {_pct(hold_policy.get('max_drawdown_pct'))}")
|
||||
print(f" hold Sharpe: {hold_policy.get('sharpe')}")
|
||||
print(f" hold trades: {hold_policy.get('trades')}")
|
||||
|
||||
|
||||
async def _main() -> None:
|
||||
args = _parse_args()
|
||||
snapshot = Path(args.snapshot)
|
||||
if not snapshot.exists():
|
||||
raise SystemExit(f"Snapshot not found: {snapshot}")
|
||||
|
||||
os.environ["BACKTEST_SNAPSHOT_OFFLINE"] = "1"
|
||||
if args.allow_spawn:
|
||||
os.environ["BACKTEST_ALLOW_SPAWN"] = "1"
|
||||
|
||||
from app.config import settings
|
||||
from app.services.backtest_service import run_backtest
|
||||
|
||||
if args.workers is not None:
|
||||
settings.backtest_workers = args.workers
|
||||
|
||||
output = Path(args.out) if args.out else _default_output_path()
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
engine = create_async_engine(_sqlite_url(snapshot), pool_pre_ping=True)
|
||||
Session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
last_progress: tuple[int, int] | None = None
|
||||
|
||||
def progress(done: int, total: int, symbol: str) -> None:
|
||||
nonlocal last_progress
|
||||
if args.quiet:
|
||||
return
|
||||
marker = (done, total)
|
||||
if marker == last_progress:
|
||||
return
|
||||
last_progress = marker
|
||||
label = f" {symbol}" if symbol else ""
|
||||
print(f"progress: {done}/{total}{label}", end="\r")
|
||||
|
||||
try:
|
||||
async with Session() as db:
|
||||
report = await run_backtest(db, progress_cb=progress)
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
if not args.quiet:
|
||||
print("")
|
||||
with output.open("w", encoding="utf-8") as fh:
|
||||
json.dump(report, fh, indent=2)
|
||||
fh.write("\n")
|
||||
|
||||
print(f"Report written: {output}")
|
||||
_print_summary(report)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(_main())
|
||||
Reference in New Issue
Block a user