from typing import Dict, List, Callable, Optional, Union
import polars as pl
import random
from forecasting.simulation.harness_base import BaseHarness
from forecasting.metrics.utils import compute_metric_safe
from forecasting.metrics.context import MetricContext
from forecasting.data.storage import Storage
from common.loggers.timing import log_time, TimingContext
import functools
import os
import json
import time
def _identity_transform(lf: pl.LazyFrame) -> pl.LazyFrame:
"""Identity transform for baseline pipeline. Module-level for pickling."""
return lf
def _compose_transforms(
lf: pl.LazyFrame, transform_fn: Callable, base_fn: Callable
) -> pl.LazyFrame:
"""Compose two transforms. Module-level for pickling."""
return transform_fn(base_fn(lf))
def _execute_subsample(
files: List[str],
strategy: str,
n: int,
n_days: int,
seed: Optional[int],
step: Optional[int],
transform_fn_serialized: bytes,
storage_options: Optional[Dict] = None,
) -> "pl.DataFrame":
"""
Core subsampling logic using lazy evaluation via pl.scan_parquet.
Args:
files: List of parquet file paths (already converted to remote paths by caller)
strategy: 'head', 'gather_every', or 'random_days'
n: Row limit for head strategy
n_days: Number of days for random_days strategy
seed: Random seed
step: Step size for gather_every strategy
transform_fn_serialized: Pickled transform function
storage_options: Optional storage options for S3 paths
Returns:
Subsampled DataFrame
"""
import cloudpickle as pickle
if not files:
return pl.DataFrame()
transform_fn = pickle.loads(transform_fn_serialized)
if strategy == "head":
lf = pl.scan_parquet(files, storage_options=storage_options)
return transform_fn(lf).limit(n).collect()
elif strategy == "gather_every":
lf = pl.scan_parquet(files, storage_options=storage_options)
s = step if step else 100
return transform_fn(lf).gather_every(s).collect()
elif strategy == "random_days":
rng = random.Random(seed)
count = min(len(files), n_days)
chosen_files = rng.sample(files, count)
print(f"Sampling {count} days (out of {len(files)} total)...")
subset_lf = pl.scan_parquet(chosen_files, storage_options=storage_options)
return transform_fn(subset_lf).collect()
return pl.DataFrame()
def _execute_collect(
files: List[str],
transform_fn_serialized: bytes,
storage_options: Optional[Dict] = None,
) -> "pl.DataFrame":
"""
Collects transformed data as a DataFrame. Runs inside executor.
Used for cross-storage transfers (e.g., Modal -> local).
Paths should already be converted to remote paths by the caller.
"""
import cloudpickle as pickle
if not files:
return pl.DataFrame()
transform_fn = pickle.loads(transform_fn_serialized)
lf = pl.scan_parquet(files, storage_options=storage_options)
return transform_fn(lf).collect()
def _execute_get_experiment(
files: List[str],
transform_fn_serialized: bytes,
storage_options: Optional[Dict] = None,
) -> bytes:
"""
Builds pipeline and returns serialized LazyFrame plan.
Uses Polars native LazyFrame.serialize() for the plan.
Paths should already be converted to remote paths by the caller.
"""
import cloudpickle as pickle
if not files:
return b""
transform_fn = pickle.loads(transform_fn_serialized)
lf = pl.scan_parquet(files, storage_options=storage_options)
transformed_lf = transform_fn(lf)
# Serialize the LazyFrame plan using Polars native serialization
return transformed_lf.serialize()
def _execute_metrics(
files: List[str],
pipelines_serialized: Dict[str, bytes],
metrics_serialized: bytes,
storage_options: Optional[Dict] = None,
) -> "pl.DataFrame":
"""
Computes metrics for all experiments. Runs inside executor.
Paths should already be converted to remote paths by the caller.
"""
import cloudpickle as pickle
from forecasting.metrics.utils import compute_metric_safe
if not files:
return pl.DataFrame()
metrics = pickle.loads(metrics_serialized)
raw_lazy = pl.scan_parquet(files, storage_options=storage_options)
results = []
for name, pipeline_serialized in pipelines_serialized.items():
pipeline = pickle.loads(pipeline_serialized)
lf = pipeline(raw_lazy)
result = compute_metric_safe(name, lf, metrics)
if result is not None:
results.append(result)
if not results:
return pl.DataFrame()
return pl.concat(results)
[docs]
class AnalysisHarness(BaseHarness):
"""
Orchestrates the post-hoc analysis of backtest results using a declarative pipeline pattern.
The AnalysisHarness treats research as a series of "Experiments." Instead of generating
new data files for every hypothesis, you define a 'transform_fn' (logic) that is
registered in a compute graph.
This allows for:
1. **Rapid Prototyping**: Apply logic to a small local sample (subsampling) to verify math.
2. **Scalable Execution**: Execute that same logic across multi-terabyte datasets via
distributed executors without code changes.
3. **Lineage Tracking**: Maintains a clear map of transformations applied to 'baseline' results.
Attributes:
pipelines (Dict[str, pl.LazyFrame]): A registry of named experiments
represented as unexecuted Polars compute graphs.
transforms (Dict[str, Callable]): A registry of the actual functions used
to generate the pipelines (used for re-applying logic during subsampling).
"""
def __init__(
self,
id: str,
base_source: Union[str, List[str]],
storage_options: Optional[Dict] = None,
):
"""
Args:
base_source: Glob pattern, S3 path, or list of Parquet file paths (the backtest results).
storage_options: Cloud storage credentials/configuration for Polars.
id: Unique identifier for this analysis session.
"""
super().__init__(base_source, storage_options, id)
# transforms[name] = original single-step transform function
self.transforms: Dict[str, Callable[[pl.LazyFrame], pl.LazyFrame]] = {
"baseline": _identity_transform
}
# pipelines[name] = composed transform function (full chain from baseline)
self.pipelines: Dict[str, Callable[[pl.LazyFrame], pl.LazyFrame]] = {
"baseline": _identity_transform
}
# Load Metadata if available
self.metadata = self._load_metadata(base_source)
def _load_metadata(self, source: Union[str, List[str]]) -> Dict:
"""Load metadata.json from the source directory. Raises if not found."""
# Heuristic: If source is a glob or list, look in the parent dir
path_str = source if isinstance(source, str) else source[0]
if "*" in path_str:
base_dir = os.path.dirname(path_str)
elif path_str.endswith(".parquet"):
base_dir = os.path.dirname(path_str)
else:
base_dir = path_str
meta_path = f"{base_dir.rstrip('/')}/metadata.json"
metadata = Storage.load_json(meta_path)
print(f"[AnalysisHarness] Loaded metadata from {meta_path}")
return metadata
@property
def columns(self) -> List[str]:
"""Returns all columns available in the baseline dataset."""
files = self._resolve_paths()
if not files:
return []
return Storage.read_parquet_schema(files[0])
[docs]
def get_targets(self) -> List[str]:
"""Identifies target/prediction columns (e.g., target_1s)."""
return [c for c in self.columns if c.startswith("target_")]
[docs]
def get_features(self) -> List[str]:
"""Identifies all non-target columns."""
return [c for c in self.columns if not c.startswith("target_")]
[docs]
def add_experiment(
self,
name: str,
transform_fn: Callable[[pl.LazyFrame], pl.LazyFrame],
base: str = "baseline",
):
"""
Registers a new hypothesis or data view, optionally chaining from another experiment.
This builds a 'Virtual Dataset'. No computation happens until metrics are requested
or the experiment is 'sinked' to disk.
Args:
name: Unique name for this experiment (e.g., 'filtered_high_vol').
transform_fn: A function taking a LazyFrame and returning a modified LazyFrame.
base: Name of the experiment to build on (default: "baseline").
Example:
harness.add_experiment("filtered", lambda lf: lf.filter(pl.col("volume") > 1000))
harness.add_experiment("aggregated", lambda lf: lf.group_by("symbol").agg(...), base="filtered")
"""
if base not in self.pipelines:
raise ValueError(f"Base experiment '{base}' not found.")
if name in self.pipelines:
print(f"Warning: Overwriting existing experiment '{name}'")
# Store original single-step transform
self.transforms[name] = transform_fn
# Compose with base to create full pipeline using functools.partial (picklable)
base_pipeline = self.pipelines[base]
self.pipelines[name] = functools.partial(
_compose_transforms, transform_fn=transform_fn, base_fn=base_pipeline
)
[docs]
def get_experiment(self, name: str) -> pl.LazyFrame:
"""
Returns the compute graph (LazyFrame) for a specific experiment.
The plan is built via the executor and serialized back using Polars
native LazyFrame.serialize(). Call .collect() to execute the plan.
Args:
name: Experiment name.
Returns:
LazyFrame with the experiment's compute graph.
"""
import cloudpickle as pickle
import io
if name not in self.pipelines:
raise ValueError(f"Experiment '{name}' not found.")
files = self._resolve_paths()
executor = Storage.get_executor(self.base_source)
# Convert r2://dev/ paths to remote paths for executor
remote_files = Storage.get_remote_path(files)
transform_fn_serialized = pickle.dumps(self.pipelines[name])
args = (remote_files, transform_fn_serialized, self.storage_options)
results = executor.map(_execute_get_experiment, [args], starmap=True)
if results and results[0]:
# Deserialize the LazyFrame plan
return pl.LazyFrame.deserialize(io.BytesIO(results[0]))
return pl.LazyFrame()
[docs]
def get_metrics(
self,
metrics: Dict[str, Callable[[pl.LazyFrame], pl.DataFrame]],
experiment_names: Optional[List[str]] = None,
concurrency: int = 100,
) -> pl.DataFrame:
"""
Triggers the execution of the compute graphs across selected experiments.
Args:
metrics: Dict of metric functions to apply to each experiment pipeline.
experiment_names: Names of registered experiments to evaluate. Defaults to all.
concurrency: Max workers if distributed.
"""
import cloudpickle as pickle
if experiment_names is None:
experiment_names = list(self.pipelines.keys())
files = self._resolve_paths()
executor = Storage.get_executor(self.base_source)
# Convert r2://dev/ paths to remote paths for executor
remote_files = Storage.get_remote_path(files)
# Serialize pipelines and metrics for execution
pipelines_serialized = {
name: pickle.dumps(self.pipelines[name])
for name in experiment_names
if name in self.pipelines
}
metrics_serialized = pickle.dumps(metrics)
args = (
remote_files,
pipelines_serialized,
metrics_serialized,
self.storage_options,
)
results = executor.map(_execute_metrics, [args], starmap=True)
return results[0] if results else pl.DataFrame()
[docs]
def sink_experiment(self, name: str, output_path: str):
"""
Materializes the experiment compute graph into a Parquet file.
Collects data via executor, then writes to output_path using Storage
(supports local, S3, or r2://dev/ paths).
Note: This loads all data into memory. For very large datasets,
consider using streaming approaches or chunked writes.
"""
if name not in self.pipelines:
raise ValueError(f"Experiment {name} not found.")
print(f"Sinking {name} to {output_path}...")
import cloudpickle as pickle
files = self._resolve_paths()
transform_fn_serialized = pickle.dumps(self.pipelines[name])
# 1. Collect data via executor
executor = Storage.get_executor(self.base_source)
remote_files = Storage.get_remote_path(files)
args = (remote_files, transform_fn_serialized, self.storage_options)
results = executor.map(_execute_collect, [args], starmap=True)
if not results or results[0] is None or len(results[0]) == 0:
print(" Warning: No data collected")
return
df = results[0]
# 2. Write to output_path (Storage handles local/S3/Modal)
Storage.save_parquet(df.to_pandas(), output_path)
print(f" Saved {len(df)} rows to {output_path}")
[docs]
def get_subsampled_df(
self,
experiment_name: str = "baseline",
n: int = 100_000,
strategy: str = "random_days",
seed: Optional[int] = None,
step: Optional[int] = None,
n_days: int = 5,
) -> pl.DataFrame:
"""
Injects a small data sample into an experiment's logic to return an eager DataFrame.
This is the primary tool for interactive research. It guarantees that the logic
you see in your notebook is identical to the logic run on the cluster.
Args:
experiment_name: Which experiment logic to apply to the sample.
n: Row limit for the final sample.
strategy:
- 'random_days': Picks whole days (preserving HFT microstructure). **Preferred**.
- 'head': Fast, but biased toward early dates.
- 'gather_every': Systematically samples every Nth row.
n_days: How many full days to include in the 'random_days' sample.
"""
import cloudpickle as pickle
if experiment_name not in self.pipelines:
raise ValueError(f"Experiment '{experiment_name}' not found.")
if strategy == "random_rows":
raise ValueError(
"random_rows is disabled to prevent time-series integrity loss. Use 'random_days'."
)
if strategy not in ("head", "gather_every", "random_days"):
raise ValueError(
f"Unknown strategy: {strategy}. Options: 'random_days', 'gather_every', 'head'."
)
files = self._resolve_paths()
transform_fn_serialized = pickle.dumps(self.pipelines[experiment_name])
executor = Storage.get_executor(self.base_source)
# Convert r2://dev/ paths to remote paths for executor
remote_files = Storage.get_remote_path(files)
args = (
remote_files,
strategy,
n,
n_days,
seed,
step,
transform_fn_serialized,
self.storage_options,
)
results = executor.map(_execute_subsample, [args], starmap=True)
return results[0] if results else pl.DataFrame()
[docs]
def compute_global_metric(
self,
executor,
map_fn: Callable,
stats_class,
output_base: str,
pattern: str = "*.parquet",
start_time: Optional[str] = None,
end_time: Optional[str] = None,
concurrency_limit: int = 100,
secrets: Optional[list] = None,
return_per_day: bool = False,
) -> Union[pl.DataFrame, Dict[str, pl.DataFrame]]:
"""
Orchestrates a robust Map-Reduce metric calculation.
Standard 'averaging' of daily metrics is often mathematically wrong. This method
uses 'Sufficient Statistics' (sums, counts, squared sums) to compute a globally
accurate metric across the entire dataset.
Args:
executor: BaseExecutor.
map_fn: The mapping function (e.g. map_r2_chunk).
stats_class: The SufficientStats class (e.g. R2SufficientStats).
output_base: Base path where backtest results are stored.
pattern: Glob pattern for result files.
start_time: Optional filter (e.g. "09:30:00"). If None, tries metadata.
end_time: Optional filter (e.g. "16:00:00"). If None, tries metadata.
concurrency_limit: Max concurrent worker containers (Modal).
return_per_day: If True, also return per-day metrics alongside aggregated.
Returns:
If return_per_day=False: pl.DataFrame with aggregated metrics.
If return_per_day=True: Dict with keys "aggregated" and "per_day".
"""
# Resolve Time Filters (Arg > Metadata > None)
s_time = start_time or self.metadata.get("start_time")
e_time = end_time or self.metadata.get("end_time")
tz = self.metadata.get("timezone", "US/Eastern")
# Build Context
# Only set if we have meaningful filters to push down
context = None
if s_time or e_time:
context = MetricContext(start_time=s_time, end_time=e_time, timezone=tz)
print(f"Applying Filter Context: {s_time} - {e_time} [{tz}]")
# Bind context to map function
if context:
# functools.partial is picklable and cleaner than lambda
current_map_fn = functools.partial(map_fn, context=context)
else:
current_map_fn = map_fn
# 1. List Files - use Storage.list_files for unified handling of local/S3/Modal volumes
search_path = (
f"{output_base}/{pattern}"
if not output_base.endswith(".parquet")
else output_base
)
_t_list = time.time()
files = Storage.list_files(search_path)
log_time(
"wave2_list_files",
time.time() - _t_list,
n_files=len(files) if files else 0,
)
if not files:
print(f"No files found at {search_path}")
if return_per_day:
return {"aggregated": pl.DataFrame(), "per_day": pl.DataFrame()}
return pl.DataFrame()
# Convert r2://dev/ paths to remote paths for executor
remote_files = Storage.get_remote_path(files)
# 2. Map Phase
# Batching logic could be here or we just let executor handle it.
# ModalExecutor handles generic lists well. But for chunks we usually want explicit batching to reduce overhead.
BATCH_SIZE = 10
chunks = [
remote_files[i : i + BATCH_SIZE]
for i in range(0, len(remote_files), BATCH_SIZE)
]
print(f"Batching {len(remote_files)} files into {len(chunks)} chunks")
print(
f"Submitting {len(chunks)} metric map tasks (Concurrency: {concurrency_limit})..."
)
# Modal specific flags
kwargs = {}
if hasattr(executor, "map"):
if "sync_cpp" in executor.map.__code__.co_varnames:
kwargs["sync_cpp"] = (
False # Metrics usually rely on Parquet/Polars, no C++ engine needed unless specified
)
if "concurrency_limit" in executor.map.__code__.co_varnames:
kwargs["concurrency_limit"] = concurrency_limit
if secrets is not None:
kwargs["secrets"] = secrets
# If we need per-day metrics, add return_per_file=True to the map function
_t_map = time.time()
if return_per_day:
per_day_map_fn = (
functools.partial(map_fn, context=context, return_per_file=True)
if context
else functools.partial(map_fn, return_per_file=True)
)
metric_results = executor.map(per_day_map_fn, chunks, **kwargs)
else:
metric_results = executor.map(current_map_fn, chunks, **kwargs)
log_time("wave2_map", time.time() - _t_map, n_chunks=len(chunks))
# 3. Reduce Phase
print("Aggregating partial results...")
_t_reduce = time.time()
valid_results = [r for r in metric_results if r is not None]
if not valid_results:
print("No valid results received.")
if return_per_day:
return {"aggregated": pl.DataFrame(), "per_day": pl.DataFrame()}
return pl.DataFrame()
if return_per_day:
# Results are dicts with "batch_stats" and "per_file" keys
batch_stats = []
per_file_results = []
for res in valid_results:
if isinstance(res, dict):
if res.get("batch_stats") is not None:
batch_stats.append(res["batch_stats"])
if res.get("per_file") is not None:
per_file_results.extend(res["per_file"])
else:
# Fallback: treat as batch stats only
batch_stats.append(res)
# Aggregate batch stats
aggregated_df = pl.DataFrame()
if batch_stats:
total_stats = stats_class().accumulate(batch_stats)
aggregated_df = total_stats.compute()
# Combine per-file results
per_day_df = pl.DataFrame()
if per_file_results:
per_day_df = pl.concat(per_file_results)
log_time(
"wave2_reduce", time.time() - _t_reduce, n_results=len(valid_results)
)
return {"aggregated": aggregated_df, "per_day": per_day_df}
else:
total_stats = stats_class().accumulate(valid_results)
result = total_stats.compute()
log_time(
"wave2_reduce", time.time() - _t_reduce, n_results=len(valid_results)
)
return result