Source code for forecasting.simulation.backtest_harness

from typing import List, Dict, Optional, Callable, Union, Tuple
import polars as pl
import dataclasses
import random

from forecasting.simulation.harness_base import BaseHarness
from forecasting.simulation.analysis_harness import AnalysisHarness
from forecasting.simulation.config import BacktestConfig
from forecasting.data.storage import Storage
from common.loggers.timing import TimingContext


[docs] class BacktestHarness(BaseHarness): """ Orchestrates the execution of C++ backtest simulations using a clean, configuration-driven approach. Now delegates core logic to standalone worker functions. """ def __init__( self, config: BacktestConfig, storage_options: Optional[Dict] = None, ): """ Initialize the BacktestHarness. Args: config: BacktestConfig dataclass. storage_options: S3/Cloud storage credentials/options. """ super().__init__(config.base_sources, storage_options, id=config.id) self.config = config
[docs] @classmethod def from_yaml(cls, path: str, **overrides) -> "BacktestHarness": import yaml import os with open(path, "r") as f: raw_cfg = yaml.safe_load(f) # Default ID override if "id" not in raw_cfg: raw_cfg["id"] = os.path.basename(path).replace(".yaml", "") # Parse Paths sources = ( raw_cfg.get("data_paths", []) if "data_paths" in raw_cfg else [raw_cfg.get("data_path")] ) sources = [s for s in sources if s] # Parse Features (supports list of params for parameter sweeps) features: List[Tuple[str, Dict]] = [] for feat in raw_cfg.get("features", []): feat_id = feat["id"] params = feat.get("params", {}) # Support list of params -> expand into multiple feature instances if isinstance(params, list): for p in params: features.append((feat_id, p if p else {})) else: features.append((feat_id, params if params else {})) # Build Base Config config = BacktestConfig( id=raw_cfg["id"], base_sources=sources, target_s=[int(t) for t in raw_cfg.get("target_s", [])], features=features, output_path=raw_cfg.get("output_path"), start_time=raw_cfg.get("start_time"), end_time=raw_cfg.get("end_time"), timezone=raw_cfg.get("timezone", "US/Eastern"), start_day=raw_cfg.get("start_day"), end_day=raw_cfg.get("end_day"), ) # Apply Overrides (Elegant Factory Pattern) if overrides: config = dataclasses.replace(config, **overrides) return cls(config, storage_options=None)
[docs] def to_analysis(self, output_base: Optional[str] = None) -> AnalysisHarness: """ Creates an AnalysisHarness pointing to the output of this backtest. """ path = output_base or self.config.output_path if not path: raise ValueError("No output_path specified in harness or method call.") source_pattern = f"{path}/*.parquet" if path.endswith(".parquet"): source_pattern = path return AnalysisHarness( base_source=source_pattern, id=f"{self.id}_analysis", storage_options=self.storage_options, )
def _resolve_paths(self) -> List[str]: """Resolve paths and apply date filtering. Consolidates glob expansion (via the base class) and date-range filtering into a single call so callers don't need to remember to filter separately. """ paths = super()._resolve_paths() return self._filter_paths_by_date(paths) def _filter_paths_by_date(self, paths: List[str]) -> List[str]: """ Filters data paths based on start_day and end_day configuration. Assumes filenames contain dates in YYYY-MM-DD format. """ if not self.config.start_day and not self.config.end_day: return paths import re from datetime import datetime # Pattern to extract date from filename (YYYY-MM-DD or YYYYMMDD) date_pattern = r"(\d{4})-?(\d{2})-?(\d{2})" filtered_paths = [] for path in paths: # Extract date from filename match = re.search(date_pattern, path) if not match: # If no date found in filename, include it by default filtered_paths.append(path) continue # Parse the date from filename year, month, day = match.groups() file_date_str = f"{year}-{month}-{day}" try: file_date = datetime.strptime(file_date_str, "%Y-%m-%d").date() # Check if within range if self.config.start_day: start_date = datetime.strptime( self.config.start_day, "%Y-%m-%d" ).date() if file_date < start_date: continue if self.config.end_day: end_date = datetime.strptime(self.config.end_day, "%Y-%m-%d").date() if file_date > end_date: continue filtered_paths.append(path) except ValueError: # If date parsing fails, include it by default filtered_paths.append(path) return filtered_paths
[docs] def run_distributed( self, executor, task_fn: Callable = None, benchmark_file: Optional[str] = None, concurrency_limit: int = 100, secrets: Optional[list] = None, ) -> pl.DataFrame: """ Run the backtest across all configured data files using the given executor. Args: concurrency_limit: Max concurrent workers (Modal containers). secrets: List of Modal secrets (e.g. for AWS). If None, uses executor default. """ from forecasting.simulation.worker import run_single_day if not self.config.output_path: raise ValueError("No output_path configured for distributed run.") with TimingContext("wave1_resolve_paths"): data_paths = self._resolve_paths() # Filter out incompatible MBP-10 files (which lack order_id) # S3 glove expansion might pick up 'mbp-10' files if they are in the same bucket tree. data_paths = [p for p in data_paths if "mbp-10" not in p and ".mbp-" not in p] if not data_paths: print("No data paths found.") return pl.DataFrame() # Shuffle to avoid long-tail clustering of heavy days random.shuffle(data_paths) print( f"Submitting {len(data_paths)} tasks to {executor.__class__.__name__} (Concurrency: {concurrency_limit})..." ) # Convert r2://dev/ paths to remote paths for executor remote_paths = Storage.get_remote_path(data_paths) remote_output = Storage.get_remote_path(self.config.output_path) # Prepare Inputs: (path, config, output_dir, i, total, benchmark_file=None) inputs = [ (path, self.config, remote_output, i, len(remote_paths), None) for i, path in enumerate(remote_paths) ] # Executor Options kwargs = self._get_executor_options(executor) kwargs["concurrency_limit"] = concurrency_limit # kwargs['timeout'] = 120 # 2 min aggressive timeout per task (disabled for now) kwargs["retries"] = ( 2 # Automatic retries for transient heartbeat/network failures ) if secrets is not None: kwargs["secrets"] = secrets # Execute results = executor.map(run_single_day, inputs, **kwargs) # Save Metadata with TimingContext("wave1_save_metadata"): self._save_metadata(self.config.output_path) return self._format_results(results, benchmark_file)
[docs] def run_sequential( self, output_dir: Optional[str] = None, benchmark_file: Optional[str] = None ) -> None: """ Runs the backtest sequentially locally. """ import os from datetime import datetime from tqdm import tqdm from forecasting.simulation.worker import run_single_day # Determine final output path timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") final_output = output_dir or self.config.output_path if final_output: if final_output.startswith("r2://"): final_output = f"{final_output.rstrip('/')}/{timestamp}" else: final_output = os.path.join(final_output, timestamp) print(f"Results will be saved to: {final_output}") data_paths = self._resolve_paths() print(f"Starting Sequential Backtest on {len(data_paths)} files...") pbar = tqdm(enumerate(data_paths), total=len(data_paths), unit="day") for i, path in pbar: filename = os.path.basename(path) pbar.set_description(f"Processing {filename}") run_single_day( path, self.config, final_output, i, len(data_paths), benchmark_file ) self._save_metadata(final_output)
def _save_metadata(self, output_dir: Optional[str]): """Saves the BacktestConfig as metadata.json in the output directory.""" if not output_dir: return path = output_dir.rstrip("/") + "/metadata.json" data = dataclasses.asdict(self.config) print(f"Saving metadata to {path}...") try: Storage.save_json(data, path) except Exception as e: print(f"Warning: Failed to save metadata: {e}") @staticmethod def _get_executor_options(executor) -> Dict: """Configures executor specific options (like starmap/sync_cpp) via inspection.""" kwargs = {} if hasattr(executor, "map"): try: co_varnames = executor.map.__code__.co_varnames if "starmap" in co_varnames: kwargs["starmap"] = True if "sync_cpp" in co_varnames: kwargs["sync_cpp"] = True except Exception as e: print(f"Warning: Failed to inspect executor options: {e}") pass return kwargs @staticmethod def _format_results( results: List, benchmark_file: Optional[str] = None ) -> pl.DataFrame: print("Backtest Complete.") cols = [ "timestamp", "filename", "instrument_id", "n_events", "load_time", "engine_time", "save_time", "total_time", "status", "error", ] # Filter valid results valid_results = [ r for r in results if isinstance(r, (list, tuple)) and len(r) == len(cols) ] if not valid_results: print("Warning: No valid results returned.") return pl.DataFrame(schema=cols) # df = pl.from_records(valid_results, schema=cols) # Use DataFrame constructor directly to avoid deprecated from_records warning crash df = pl.DataFrame(valid_results, schema=cols, orient="row") # Save aggregated benchmark log if requested if benchmark_file: try: print(f"Saving aggregated benchmark logs to {benchmark_file}...") df.write_csv(benchmark_file) except Exception as e: print(f"Failed to save benchmark log: {e}") return df
[docs] def get_market_subsampled_df( self, n: int = 100_000, strategy: str = "random_files", seed: Optional[int] = None, n_days: Optional[int] = None, instrument_id: Optional[int] = None, ) -> pl.DataFrame: """ Returns a DataFrame of raw market data (Databento MBO events) for exploration. This is useful for analyzing raw order book events before featurization. Uses the DataLoader to read .dbn.zst files. Args: n: Number of events to return (approximate for some strategies). strategy: - "head": First n events from first file. - "random_files": Picks random files and reads n events total. - "random_days": Picks n_days random files, reads all events. seed: Random seed for file selection. n_days: Number of files/days for "random_days" strategy. instrument_id: Specific instrument to filter for (auto-detects if None). Returns: pl.DataFrame with columns: action, side, price, size, order_id, ts_recv, flags """ from forecasting.data.loader import DataLoader files = [ f for f in self._resolve_paths() if "mbp-10" not in f and ".mbp-" not in f ] if not files: print("No .dbn.zst or .dbn files found in base_source") return pl.DataFrame() # Select files based on strategy if strategy == "head": selected_files = files[:1] elif strategy in ["random_files", "random_days"]: count = ( n_days if strategy == "random_days" and n_days else min(len(files), 5) ) rng = random.Random(seed) selected_files = rng.sample(files, min(len(files), count)) else: raise ValueError(f"Unknown strategy: {strategy}") print(f"Loading market data from {len(selected_files)} files...") # Load and concatenate all_data = [] total_events = 0 # Calculate events per file for efficient loading events_per_file = ( n // len(selected_files) + 1 if strategy != "random_days" else None ) for f in selected_files: loader = DataLoader( f, start_time=self.config.start_time, end_time=self.config.end_time, timezone=self.config.timezone, ) # Use n_events for memory-efficient loading (only for non-random_days) remaining = n - total_events if strategy != "random_days" else None n_to_load = min(events_per_file, remaining) if remaining else None sim_data = loader.load(instrument_id=instrument_id, n_events=n_to_load) if sim_data.n_events == 0: continue # Convert to DataFrame df = pl.DataFrame( { "action": sim_data.actions, "side": sim_data.sides, "price": sim_data.prices, "size": sim_data.sizes, "order_id": sim_data.order_ids, "ts_recv": sim_data.ts_recvs, "flags": sim_data.flags, } ) all_data.append(df) total_events += len(df) # For non-random_days strategies, stop if we have enough if strategy != "random_days" and total_events >= n: break if not all_data: return pl.DataFrame() result = pl.concat(all_data) # Limit rows for non-random_days strategies if strategy != "random_days" and len(result) > n: result = result.head(n) print(f"Loaded {len(result):,} market events") return result