"""Create a minimal local SQLite snapshot for offline backtest research. Copies only the data required by app.services.backtest_service.run_backtest: tickers, OHLCV bars, SPY benchmark closes, and activation/recommendation settings. Other system settings are intentionally skipped to avoid copying secrets into local snapshot files. """ from __future__ import annotations import argparse import asyncio import os import sys from pathlib import Path from sqlalchemy import func, insert, or_, select from sqlalchemy.engine import make_url from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) def _normalize_postgres_url(url: str) -> str: if url.startswith("postgresql+asyncpg://"): return url if url.startswith("postgresql://"): return "postgresql+asyncpg://" + url[len("postgresql://") :] if url.startswith("postgres://"): return "postgresql+asyncpg://" + url[len("postgres://") :] return url def _sqlite_url(path: Path) -> str: return f"sqlite+aiosqlite:///{path.resolve().as_posix()}" def _hide_password(url: str) -> str: return make_url(url).render_as_string(hide_password=True) def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--database-url", default=os.getenv("DATABASE_URL"), help="Source Postgres URL. Defaults to DATABASE_URL, then app .env database_url.", ) parser.add_argument( "--output", default="backtest_snapshots/prod-backtest.sqlite", help="SQLite snapshot path to create.", ) parser.add_argument("--batch-size", type=int, default=5000) parser.add_argument("--force", action="store_true", help="Overwrite an existing snapshot file.") return parser.parse_args() async def _copy_table( source: AsyncSession, dest: AsyncSession, model: type, *, batch_size: int, where=None, ) -> int: table = model.__table__ columns = list(table.columns) count_stmt = select(func.count()).select_from(table) stmt = select(*columns) if where is not None: count_stmt = count_stmt.where(where) stmt = stmt.where(where) primary_key_columns = list(table.primary_key.columns) if primary_key_columns: stmt = stmt.order_by(*primary_key_columns) expected = int((await source.execute(count_stmt)).scalar_one()) if expected == 0: print(f"{table.name}: 0 rows") return 0 copied = 0 stream = await source.stream(stmt.execution_options(yield_per=batch_size)) async for partition in stream.partitions(batch_size): rows = [dict(row._mapping) for row in partition] if not rows: continue await dest.execute(insert(table), rows) await dest.commit() copied += len(rows) print(f"{table.name}: {copied}/{expected}", end="\r") print(f"{table.name}: {copied} rows") return copied async def _main() -> None: args = _parse_args() from app.config import settings from app.database import Base import app.models # noqa: F401 - registers all metadata tables from app.models.benchmark_price import BenchmarkPrice from app.models.ohlcv import OHLCVRecord from app.models.settings import SystemSetting from app.models.ticker import Ticker source_url = _normalize_postgres_url(args.database_url or settings.database_url) output = Path(args.output) if output.exists(): if not args.force: raise SystemExit(f"{output} already exists. Pass --force to overwrite it.") output.unlink() output.parent.mkdir(parents=True, exist_ok=True) source_engine = create_async_engine( source_url, pool_pre_ping=True, connect_args={"server_settings": {"default_transaction_read_only": "on"}}, ) dest_engine = create_async_engine(_sqlite_url(output)) SourceSession = async_sessionmaker(source_engine, class_=AsyncSession, expire_on_commit=False) DestSession = async_sessionmaker(dest_engine, class_=AsyncSession, expire_on_commit=False) print(f"Source: {_hide_password(source_url)}") print(f"Snapshot: {output}") try: async with dest_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) async with SourceSession() as source, DestSession() as dest: counts = { "tickers": await _copy_table(source, dest, Ticker, batch_size=args.batch_size), "system_settings": await _copy_table( source, dest, SystemSetting, batch_size=args.batch_size, where=or_( SystemSetting.key.like("activation_%"), SystemSetting.key.like("recommendation_%"), ), ), "benchmark_prices": await _copy_table(source, dest, BenchmarkPrice, batch_size=args.batch_size), "ohlcv_records": await _copy_table(source, dest, OHLCVRecord, batch_size=args.batch_size), } finally: await source_engine.dispose() await dest_engine.dispose() print("Done:") for name, count in counts.items(): print(f" {name}: {count}") if __name__ == "__main__": asyncio.run(_main())