parallelize the backtest across worker processes (true multi-core)
The replay was CPU-bound and single-core: the earlier asyncio.to_thread offload kept the API responsive but, because of the GIL, ran on one core. Per-ticker replay is independent, so fan it out across worker processes (which sidestep the GIL) for real multi-core speedup. - New `settings.backtest_workers` (default 4), capped to cpu_count-1 so a core stays free for the web server. - Uses a `forkserver` context (workers forked from a clean single-threaded server — avoids the fork-with-threads deadlock); falls back to `fork`. On spawn-only platforms (Windows) and for 1-ticker runs it uses the thread path, so dev/tests are unaffected. - Worker takes primitive column arrays (cheap to pickle), rebuilds bars, and returns (candidates, plain-dict signal series) — both picklable across the process boundary. Bars are still fetched in the event loop (ORM-safe). - Pool creation is guarded: if the pool can't start, the job falls back to the sequential thread path instead of failing. 334 backend tests pass (parallel path is POSIX/server-only, so it's covered by construction + the picklability/worker-count tests; the thread fallback is exercised by the run_backtest smoke test). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -18,15 +18,19 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from datetime import date, datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.models.settings import SystemSetting
|
||||
from app.models.ticker import Ticker
|
||||
from app.services.admin_service import get_activation_config, update_setting
|
||||
@@ -491,18 +495,69 @@ def _signal_evaluation(collected: dict) -> list[dict]:
|
||||
return rows
|
||||
|
||||
|
||||
def _process_ticker(
|
||||
def _signal_series(records: list) -> dict:
|
||||
"""Per-ticker signal/forward-return series as a PLAIN (picklable) nested dict
|
||||
— no defaultdict/lambda — so it can cross a process boundary."""
|
||||
tmp: dict = defaultdict(lambda: defaultdict(list))
|
||||
_accumulate_signal_series(records, tmp)
|
||||
return {name: dict(weeks) for name, weeks in tmp.items()}
|
||||
|
||||
|
||||
def _replay_and_signals(
|
||||
symbol: str,
|
||||
records: list,
|
||||
columns: tuple,
|
||||
config: dict,
|
||||
activation: dict,
|
||||
collected: dict,
|
||||
) -> list[dict]:
|
||||
"""The CPU-bound per-ticker work — replay + signal accumulation — bundled so
|
||||
run_backtest can hand it to a worker thread. Mutates ``collected``."""
|
||||
cands = _replay_ticker(symbol, records, config, activation)
|
||||
_accumulate_signal_series(records, collected)
|
||||
return cands
|
||||
) -> tuple[list[dict], dict]:
|
||||
"""The CPU-bound per-ticker work, as a top-level (picklable) function so it can
|
||||
run in a worker process. Takes primitive column arrays (cheap to pickle),
|
||||
rebuilds bar objects, and returns (candidates, signal_series)."""
|
||||
date_ords, opens, highs, lows, closes, volumes = columns
|
||||
bars = [
|
||||
SimpleNamespace(
|
||||
date=date.fromordinal(o), open=op, high=hi, low=lo, close=cl, volume=vo
|
||||
)
|
||||
for o, op, hi, lo, cl, vo in zip(date_ords, opens, highs, lows, closes, volumes)
|
||||
]
|
||||
return _replay_ticker(symbol, bars, config, activation), _signal_series(bars)
|
||||
|
||||
|
||||
def _backtest_worker_count() -> int:
|
||||
"""How many worker processes to replay tickers across. Capped to cpu_count-1
|
||||
so a core stays free for the web server; 1 means sequential."""
|
||||
configured = int(getattr(settings, "backtest_workers", 4))
|
||||
if configured <= 1:
|
||||
return 1
|
||||
cpu = os.cpu_count() or 1
|
||||
return max(1, min(configured, cpu - 1))
|
||||
|
||||
|
||||
def _mp_context():
|
||||
"""A start method safe to use from the threaded asyncio server: ``forkserver``
|
||||
(workers forked from a clean, single-threaded server — avoids the
|
||||
fork-with-threads deadlock) when available, else ``fork``. Returns None on
|
||||
spawn-only platforms (Windows), where the caller falls back to a thread."""
|
||||
methods = multiprocessing.get_all_start_methods()
|
||||
for method in ("forkserver", "fork"):
|
||||
if method in methods:
|
||||
return multiprocessing.get_context(method)
|
||||
return None
|
||||
|
||||
|
||||
async def _fetch_columns(db: AsyncSession, symbol: str) -> tuple | None:
|
||||
"""Read one ticker's OHLCV and detach it to primitive column arrays in the
|
||||
event loop (safe ORM access), ready to hand to a worker. None if no data."""
|
||||
records = await query_ohlcv(db, symbol)
|
||||
if not records:
|
||||
return None
|
||||
return (
|
||||
[r.date.toordinal() for r in records],
|
||||
[float(r.open) for r in records],
|
||||
[float(r.high) for r in records],
|
||||
[float(r.low) for r in records],
|
||||
[float(r.close) for r in records],
|
||||
[int(r.volume) for r in records],
|
||||
)
|
||||
|
||||
|
||||
def _assign_momentum_percentiles(candidates: list[dict]) -> None:
|
||||
@@ -549,33 +604,71 @@ async def run_backtest(
|
||||
candidates: list[dict] = []
|
||||
# collected[signal_name][iso_week] -> list of (signal_value, forward_return)
|
||||
collected: dict = defaultdict(lambda: defaultdict(list))
|
||||
for index, ticker in enumerate(tickers):
|
||||
if progress_cb is not None:
|
||||
progress_cb(index, total, ticker.symbol)
|
||||
|
||||
def _merge(result: tuple[list[dict], dict]) -> None:
|
||||
cands, series = result
|
||||
candidates.extend(cands)
|
||||
for name, weeks in series.items():
|
||||
for wk, pairs in weeks.items():
|
||||
collected[name][wk].extend(pairs)
|
||||
|
||||
workers = _backtest_worker_count()
|
||||
ctx = _mp_context() if (workers > 1 and total > 1) else None
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
pool = None
|
||||
if ctx is not None:
|
||||
try:
|
||||
records = await query_ohlcv(db, ticker.symbol)
|
||||
# Detach the ORM rows to plain objects in the event loop (safe to read
|
||||
# here), then run the heavy replay in a worker thread. The compute is
|
||||
# CPU-bound and used to block the event loop — and the API server with
|
||||
# it — for the whole run; offloading lets CPython hand the GIL back to
|
||||
# the loop every few ms so health checks / page loads stay responsive.
|
||||
bars = [
|
||||
SimpleNamespace(
|
||||
date=r.date,
|
||||
open=float(r.open),
|
||||
high=float(r.high),
|
||||
low=float(r.low),
|
||||
close=float(r.close),
|
||||
volume=int(r.volume),
|
||||
)
|
||||
for r in records
|
||||
]
|
||||
cands = await asyncio.to_thread(
|
||||
_process_ticker, ticker.symbol, bars, config, activation, collected
|
||||
)
|
||||
candidates.extend(cands)
|
||||
pool = ProcessPoolExecutor(max_workers=workers, mp_context=ctx)
|
||||
except Exception:
|
||||
logger.exception("Backtest replay failed for %s", ticker.symbol)
|
||||
logger.exception("Backtest process pool unavailable; falling back to sequential")
|
||||
|
||||
if pool is not None:
|
||||
# Parallel: replay tickers across worker processes — true multi-core, since
|
||||
# the GIL only serializes work within a single process. Bars are fetched in
|
||||
# the event loop (ORM-safe) and a bounded batch is fanned out to the pool.
|
||||
logger.info(json.dumps({
|
||||
"event": "backtest_parallel", "workers": workers,
|
||||
"start_method": ctx.get_start_method(),
|
||||
}))
|
||||
chunk = workers * 2
|
||||
done = 0
|
||||
with pool:
|
||||
for start in range(0, total, chunk):
|
||||
batch = tickers[start : start + chunk]
|
||||
futures = []
|
||||
for ticker in batch:
|
||||
try:
|
||||
columns = await _fetch_columns(db, ticker.symbol)
|
||||
except Exception:
|
||||
logger.exception("Backtest fetch failed for %s", ticker.symbol)
|
||||
continue
|
||||
if columns is not None:
|
||||
futures.append(loop.run_in_executor(
|
||||
pool, _replay_and_signals, ticker.symbol, columns, config, activation
|
||||
))
|
||||
for result in await asyncio.gather(*futures, return_exceptions=True):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(json.dumps({"event": "backtest_worker_error", "message": str(result)}))
|
||||
else:
|
||||
_merge(result)
|
||||
done += len(batch)
|
||||
if progress_cb is not None:
|
||||
progress_cb(min(done, total), total, "")
|
||||
else:
|
||||
# Sequential fallback (Windows / 1 worker): run each replay in a worker
|
||||
# thread so the event loop — and the API server — stays responsive.
|
||||
for index, ticker in enumerate(tickers):
|
||||
if progress_cb is not None:
|
||||
progress_cb(index, total, ticker.symbol)
|
||||
try:
|
||||
columns = await _fetch_columns(db, ticker.symbol)
|
||||
if columns is not None:
|
||||
_merge(await asyncio.to_thread(
|
||||
_replay_and_signals, ticker.symbol, columns, config, activation
|
||||
))
|
||||
except Exception:
|
||||
logger.exception("Backtest replay failed for %s", ticker.symbol)
|
||||
|
||||
if progress_cb is not None and total:
|
||||
progress_cb(total, total, "")
|
||||
|
||||
Reference in New Issue
Block a user