diff --git a/app/config.py b/app/config.py index 9441acd..57f9b00 100644 --- a/app/config.py +++ b/app/config.py @@ -73,6 +73,12 @@ class Settings(BaseSettings): # backtests become computable. ~252 trading days/year. ohlcv_history_days: int = 1825 + # Backtest parallelism: replay tickers across this many worker processes on + # POSIX (forkserver), capped to cpu_count-1 so a core stays free for the web + # server. 1 disables it (sequential). No effect on Windows / spawn-only + # platforms — those fall back to a single worker thread. + backtest_workers: int = 4 + # Database Pool db_pool_size: int = 5 db_pool_timeout: int = 30 diff --git a/app/services/backtest_service.py b/app/services/backtest_service.py index c5e61a6..f0d9bad 100644 --- a/app/services/backtest_service.py +++ b/app/services/backtest_service.py @@ -18,15 +18,19 @@ import asyncio import json import logging import math +import multiprocessing +import os from collections import defaultdict from collections.abc import Callable -from datetime import datetime, timezone +from concurrent.futures import ProcessPoolExecutor +from datetime import date, datetime, timezone from types import SimpleNamespace from typing import Any from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.config import settings from app.models.settings import SystemSetting from app.models.ticker import Ticker from app.services.admin_service import get_activation_config, update_setting @@ -491,18 +495,69 @@ def _signal_evaluation(collected: dict) -> list[dict]: return rows -def _process_ticker( +def _signal_series(records: list) -> dict: + """Per-ticker signal/forward-return series as a PLAIN (picklable) nested dict + — no defaultdict/lambda — so it can cross a process boundary.""" + tmp: dict = defaultdict(lambda: defaultdict(list)) + _accumulate_signal_series(records, tmp) + return {name: dict(weeks) for name, weeks in tmp.items()} + + +def _replay_and_signals( symbol: str, - records: list, + columns: tuple, config: dict, activation: dict, - collected: dict, -) -> list[dict]: - """The CPU-bound per-ticker work — replay + signal accumulation — bundled so - run_backtest can hand it to a worker thread. Mutates ``collected``.""" - cands = _replay_ticker(symbol, records, config, activation) - _accumulate_signal_series(records, collected) - return cands +) -> tuple[list[dict], dict]: + """The CPU-bound per-ticker work, as a top-level (picklable) function so it can + run in a worker process. Takes primitive column arrays (cheap to pickle), + rebuilds bar objects, and returns (candidates, signal_series).""" + date_ords, opens, highs, lows, closes, volumes = columns + bars = [ + SimpleNamespace( + date=date.fromordinal(o), open=op, high=hi, low=lo, close=cl, volume=vo + ) + for o, op, hi, lo, cl, vo in zip(date_ords, opens, highs, lows, closes, volumes) + ] + return _replay_ticker(symbol, bars, config, activation), _signal_series(bars) + + +def _backtest_worker_count() -> int: + """How many worker processes to replay tickers across. Capped to cpu_count-1 + so a core stays free for the web server; 1 means sequential.""" + configured = int(getattr(settings, "backtest_workers", 4)) + if configured <= 1: + return 1 + cpu = os.cpu_count() or 1 + return max(1, min(configured, cpu - 1)) + + +def _mp_context(): + """A start method safe to use from the threaded asyncio server: ``forkserver`` + (workers forked from a clean, single-threaded server — avoids the + fork-with-threads deadlock) when available, else ``fork``. Returns None on + spawn-only platforms (Windows), where the caller falls back to a thread.""" + methods = multiprocessing.get_all_start_methods() + for method in ("forkserver", "fork"): + if method in methods: + return multiprocessing.get_context(method) + return None + + +async def _fetch_columns(db: AsyncSession, symbol: str) -> tuple | None: + """Read one ticker's OHLCV and detach it to primitive column arrays in the + event loop (safe ORM access), ready to hand to a worker. None if no data.""" + records = await query_ohlcv(db, symbol) + if not records: + return None + return ( + [r.date.toordinal() for r in records], + [float(r.open) for r in records], + [float(r.high) for r in records], + [float(r.low) for r in records], + [float(r.close) for r in records], + [int(r.volume) for r in records], + ) def _assign_momentum_percentiles(candidates: list[dict]) -> None: @@ -549,33 +604,71 @@ async def run_backtest( candidates: list[dict] = [] # collected[signal_name][iso_week] -> list of (signal_value, forward_return) collected: dict = defaultdict(lambda: defaultdict(list)) - for index, ticker in enumerate(tickers): - if progress_cb is not None: - progress_cb(index, total, ticker.symbol) + + def _merge(result: tuple[list[dict], dict]) -> None: + cands, series = result + candidates.extend(cands) + for name, weeks in series.items(): + for wk, pairs in weeks.items(): + collected[name][wk].extend(pairs) + + workers = _backtest_worker_count() + ctx = _mp_context() if (workers > 1 and total > 1) else None + loop = asyncio.get_running_loop() + + pool = None + if ctx is not None: try: - records = await query_ohlcv(db, ticker.symbol) - # Detach the ORM rows to plain objects in the event loop (safe to read - # here), then run the heavy replay in a worker thread. The compute is - # CPU-bound and used to block the event loop — and the API server with - # it — for the whole run; offloading lets CPython hand the GIL back to - # the loop every few ms so health checks / page loads stay responsive. - bars = [ - SimpleNamespace( - date=r.date, - open=float(r.open), - high=float(r.high), - low=float(r.low), - close=float(r.close), - volume=int(r.volume), - ) - for r in records - ] - cands = await asyncio.to_thread( - _process_ticker, ticker.symbol, bars, config, activation, collected - ) - candidates.extend(cands) + pool = ProcessPoolExecutor(max_workers=workers, mp_context=ctx) except Exception: - logger.exception("Backtest replay failed for %s", ticker.symbol) + logger.exception("Backtest process pool unavailable; falling back to sequential") + + if pool is not None: + # Parallel: replay tickers across worker processes — true multi-core, since + # the GIL only serializes work within a single process. Bars are fetched in + # the event loop (ORM-safe) and a bounded batch is fanned out to the pool. + logger.info(json.dumps({ + "event": "backtest_parallel", "workers": workers, + "start_method": ctx.get_start_method(), + })) + chunk = workers * 2 + done = 0 + with pool: + for start in range(0, total, chunk): + batch = tickers[start : start + chunk] + futures = [] + for ticker in batch: + try: + columns = await _fetch_columns(db, ticker.symbol) + except Exception: + logger.exception("Backtest fetch failed for %s", ticker.symbol) + continue + if columns is not None: + futures.append(loop.run_in_executor( + pool, _replay_and_signals, ticker.symbol, columns, config, activation + )) + for result in await asyncio.gather(*futures, return_exceptions=True): + if isinstance(result, Exception): + logger.error(json.dumps({"event": "backtest_worker_error", "message": str(result)})) + else: + _merge(result) + done += len(batch) + if progress_cb is not None: + progress_cb(min(done, total), total, "") + else: + # Sequential fallback (Windows / 1 worker): run each replay in a worker + # thread so the event loop — and the API server — stays responsive. + for index, ticker in enumerate(tickers): + if progress_cb is not None: + progress_cb(index, total, ticker.symbol) + try: + columns = await _fetch_columns(db, ticker.symbol) + if columns is not None: + _merge(await asyncio.to_thread( + _replay_and_signals, ticker.symbol, columns, config, activation + )) + except Exception: + logger.exception("Backtest replay failed for %s", ticker.symbol) if progress_cb is not None and total: progress_cb(total, total, "") diff --git a/tests/unit/test_signal_eval.py b/tests/unit/test_signal_eval.py index 78f1194..cbc85a0 100644 --- a/tests/unit/test_signal_eval.py +++ b/tests/unit/test_signal_eval.py @@ -2,6 +2,8 @@ from __future__ import annotations +import os +import pickle import random from datetime import date, timedelta from types import SimpleNamespace @@ -156,3 +158,38 @@ def test_accumulate_signal_series_emits_weekly_pairs(): # ...one per ISO week, with a forward return attached to each pair. sample = next(iter(collected["mom_12_1"].values())) assert all(len(pair) == 2 for pair in sample) + + +# --------------------------------------------------------------------------- +# Parallel-replay plumbing (process pool): plain/picklable results, worker count +# --------------------------------------------------------------------------- + +def test_signal_series_is_plain_and_picklable(): + from collections import defaultdict + + closes = [100.0 * (1.003 ** k) for k in range(400)] + series = bt._signal_series(_records(closes)) + # Must be plain dicts (no defaultdict/lambda) so it survives a process boundary. + assert type(series) is dict + assert all(type(weeks) is dict for weeks in series.values()) + pickle.dumps(series) # the worker's return is pickled to the parent — must not raise + # ...and equivalent to the in-place accumulator. + acc = defaultdict(lambda: defaultdict(list)) + bt._accumulate_signal_series(_records(closes), acc) + assert series == {name: dict(w) for name, w in acc.items()} + + +def test_worker_count_caps_to_cpu_minus_one(monkeypatch): + monkeypatch.setattr(bt.settings, "backtest_workers", 1000) + assert bt._backtest_worker_count() == max(1, (os.cpu_count() or 1) - 1) + + +def test_worker_count_one_disables(monkeypatch): + monkeypatch.setattr(bt.settings, "backtest_workers", 1) + assert bt._backtest_worker_count() == 1 + + +def test_mp_context_is_none_or_posix(): + ctx = bt._mp_context() + # None on spawn-only platforms (Windows); a safe POSIX context otherwise. + assert ctx is None or ctx.get_start_method() in ("fork", "forkserver")