from typing import Any, Dict, List, Callable, Optional, Union
import polars as pl
import random
from forecasting.simulation.harness_base import BaseHarness
from forecasting.metrics.utils import compute_metric_safe
from forecasting.metrics.context import MetricContext
from forecasting.metrics.core.stats import MetricConfig
from forecasting.data.storage import Storage
from common.loggers.timing import log_time, TimingContext
import functools
import os
import json
import time
def _identity_transform(lf: pl.LazyFrame) -> pl.LazyFrame:
"""Identity transform for baseline pipeline. Module-level for pickling."""
return lf
def _compose_transforms(
lf: pl.LazyFrame, transform_fn: Callable, base_fn: Callable
) -> pl.LazyFrame:
"""Compose two transforms. Module-level for pickling."""
return transform_fn(base_fn(lf))
def _execute_subsample(
files: List[str],
strategy: str,
n: int,
n_days: int,
seed: Optional[int],
step: Optional[int],
transform_fn_serialized: bytes,
storage_options: Optional[Dict] = None,
) -> "pl.DataFrame":
"""
Core subsampling logic using lazy evaluation via pl.scan_parquet.
Args:
files: List of parquet file paths (already converted to remote paths by caller)
strategy: 'head', 'gather_every', or 'random_days'
n: Row limit for head strategy
n_days: Number of days for random_days strategy
seed: Random seed
step: Step size for gather_every strategy
transform_fn_serialized: Pickled transform function
storage_options: Optional storage options for S3 paths
Returns:
Subsampled DataFrame
"""
import cloudpickle as pickle
if not files:
return pl.DataFrame()
transform_fn = pickle.loads(transform_fn_serialized)
if strategy == "head":
lf = pl.scan_parquet(files, storage_options=storage_options)
return transform_fn(lf).limit(n).collect()
elif strategy == "gather_every":
lf = pl.scan_parquet(files, storage_options=storage_options)
s = step if step else 100
return transform_fn(lf).gather_every(s).collect()
elif strategy == "random_days":
rng = random.Random(seed)
count = min(len(files), n_days)
chosen_files = rng.sample(files, count)
print(f"Sampling {count} days (out of {len(files)} total)...")
subset_lf = pl.scan_parquet(chosen_files, storage_options=storage_options)
return transform_fn(subset_lf).collect()
return pl.DataFrame()
def _execute_collect(
files: List[str],
transform_fn_serialized: bytes,
storage_options: Optional[Dict] = None,
) -> "pl.DataFrame":
"""
Collects transformed data as a DataFrame. Runs inside executor.
Used for cross-storage transfers (e.g., Modal -> local).
Paths should already be converted to remote paths by the caller.
"""
import cloudpickle as pickle
if not files:
return pl.DataFrame()
transform_fn = pickle.loads(transform_fn_serialized)
lf = pl.scan_parquet(files, storage_options=storage_options)
return transform_fn(lf).collect()
def _execute_get_experiment(
files: List[str],
transform_fn_serialized: bytes,
storage_options: Optional[Dict] = None,
) -> bytes:
"""
Builds pipeline and returns serialized LazyFrame plan.
Uses Polars native LazyFrame.serialize() for the plan.
Paths should already be converted to remote paths by the caller.
"""
import cloudpickle as pickle
if not files:
return b""
transform_fn = pickle.loads(transform_fn_serialized)
lf = pl.scan_parquet(files, storage_options=storage_options)
transformed_lf = transform_fn(lf)
# Serialize the LazyFrame plan using Polars native serialization
return transformed_lf.serialize()
def _execute_metrics(
files: List[str],
pipelines_serialized: Dict[str, bytes],
metrics_serialized: bytes,
storage_options: Optional[Dict] = None,
) -> "pl.DataFrame":
"""
Computes metrics for all experiments. Runs inside executor.
Paths should already be converted to remote paths by the caller.
"""
import cloudpickle as pickle
from forecasting.metrics.utils import compute_metric_safe
if not files:
return pl.DataFrame()
metrics = pickle.loads(metrics_serialized)
raw_lazy = pl.scan_parquet(files, storage_options=storage_options)
results = []
for name, pipeline_serialized in pipelines_serialized.items():
pipeline = pickle.loads(pipeline_serialized)
lf = pipeline(raw_lazy)
result = compute_metric_safe(name, lf, metrics)
if result is not None:
results.append(result)
if not results:
return pl.DataFrame()
return pl.concat(results)
[docs]
class AnalysisHarness(BaseHarness):
"""
Orchestrates the post-hoc analysis of backtest results using a declarative pipeline pattern.
The AnalysisHarness treats research as a series of "Experiments." Instead of generating
new data files for every hypothesis, you define a 'transform_fn' (logic) that is
registered in a compute graph.
This allows for:
1. **Rapid Prototyping**: Apply logic to a small local sample (subsampling) to verify math.
2. **Scalable Execution**: Execute that same logic across multi-terabyte datasets via
distributed executors without code changes.
3. **Lineage Tracking**: Maintains a clear map of transformations applied to 'baseline' results.
Attributes:
pipelines (Dict[str, pl.LazyFrame]): A registry of named experiments
represented as unexecuted Polars compute graphs.
transforms (Dict[str, Callable]): A registry of the actual functions used
to generate the pipelines (used for re-applying logic during subsampling).
"""
def __init__(
self,
id: str,
base_source: Union[str, List[str]],
storage_options: Optional[Dict] = None,
):
"""
Args:
base_source: Glob pattern, S3 path, or list of Parquet file paths (the backtest results).
storage_options: Cloud storage credentials/configuration for Polars.
id: Unique identifier for this analysis session.
"""
super().__init__(base_source, storage_options, id)
# transforms[name] = original single-step transform function
self.transforms: Dict[str, Callable[[pl.LazyFrame], pl.LazyFrame]] = {
"baseline": _identity_transform
}
# pipelines[name] = composed transform function (full chain from baseline)
self.pipelines: Dict[str, Callable[[pl.LazyFrame], pl.LazyFrame]] = {
"baseline": _identity_transform
}
# Load Metadata if available
self.metadata = self._load_metadata(base_source)
def _load_metadata(self, source: Union[str, List[str]]) -> Dict:
"""Load metadata.json from the source directory.
Returns an empty dict if the file is not found so that downstream
code can fall back to defaults without crashing the scoring pipeline.
"""
path_str = source if isinstance(source, str) else source[0]
if "*" in path_str:
base_dir = os.path.dirname(path_str)
elif path_str.endswith(".parquet"):
base_dir = os.path.dirname(path_str)
else:
base_dir = path_str
meta_path = f"{base_dir.rstrip('/')}/metadata.json"
metadata = Storage.load_json(meta_path)
print(f"[AnalysisHarness] Loaded metadata from {meta_path}")
return metadata
@property
def columns(self) -> List[str]:
"""Returns all columns available in the baseline dataset."""
files = self._resolve_paths()
if not files:
return []
return Storage.read_parquet_schema(files[0])
[docs]
def get_targets(self) -> List[str]:
"""Identifies target/prediction columns (e.g., target_1s)."""
return [c for c in self.columns if c.startswith("target_")]
[docs]
def get_features(self) -> List[str]:
"""Identifies all non-target columns."""
return [c for c in self.columns if not c.startswith("target_")]
[docs]
def add_experiment(
self,
name: str,
transform_fn: Callable[[pl.LazyFrame], pl.LazyFrame],
base: str = "baseline",
):
"""
Registers a new hypothesis or data view, optionally chaining from another experiment.
This builds a 'Virtual Dataset'. No computation happens until metrics are requested
or the experiment is 'sinked' to disk.
Args:
name: Unique name for this experiment (e.g., 'filtered_high_vol').
transform_fn: A function taking a LazyFrame and returning a modified LazyFrame.
base: Name of the experiment to build on (default: "baseline").
Example:
harness.add_experiment("filtered", lambda lf: lf.filter(pl.col("volume") > 1000))
harness.add_experiment("aggregated", lambda lf: lf.group_by("symbol").agg(...), base="filtered")
"""
if base not in self.pipelines:
raise ValueError(f"Base experiment '{base}' not found.")
if name in self.pipelines:
print(f"Warning: Overwriting existing experiment '{name}'")
# Store original single-step transform
self.transforms[name] = transform_fn
# Compose with base to create full pipeline using functools.partial (picklable)
base_pipeline = self.pipelines[base]
self.pipelines[name] = functools.partial(
_compose_transforms, transform_fn=transform_fn, base_fn=base_pipeline
)
[docs]
def get_experiment(self, name: str) -> pl.LazyFrame:
"""
Returns the compute graph (LazyFrame) for a specific experiment.
The plan is built via the executor and serialized back using Polars
native LazyFrame.serialize(). Call .collect() to execute the plan.
Args:
name: Experiment name.
Returns:
LazyFrame with the experiment's compute graph.
"""
import cloudpickle as pickle
import io
if name not in self.pipelines:
raise ValueError(f"Experiment '{name}' not found.")
files = self._resolve_paths()
executor = Storage.get_executor(self.base_source)
# Convert r2://dev/ paths to remote paths for executor
remote_files = Storage.get_remote_path(files)
transform_fn_serialized = pickle.dumps(self.pipelines[name])
args = (remote_files, transform_fn_serialized, self.storage_options)
results = executor.map(_execute_get_experiment, [args], starmap=True)
if results and results[0]:
# Deserialize the LazyFrame plan
return pl.LazyFrame.deserialize(io.BytesIO(results[0]))
return pl.LazyFrame()
[docs]
def get_metrics(
self,
metrics: Dict[str, Callable[[pl.LazyFrame], pl.DataFrame]],
experiment_names: Optional[List[str]] = None,
concurrency: int = 100,
) -> pl.DataFrame:
"""
Triggers the execution of the compute graphs across selected experiments.
Args:
metrics: Dict of metric functions to apply to each experiment pipeline.
experiment_names: Names of registered experiments to evaluate. Defaults to all.
concurrency: Max workers if distributed.
"""
import cloudpickle as pickle
if experiment_names is None:
experiment_names = list(self.pipelines.keys())
files = self._resolve_paths()
executor = Storage.get_executor(self.base_source)
# Convert r2://dev/ paths to remote paths for executor
remote_files = Storage.get_remote_path(files)
# Serialize pipelines and metrics for execution
pipelines_serialized = {
name: pickle.dumps(self.pipelines[name])
for name in experiment_names
if name in self.pipelines
}
metrics_serialized = pickle.dumps(metrics)
args = (
remote_files,
pipelines_serialized,
metrics_serialized,
self.storage_options,
)
results = executor.map(_execute_metrics, [args], starmap=True)
return results[0] if results else pl.DataFrame()
[docs]
def sink_experiment(self, name: str, output_path: str):
"""
Materializes the experiment compute graph into a Parquet file.
Collects data via executor, then writes to output_path using Storage
(supports local, S3, or r2://dev/ paths).
Note: This loads all data into memory. For very large datasets,
consider using streaming approaches or chunked writes.
"""
if name not in self.pipelines:
raise ValueError(f"Experiment {name} not found.")
print(f"Sinking {name} to {output_path}...")
import cloudpickle as pickle
files = self._resolve_paths()
transform_fn_serialized = pickle.dumps(self.pipelines[name])
# 1. Collect data via executor
executor = Storage.get_executor(self.base_source)
remote_files = Storage.get_remote_path(files)
args = (remote_files, transform_fn_serialized, self.storage_options)
results = executor.map(_execute_collect, [args], starmap=True)
if not results or results[0] is None or len(results[0]) == 0:
print(" Warning: No data collected")
return
df = results[0]
# 2. Write to output_path (Storage handles local/S3/Modal)
Storage.save_parquet(df.to_pandas(), output_path)
print(f" Saved {len(df)} rows to {output_path}")
[docs]
def get_subsampled_df(
self,
experiment_name: str = "baseline",
n: int = 100_000,
strategy: str = "random_days",
seed: Optional[int] = None,
step: Optional[int] = None,
n_days: int = 5,
) -> pl.DataFrame:
"""
Injects a small data sample into an experiment's logic to return an eager DataFrame.
This is the primary tool for interactive research. It guarantees that the logic
you see in your notebook is identical to the logic run on the cluster.
Args:
experiment_name: Which experiment logic to apply to the sample.
n: Row limit for the final sample.
strategy:
- 'random_days': Picks whole days (preserving HFT microstructure). **Preferred**.
- 'head': Fast, but biased toward early dates.
- 'gather_every': Systematically samples every Nth row.
n_days: How many full days to include in the 'random_days' sample.
"""
import cloudpickle as pickle
if experiment_name not in self.pipelines:
raise ValueError(f"Experiment '{experiment_name}' not found.")
if strategy == "random_rows":
raise ValueError(
"random_rows is disabled to prevent time-series integrity loss. Use 'random_days'."
)
if strategy not in ("head", "gather_every", "random_days"):
raise ValueError(
f"Unknown strategy: {strategy}. Options: 'random_days', 'gather_every', 'head'."
)
files = self._resolve_paths()
transform_fn_serialized = pickle.dumps(self.pipelines[experiment_name])
executor = Storage.get_executor(self.base_source)
# Convert r2://dev/ paths to remote paths for executor
remote_files = Storage.get_remote_path(files)
args = (
remote_files,
strategy,
n,
n_days,
seed,
step,
transform_fn_serialized,
self.storage_options,
)
results = executor.map(_execute_subsample, [args], starmap=True)
return results[0] if results else pl.DataFrame()
[docs]
def compute_global_metric(
self,
executor,
map_fn: Callable,
stats_class,
output_base: str,
pattern: str = "*.parquet",
start_time: Optional[str] = None,
end_time: Optional[str] = None,
concurrency_limit: int = 100,
secrets: Optional[list] = None,
metric_config: Optional[MetricConfig] = None,
) -> Union[pl.DataFrame, Dict[str, pl.DataFrame]]:
"""
Orchestrates a metric-agnostic Map-Reduce calculation.
The method only reads the scope fields (``per_day``, ``aggregated``)
from *metric_config*. The full config object is passed through
opaquely to the *map_fn* (as ``metric_config``) and *stats_class*
(as ``config``), so metric-specific settings live entirely in the
subclass (e.g. :class:`R2Config`).
Args:
executor: BaseExecutor.
map_fn: The mapping function (e.g. map_r2_chunk).
stats_class: The SufficientStats class (e.g. R2SufficientStats).
output_base: Base path where backtest results are stored.
pattern: Glob pattern for result files.
start_time: Optional filter (e.g. "09:30:00"). If None, tries metadata.
end_time: Optional filter (e.g. "16:00:00"). If None, tries metadata.
concurrency_limit: Max concurrent worker containers (Modal).
secrets: Optional Modal secrets list.
metric_config: Metric configuration (scope + metric-specific settings).
Passed through to *map_fn* and *stats_class*.
Returns:
When per_day is disabled: pl.DataFrame with aggregated metrics.
When per_day is enabled: Dict with keys "aggregated" and/or "per_day".
"""
want_per_day = metric_config.per_day if metric_config else False
want_aggregated = metric_config.aggregated if metric_config else True
# --- Resolve Time Filters ------------------------------------------
s_time = start_time or self.metadata.get("start_time")
e_time = end_time or self.metadata.get("end_time")
tz = self.metadata.get("timezone", "US/Eastern")
context = None
if s_time or e_time:
context = MetricContext(start_time=s_time, end_time=e_time, timezone=tz)
print(f"Applying Filter Context: {s_time} - {e_time} [{tz}]")
# --- Bind args for map function ------------------------------------
partial_kwargs: Dict[str, Any] = {}
if metric_config is not None:
partial_kwargs["metric_config"] = metric_config
if context:
partial_kwargs["context"] = context
# --- 1. List Files -------------------------------------------------
search_path = (
f"{output_base}/{pattern}"
if not output_base.endswith(".parquet")
else output_base
)
_t_list = time.time()
with Storage.cached():
files = Storage.list_files(search_path)
log_time(
"wave2_list_files",
time.time() - _t_list,
n_files=len(files) if files else 0,
)
if not files:
print(f"No files found at {search_path}")
if want_per_day:
return {"aggregated": pl.DataFrame(), "per_day": pl.DataFrame()}
return pl.DataFrame()
remote_files = Storage.get_remote_path(files)
# --- 2. Map Phase --------------------------------------------------
BATCH_SIZE = 10
chunks = [
remote_files[i : i + BATCH_SIZE]
for i in range(0, len(remote_files), BATCH_SIZE)
]
print(f"Batching {len(remote_files)} files into {len(chunks)} chunks")
print(
f"Submitting {len(chunks)} metric map tasks (Concurrency: {concurrency_limit})..."
)
kwargs: Dict[str, Any] = {}
if hasattr(executor, "map"):
if "sync_cpp" in executor.map.__code__.co_varnames:
kwargs["sync_cpp"] = False
if "concurrency_limit" in executor.map.__code__.co_varnames:
kwargs["concurrency_limit"] = concurrency_limit
if secrets is not None:
kwargs["secrets"] = secrets
# If we need per-day metrics, add return_per_file=True to the map function
_t_map = time.time()
if want_per_day:
per_day_kwargs = dict(partial_kwargs)
per_day_kwargs["return_per_file"] = True
bound_map_fn = functools.partial(map_fn, **per_day_kwargs)
else:
bound_map_fn = (
functools.partial(map_fn, **partial_kwargs)
if partial_kwargs
else map_fn
)
metric_results = executor.map(bound_map_fn, chunks, **kwargs)
log_time("wave2_map", time.time() - _t_map, n_chunks=len(chunks))
# 3. Reduce Phase
print("Aggregating partial results...")
_t_reduce = time.time()
valid_results = [r for r in metric_results if r is not None]
if not valid_results:
print("No valid results received.")
if want_per_day:
return {"aggregated": pl.DataFrame(), "per_day": pl.DataFrame()}
return pl.DataFrame()
stats_init_kwargs: Dict[str, Any] = {}
if metric_config is not None:
stats_init_kwargs["config"] = metric_config
if want_per_day:
batch_stats = []
per_file_results = []
for res in valid_results:
if isinstance(res, dict):
if res.get("batch_stats") is not None:
batch_stats.append(res["batch_stats"])
if res.get("per_file") is not None:
per_file_results.extend(res["per_file"])
else:
batch_stats.append(res)
result: Dict[str, pl.DataFrame] = {}
if want_aggregated and batch_stats:
total_stats = stats_class(**stats_init_kwargs).accumulate(batch_stats)
aggregated_df = total_stats.compute()
elif want_aggregated:
aggregated_df = pl.DataFrame()
per_day_df = pl.DataFrame()
if per_file_results:
per_day_df = pl.concat(per_file_results)
log_time(
"wave2_reduce", time.time() - _t_reduce, n_results=len(valid_results)
)
return {"aggregated": aggregated_df, "per_day": per_day_df}
else:
total_stats = stats_class(**stats_init_kwargs).accumulate(valid_results)
result = total_stats.compute()
log_time(
"wave2_reduce", time.time() - _t_reduce, n_results=len(valid_results)
)
return result