from typing import List, Dict, Optional, Callable, Union, Tuple
import polars as pl
import dataclasses
import random
from forecasting.simulation.harness_base import BaseHarness
from forecasting.simulation.analysis_harness import AnalysisHarness
from forecasting.simulation.config import BacktestConfig
from forecasting.data.storage import Storage
from common.loggers.timing import TimingContext
[docs]
class BacktestHarness(BaseHarness):
"""
Orchestrates the execution of C++ backtest simulations using a clean, configuration-driven approach.
Now delegates core logic to standalone worker functions.
"""
def __init__(
self,
config: BacktestConfig,
storage_options: Optional[Dict] = None,
):
"""
Initialize the BacktestHarness.
Args:
config: BacktestConfig dataclass.
storage_options: S3/Cloud storage credentials/options.
"""
super().__init__(config.base_sources, storage_options, id=config.id)
self.config = config
[docs]
@classmethod
def from_yaml(cls, path: str, **overrides) -> "BacktestHarness":
import yaml
import os
with open(path, "r") as f:
raw_cfg = yaml.safe_load(f)
# Default ID override
if "id" not in raw_cfg:
raw_cfg["id"] = os.path.basename(path).replace(".yaml", "")
# Parse Paths
sources = (
raw_cfg.get("data_paths", [])
if "data_paths" in raw_cfg
else [raw_cfg.get("data_path")]
)
sources = [s for s in sources if s]
# Parse Features (supports list of params for parameter sweeps)
features: List[Tuple[str, Dict]] = []
for feat in raw_cfg.get("features", []):
feat_id = feat["id"]
params = feat.get("params", {})
# Support list of params -> expand into multiple feature instances
if isinstance(params, list):
for p in params:
features.append((feat_id, p if p else {}))
else:
features.append((feat_id, params if params else {}))
# Build Base Config
config = BacktestConfig(
id=raw_cfg["id"],
base_sources=sources,
target_s=[int(t) for t in raw_cfg.get("target_s", [])],
features=features,
output_path=raw_cfg.get("output_path"),
start_time=raw_cfg.get("start_time"),
end_time=raw_cfg.get("end_time"),
timezone=raw_cfg.get("timezone", "US/Eastern"),
start_day=raw_cfg.get("start_day"),
end_day=raw_cfg.get("end_day"),
)
# Apply Overrides (Elegant Factory Pattern)
if overrides:
config = dataclasses.replace(config, **overrides)
return cls(config, storage_options=None)
[docs]
def to_analysis(self, output_base: Optional[str] = None) -> AnalysisHarness:
"""
Creates an AnalysisHarness pointing to the output of this backtest.
"""
path = output_base or self.config.output_path
if not path:
raise ValueError("No output_path specified in harness or method call.")
source_pattern = f"{path}/*.parquet"
if path.endswith(".parquet"):
source_pattern = path
return AnalysisHarness(
base_source=source_pattern,
id=f"{self.id}_analysis",
storage_options=self.storage_options,
)
def _resolve_paths(self) -> List[str]:
"""Resolve paths and apply date filtering.
Consolidates glob expansion (via the base class) and date-range
filtering into a single call so callers don't need to remember
to filter separately.
"""
paths = super()._resolve_paths()
return self._filter_paths_by_date(paths)
def _filter_paths_by_date(self, paths: List[str]) -> List[str]:
"""
Filters data paths based on start_day and end_day configuration.
Assumes filenames contain dates in YYYY-MM-DD format.
"""
if not self.config.start_day and not self.config.end_day:
return paths
import re
from datetime import datetime
# Pattern to extract date from filename (YYYY-MM-DD or YYYYMMDD)
date_pattern = r"(\d{4})-?(\d{2})-?(\d{2})"
filtered_paths = []
for path in paths:
# Extract date from filename
match = re.search(date_pattern, path)
if not match:
# If no date found in filename, include it by default
filtered_paths.append(path)
continue
# Parse the date from filename
year, month, day = match.groups()
file_date_str = f"{year}-{month}-{day}"
try:
file_date = datetime.strptime(file_date_str, "%Y-%m-%d").date()
# Check if within range
if self.config.start_day:
start_date = datetime.strptime(
self.config.start_day, "%Y-%m-%d"
).date()
if file_date < start_date:
continue
if self.config.end_day:
end_date = datetime.strptime(self.config.end_day, "%Y-%m-%d").date()
if file_date > end_date:
continue
filtered_paths.append(path)
except ValueError:
# If date parsing fails, include it by default
filtered_paths.append(path)
return filtered_paths
[docs]
def run_distributed(
self,
executor,
task_fn: Callable = None,
benchmark_file: Optional[str] = None,
concurrency_limit: int = 100,
secrets: Optional[list] = None,
) -> pl.DataFrame:
"""
Run the backtest across all configured data files using the given executor.
Args:
concurrency_limit: Max concurrent workers (Modal containers).
secrets: List of Modal secrets (e.g. for AWS). If None, uses executor default.
"""
from forecasting.simulation.worker import run_single_day
if not self.config.output_path:
raise ValueError("No output_path configured for distributed run.")
with TimingContext("wave1_resolve_paths"):
data_paths = self._resolve_paths()
# Filter out incompatible MBP-10 files (which lack order_id)
# S3 glove expansion might pick up 'mbp-10' files if they are in the same bucket tree.
data_paths = [p for p in data_paths if "mbp-10" not in p and ".mbp-" not in p]
if not data_paths:
print("No data paths found.")
return pl.DataFrame()
# Shuffle to avoid long-tail clustering of heavy days
random.shuffle(data_paths)
print(
f"Submitting {len(data_paths)} tasks to {executor.__class__.__name__} (Concurrency: {concurrency_limit})..."
)
# Convert r2://dev/ paths to remote paths for executor
remote_paths = Storage.get_remote_path(data_paths)
remote_output = Storage.get_remote_path(self.config.output_path)
# Prepare Inputs: (path, config, output_dir, i, total, benchmark_file=None)
inputs = [
(path, self.config, remote_output, i, len(remote_paths), None)
for i, path in enumerate(remote_paths)
]
# Executor Options
kwargs = self._get_executor_options(executor)
kwargs["concurrency_limit"] = concurrency_limit
# kwargs['timeout'] = 120 # 2 min aggressive timeout per task (disabled for now)
kwargs["retries"] = (
2 # Automatic retries for transient heartbeat/network failures
)
if secrets is not None:
kwargs["secrets"] = secrets
# Execute
results = executor.map(run_single_day, inputs, **kwargs)
# Save Metadata
with TimingContext("wave1_save_metadata"):
self._save_metadata(self.config.output_path)
return self._format_results(results, benchmark_file)
[docs]
def run_sequential(
self, output_dir: Optional[str] = None, benchmark_file: Optional[str] = None
) -> None:
"""
Runs the backtest sequentially locally.
"""
import os
from datetime import datetime
from tqdm import tqdm
from forecasting.simulation.worker import run_single_day
# Determine final output path
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
final_output = output_dir or self.config.output_path
if final_output:
if final_output.startswith("r2://"):
final_output = f"{final_output.rstrip('/')}/{timestamp}"
else:
final_output = os.path.join(final_output, timestamp)
print(f"Results will be saved to: {final_output}")
data_paths = self._resolve_paths()
print(f"Starting Sequential Backtest on {len(data_paths)} files...")
pbar = tqdm(enumerate(data_paths), total=len(data_paths), unit="day")
for i, path in pbar:
filename = os.path.basename(path)
pbar.set_description(f"Processing {filename}")
run_single_day(
path, self.config, final_output, i, len(data_paths), benchmark_file
)
self._save_metadata(final_output)
def _save_metadata(self, output_dir: Optional[str]):
"""Saves the BacktestConfig as metadata.json in the output directory."""
if not output_dir:
return
path = output_dir.rstrip("/") + "/metadata.json"
data = dataclasses.asdict(self.config)
print(f"Saving metadata to {path}...")
try:
Storage.save_json(data, path)
except Exception as e:
print(f"Warning: Failed to save metadata: {e}")
@staticmethod
def _get_executor_options(executor) -> Dict:
"""Configures executor specific options (like starmap/sync_cpp) via inspection."""
kwargs = {}
if hasattr(executor, "map"):
try:
co_varnames = executor.map.__code__.co_varnames
if "starmap" in co_varnames:
kwargs["starmap"] = True
if "sync_cpp" in co_varnames:
kwargs["sync_cpp"] = True
except Exception as e:
print(f"Warning: Failed to inspect executor options: {e}")
pass
return kwargs
@staticmethod
def _format_results(
results: List, benchmark_file: Optional[str] = None
) -> pl.DataFrame:
print("Backtest Complete.")
cols = [
"timestamp",
"filename",
"instrument_id",
"n_events",
"load_time",
"engine_time",
"save_time",
"total_time",
"status",
"error",
]
# Filter valid results
valid_results = [
r for r in results if isinstance(r, (list, tuple)) and len(r) == len(cols)
]
if not valid_results:
print("Warning: No valid results returned.")
return pl.DataFrame(schema=cols)
# df = pl.from_records(valid_results, schema=cols)
# Use DataFrame constructor directly to avoid deprecated from_records warning crash
df = pl.DataFrame(valid_results, schema=cols, orient="row")
# Save aggregated benchmark log if requested
if benchmark_file:
try:
print(f"Saving aggregated benchmark logs to {benchmark_file}...")
df.write_csv(benchmark_file)
except Exception as e:
print(f"Failed to save benchmark log: {e}")
return df
[docs]
def get_market_subsampled_df(
self,
n: int = 100_000,
strategy: str = "random_files",
seed: Optional[int] = None,
n_days: Optional[int] = None,
instrument_id: Optional[int] = None,
) -> pl.DataFrame:
"""
Returns a DataFrame of raw market data (Databento MBO events) for exploration.
This is useful for analyzing raw order book events before featurization.
Uses the DataLoader to read .dbn.zst files.
Args:
n: Number of events to return (approximate for some strategies).
strategy:
- "head": First n events from first file.
- "random_files": Picks random files and reads n events total.
- "random_days": Picks n_days random files, reads all events.
seed: Random seed for file selection.
n_days: Number of files/days for "random_days" strategy.
instrument_id: Specific instrument to filter for (auto-detects if None).
Returns:
pl.DataFrame with columns: action, side, price, size, order_id, ts_recv, flags
"""
from forecasting.data.loader import DataLoader
files = [
f for f in self._resolve_paths() if "mbp-10" not in f and ".mbp-" not in f
]
if not files:
print("No .dbn.zst or .dbn files found in base_source")
return pl.DataFrame()
# Select files based on strategy
if strategy == "head":
selected_files = files[:1]
elif strategy in ["random_files", "random_days"]:
count = (
n_days if strategy == "random_days" and n_days else min(len(files), 5)
)
rng = random.Random(seed)
selected_files = rng.sample(files, min(len(files), count))
else:
raise ValueError(f"Unknown strategy: {strategy}")
print(f"Loading market data from {len(selected_files)} files...")
# Load and concatenate
all_data = []
total_events = 0
# Calculate events per file for efficient loading
events_per_file = (
n // len(selected_files) + 1 if strategy != "random_days" else None
)
for f in selected_files:
loader = DataLoader(
f,
start_time=self.config.start_time,
end_time=self.config.end_time,
timezone=self.config.timezone,
)
# Use n_events for memory-efficient loading (only for non-random_days)
remaining = n - total_events if strategy != "random_days" else None
n_to_load = min(events_per_file, remaining) if remaining else None
sim_data = loader.load(instrument_id=instrument_id, n_events=n_to_load)
if sim_data.n_events == 0:
continue
# Convert to DataFrame
df = pl.DataFrame(
{
"action": sim_data.actions,
"side": sim_data.sides,
"price": sim_data.prices,
"size": sim_data.sizes,
"order_id": sim_data.order_ids,
"ts_recv": sim_data.ts_recvs,
"flags": sim_data.flags,
}
)
all_data.append(df)
total_events += len(df)
# For non-random_days strategies, stop if we have enough
if strategy != "random_days" and total_events >= n:
break
if not all_data:
return pl.DataFrame()
result = pl.concat(all_data)
# Limit rows for non-random_days strategies
if strategy != "random_days" and len(result) > n:
result = result.head(n)
print(f"Loaded {len(result):,} market events")
return result