Files
signal-platform/scripts/create_backtest_snapshot.py
T
dennisthiessen 66ef0564c1
Deploy / lint (push) Successful in 8s
Deploy / test (push) Successful in 1m22s
Deploy / deploy (push) Successful in 43s
Add local backtest snapshot runner
2026-07-03 18:35:07 +02:00

163 lines
5.4 KiB
Python

"""Create a minimal local SQLite snapshot for offline backtest research.
Copies only the data required by app.services.backtest_service.run_backtest:
tickers, OHLCV bars, SPY benchmark closes, and activation/recommendation
settings. Other system settings are intentionally skipped to avoid copying
secrets into local snapshot files.
"""
from __future__ import annotations
import argparse
import asyncio
import os
import sys
from pathlib import Path
from sqlalchemy import func, insert, or_, select
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
def _normalize_postgres_url(url: str) -> str:
if url.startswith("postgresql+asyncpg://"):
return url
if url.startswith("postgresql://"):
return "postgresql+asyncpg://" + url[len("postgresql://") :]
if url.startswith("postgres://"):
return "postgresql+asyncpg://" + url[len("postgres://") :]
return url
def _sqlite_url(path: Path) -> str:
return f"sqlite+aiosqlite:///{path.resolve().as_posix()}"
def _hide_password(url: str) -> str:
return make_url(url).render_as_string(hide_password=True)
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--database-url",
default=os.getenv("DATABASE_URL"),
help="Source Postgres URL. Defaults to DATABASE_URL, then app .env database_url.",
)
parser.add_argument(
"--output",
default="backtest_snapshots/prod-backtest.sqlite",
help="SQLite snapshot path to create.",
)
parser.add_argument("--batch-size", type=int, default=5000)
parser.add_argument("--force", action="store_true", help="Overwrite an existing snapshot file.")
return parser.parse_args()
async def _copy_table(
source: AsyncSession,
dest: AsyncSession,
model: type,
*,
batch_size: int,
where=None,
) -> int:
table = model.__table__
columns = list(table.columns)
count_stmt = select(func.count()).select_from(table)
stmt = select(*columns)
if where is not None:
count_stmt = count_stmt.where(where)
stmt = stmt.where(where)
primary_key_columns = list(table.primary_key.columns)
if primary_key_columns:
stmt = stmt.order_by(*primary_key_columns)
expected = int((await source.execute(count_stmt)).scalar_one())
if expected == 0:
print(f"{table.name}: 0 rows")
return 0
copied = 0
stream = await source.stream(stmt.execution_options(yield_per=batch_size))
async for partition in stream.partitions(batch_size):
rows = [dict(row._mapping) for row in partition]
if not rows:
continue
await dest.execute(insert(table), rows)
await dest.commit()
copied += len(rows)
print(f"{table.name}: {copied}/{expected}", end="\r")
print(f"{table.name}: {copied} rows")
return copied
async def _main() -> None:
args = _parse_args()
from app.config import settings
from app.database import Base
import app.models # noqa: F401 - registers all metadata tables
from app.models.benchmark_price import BenchmarkPrice
from app.models.ohlcv import OHLCVRecord
from app.models.settings import SystemSetting
from app.models.ticker import Ticker
source_url = _normalize_postgres_url(args.database_url or settings.database_url)
output = Path(args.output)
if output.exists():
if not args.force:
raise SystemExit(f"{output} already exists. Pass --force to overwrite it.")
output.unlink()
output.parent.mkdir(parents=True, exist_ok=True)
source_engine = create_async_engine(
source_url,
pool_pre_ping=True,
connect_args={"server_settings": {"default_transaction_read_only": "on"}},
)
dest_engine = create_async_engine(_sqlite_url(output))
SourceSession = async_sessionmaker(source_engine, class_=AsyncSession, expire_on_commit=False)
DestSession = async_sessionmaker(dest_engine, class_=AsyncSession, expire_on_commit=False)
print(f"Source: {_hide_password(source_url)}")
print(f"Snapshot: {output}")
try:
async with dest_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with SourceSession() as source, DestSession() as dest:
counts = {
"tickers": await _copy_table(source, dest, Ticker, batch_size=args.batch_size),
"system_settings": await _copy_table(
source,
dest,
SystemSetting,
batch_size=args.batch_size,
where=or_(
SystemSetting.key.like("activation_%"),
SystemSetting.key.like("recommendation_%"),
),
),
"benchmark_prices": await _copy_table(source, dest, BenchmarkPrice, batch_size=args.batch_size),
"ohlcv_records": await _copy_table(source, dest, OHLCVRecord, batch_size=args.batch_size),
}
finally:
await source_engine.dispose()
await dest_engine.dispose()
print("Done:")
for name, count in counts.items():
print(f" {name}: {count}")
if __name__ == "__main__":
asyncio.run(_main())