parallelize the backtest across worker processes (true multi-core)
Deploy / lint (push) Successful in 6s
Deploy / test (push) Successful in 38s
Deploy / deploy (push) Successful in 25s

The replay was CPU-bound and single-core: the earlier asyncio.to_thread offload
kept the API responsive but, because of the GIL, ran on one core. Per-ticker
replay is independent, so fan it out across worker processes (which sidestep the
GIL) for real multi-core speedup.

- New `settings.backtest_workers` (default 4), capped to cpu_count-1 so a core
  stays free for the web server.
- Uses a `forkserver` context (workers forked from a clean single-threaded
  server — avoids the fork-with-threads deadlock); falls back to `fork`. On
  spawn-only platforms (Windows) and for 1-ticker runs it uses the thread path,
  so dev/tests are unaffected.
- Worker takes primitive column arrays (cheap to pickle), rebuilds bars, and
  returns (candidates, plain-dict signal series) — both picklable across the
  process boundary. Bars are still fetched in the event loop (ORM-safe).
- Pool creation is guarded: if the pool can't start, the job falls back to the
  sequential thread path instead of failing.

334 backend tests pass (parallel path is POSIX/server-only, so it's covered by
construction + the picklability/worker-count tests; the thread fallback is
exercised by the run_backtest smoke test).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-23 23:20:20 +02:00
parent e71c07e554
commit 7060b9a019
3 changed files with 171 additions and 35 deletions
+6
View File
@@ -73,6 +73,12 @@ class Settings(BaseSettings):
# backtests become computable. ~252 trading days/year.
ohlcv_history_days: int = 1825
# Backtest parallelism: replay tickers across this many worker processes on
# POSIX (forkserver), capped to cpu_count-1 so a core stays free for the web
# server. 1 disables it (sequential). No effect on Windows / spawn-only
# platforms — those fall back to a single worker thread.
backtest_workers: int = 4
# Database Pool
db_pool_size: int = 5
db_pool_timeout: int = 30
+128 -35
View File
@@ -18,15 +18,19 @@ import asyncio
import json
import logging
import math
import multiprocessing
import os
from collections import defaultdict
from collections.abc import Callable
from datetime import datetime, timezone
from concurrent.futures import ProcessPoolExecutor
from datetime import date, datetime, timezone
from types import SimpleNamespace
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.settings import SystemSetting
from app.models.ticker import Ticker
from app.services.admin_service import get_activation_config, update_setting
@@ -491,18 +495,69 @@ def _signal_evaluation(collected: dict) -> list[dict]:
return rows
def _process_ticker(
def _signal_series(records: list) -> dict:
"""Per-ticker signal/forward-return series as a PLAIN (picklable) nested dict
— no defaultdict/lambda — so it can cross a process boundary."""
tmp: dict = defaultdict(lambda: defaultdict(list))
_accumulate_signal_series(records, tmp)
return {name: dict(weeks) for name, weeks in tmp.items()}
def _replay_and_signals(
symbol: str,
records: list,
columns: tuple,
config: dict,
activation: dict,
collected: dict,
) -> list[dict]:
"""The CPU-bound per-ticker work — replay + signal accumulation — bundled so
run_backtest can hand it to a worker thread. Mutates ``collected``."""
cands = _replay_ticker(symbol, records, config, activation)
_accumulate_signal_series(records, collected)
return cands
) -> tuple[list[dict], dict]:
"""The CPU-bound per-ticker work, as a top-level (picklable) function so it can
run in a worker process. Takes primitive column arrays (cheap to pickle),
rebuilds bar objects, and returns (candidates, signal_series)."""
date_ords, opens, highs, lows, closes, volumes = columns
bars = [
SimpleNamespace(
date=date.fromordinal(o), open=op, high=hi, low=lo, close=cl, volume=vo
)
for o, op, hi, lo, cl, vo in zip(date_ords, opens, highs, lows, closes, volumes)
]
return _replay_ticker(symbol, bars, config, activation), _signal_series(bars)
def _backtest_worker_count() -> int:
"""How many worker processes to replay tickers across. Capped to cpu_count-1
so a core stays free for the web server; 1 means sequential."""
configured = int(getattr(settings, "backtest_workers", 4))
if configured <= 1:
return 1
cpu = os.cpu_count() or 1
return max(1, min(configured, cpu - 1))
def _mp_context():
"""A start method safe to use from the threaded asyncio server: ``forkserver``
(workers forked from a clean, single-threaded server — avoids the
fork-with-threads deadlock) when available, else ``fork``. Returns None on
spawn-only platforms (Windows), where the caller falls back to a thread."""
methods = multiprocessing.get_all_start_methods()
for method in ("forkserver", "fork"):
if method in methods:
return multiprocessing.get_context(method)
return None
async def _fetch_columns(db: AsyncSession, symbol: str) -> tuple | None:
"""Read one ticker's OHLCV and detach it to primitive column arrays in the
event loop (safe ORM access), ready to hand to a worker. None if no data."""
records = await query_ohlcv(db, symbol)
if not records:
return None
return (
[r.date.toordinal() for r in records],
[float(r.open) for r in records],
[float(r.high) for r in records],
[float(r.low) for r in records],
[float(r.close) for r in records],
[int(r.volume) for r in records],
)
def _assign_momentum_percentiles(candidates: list[dict]) -> None:
@@ -549,33 +604,71 @@ async def run_backtest(
candidates: list[dict] = []
# collected[signal_name][iso_week] -> list of (signal_value, forward_return)
collected: dict = defaultdict(lambda: defaultdict(list))
for index, ticker in enumerate(tickers):
if progress_cb is not None:
progress_cb(index, total, ticker.symbol)
def _merge(result: tuple[list[dict], dict]) -> None:
cands, series = result
candidates.extend(cands)
for name, weeks in series.items():
for wk, pairs in weeks.items():
collected[name][wk].extend(pairs)
workers = _backtest_worker_count()
ctx = _mp_context() if (workers > 1 and total > 1) else None
loop = asyncio.get_running_loop()
pool = None
if ctx is not None:
try:
records = await query_ohlcv(db, ticker.symbol)
# Detach the ORM rows to plain objects in the event loop (safe to read
# here), then run the heavy replay in a worker thread. The compute is
# CPU-bound and used to block the event loop — and the API server with
# it — for the whole run; offloading lets CPython hand the GIL back to
# the loop every few ms so health checks / page loads stay responsive.
bars = [
SimpleNamespace(
date=r.date,
open=float(r.open),
high=float(r.high),
low=float(r.low),
close=float(r.close),
volume=int(r.volume),
)
for r in records
]
cands = await asyncio.to_thread(
_process_ticker, ticker.symbol, bars, config, activation, collected
)
candidates.extend(cands)
pool = ProcessPoolExecutor(max_workers=workers, mp_context=ctx)
except Exception:
logger.exception("Backtest replay failed for %s", ticker.symbol)
logger.exception("Backtest process pool unavailable; falling back to sequential")
if pool is not None:
# Parallel: replay tickers across worker processes — true multi-core, since
# the GIL only serializes work within a single process. Bars are fetched in
# the event loop (ORM-safe) and a bounded batch is fanned out to the pool.
logger.info(json.dumps({
"event": "backtest_parallel", "workers": workers,
"start_method": ctx.get_start_method(),
}))
chunk = workers * 2
done = 0
with pool:
for start in range(0, total, chunk):
batch = tickers[start : start + chunk]
futures = []
for ticker in batch:
try:
columns = await _fetch_columns(db, ticker.symbol)
except Exception:
logger.exception("Backtest fetch failed for %s", ticker.symbol)
continue
if columns is not None:
futures.append(loop.run_in_executor(
pool, _replay_and_signals, ticker.symbol, columns, config, activation
))
for result in await asyncio.gather(*futures, return_exceptions=True):
if isinstance(result, Exception):
logger.error(json.dumps({"event": "backtest_worker_error", "message": str(result)}))
else:
_merge(result)
done += len(batch)
if progress_cb is not None:
progress_cb(min(done, total), total, "")
else:
# Sequential fallback (Windows / 1 worker): run each replay in a worker
# thread so the event loop — and the API server — stays responsive.
for index, ticker in enumerate(tickers):
if progress_cb is not None:
progress_cb(index, total, ticker.symbol)
try:
columns = await _fetch_columns(db, ticker.symbol)
if columns is not None:
_merge(await asyncio.to_thread(
_replay_and_signals, ticker.symbol, columns, config, activation
))
except Exception:
logger.exception("Backtest replay failed for %s", ticker.symbol)
if progress_cb is not None and total:
progress_cb(total, total, "")
+37
View File
@@ -2,6 +2,8 @@
from __future__ import annotations
import os
import pickle
import random
from datetime import date, timedelta
from types import SimpleNamespace
@@ -156,3 +158,38 @@ def test_accumulate_signal_series_emits_weekly_pairs():
# ...one per ISO week, with a forward return attached to each pair.
sample = next(iter(collected["mom_12_1"].values()))
assert all(len(pair) == 2 for pair in sample)
# ---------------------------------------------------------------------------
# Parallel-replay plumbing (process pool): plain/picklable results, worker count
# ---------------------------------------------------------------------------
def test_signal_series_is_plain_and_picklable():
from collections import defaultdict
closes = [100.0 * (1.003 ** k) for k in range(400)]
series = bt._signal_series(_records(closes))
# Must be plain dicts (no defaultdict/lambda) so it survives a process boundary.
assert type(series) is dict
assert all(type(weeks) is dict for weeks in series.values())
pickle.dumps(series) # the worker's return is pickled to the parent — must not raise
# ...and equivalent to the in-place accumulator.
acc = defaultdict(lambda: defaultdict(list))
bt._accumulate_signal_series(_records(closes), acc)
assert series == {name: dict(w) for name, w in acc.items()}
def test_worker_count_caps_to_cpu_minus_one(monkeypatch):
monkeypatch.setattr(bt.settings, "backtest_workers", 1000)
assert bt._backtest_worker_count() == max(1, (os.cpu_count() or 1) - 1)
def test_worker_count_one_disables(monkeypatch):
monkeypatch.setattr(bt.settings, "backtest_workers", 1)
assert bt._backtest_worker_count() == 1
def test_mp_context_is_none_or_posix():
ctx = bt._mp_context()
# None on spawn-only platforms (Windows); a safe POSIX context otherwise.
assert ctx is None or ctx.get_start_method() in ("fork", "forkserver")