Source code for forecasting.data.loader

"""
Module: loader.py
Description: Data loading logic to convert Databento MBO records into contiguous numpy arrays for the C++ Engine.
"""

import logging
import numpy as np
import databento as db
import os
from dataclasses import dataclass
from typing import Union, BinaryIO
import io
from urllib.parse import urlparse
import pandas as pd

logger = logging.getLogger(__name__)


[docs] @dataclass class SimulationData: """ Holds prepared, contiguous memory arrays for the simulation engine. Also keeps reference to the original records for Python-side access if needed. """ instrument_id: int n_events: int # Contiguous arrays for C++ Engine actions: np.ndarray # int8 sides: np.ndarray # int8 prices: np.ndarray # int64 sizes: np.ndarray # int64 order_ids: np.ndarray # uint64 ts_recvs: np.ndarray # uint64 flags: np.ndarray # uint8 def __repr__(self): return f"<SimulationData instrument_id={self.instrument_id} n_events={self.n_events}>"
[docs] class DataLoader: def __init__( self, data_path: str, start_time: str = None, end_time: str = None, timezone: str = "UTC", ): self.data_path = data_path self.start_time = start_time self.end_time = end_time self.timezone = timezone def _get_data_source(self) -> Union[str, bytes]: """ Resolves the data source. If the path starts with 'r2://', this method streams the file content directly into memory using boto3. Otherwise, it returns the local file path. Returns: Union[str, bytes]: A local file path string, or bytes object containing the file content if downloaded from S3. """ if self.data_path.startswith("r2://"): from setup.r2 import require_r2 r2 = require_r2() parsed = urlparse(self.data_path) bucket = parsed.netloc key = parsed.path.lstrip("/") logger.info(f"Downloading from S3 to memory: {self.data_path}") s3 = r2.to_client() response = s3.get_object(Bucket=bucket, Key=key) return response["Body"].read() return self.data_path
[docs] def load(self, instrument_id: int = None, n_events: int = None) -> SimulationData: """ Loads data from file (or S3), filters for instrument, and prepares contiguous arrays. Args: instrument_id (int, optional): Specific ID to filter for. If None, defaults to the most active instrument in the file. n_events (int, optional): Maximum number of events to load. If None, loads all events. Useful for memory-efficient sampling. Returns: SimulationData: A dataclass containing contiguous (C-style) numpy arrays ready for zero-copy access by the C++ engine. """ # Resolve source data_source = self._get_data_source() logger.info(f"Loading data...") if isinstance(data_source, bytes): store = db.DBNStore.from_bytes(data_source) else: store = db.DBNStore.from_file(data_source) # If n_events is specified, use iterator for memory efficiency if n_events is not None: logger.info(f"Loading first {n_events} events (memory-efficient mode)...") records_list = [] for i, record in enumerate(store): if i >= n_events: break records_list.append(record) if not records_list: return SimulationData( instrument_id=instrument_id or 0, n_events=0, actions=np.array([], dtype=np.int8), sides=np.array([], dtype=np.int8), prices=np.array([], dtype=np.int64), sizes=np.array([], dtype=np.int64), order_ids=np.array([], dtype=np.uint64), ts_recvs=np.array([], dtype=np.uint64), flags=np.array([], dtype=np.uint8), ) # Convert list of records to structured array all_records = np.array( records_list, dtype=records_list[0].dtype if hasattr(records_list[0], "dtype") else None, ) if all_records.dtype.names is None: # Fallback: load normally if iterator didn't return structured records all_records = store.to_ndarray()[:n_events] else: all_records = store.to_ndarray() # Auto-detect instrument if not provided if instrument_id is None: u, c = np.unique(all_records["instrument_id"], return_counts=True) instrument_id = u[np.argmax(c)] logger.info(f"Auto-detected most active instrument: {instrument_id}") logger.info(f"Filtering for instrument: {instrument_id}") records = all_records[all_records["instrument_id"] == instrument_id] n_events = len(records) logger.info(f"Processing {n_events} events...") # Prepare Contiguous Memory Views # Note: .view() on structured arrays creates non-contiguous strided views. # Using np.ascontiguousarray() copies data to a dense buffer, which is safer for C++ interop. actions = np.ascontiguousarray(records["action"].view("i1")) sides = np.ascontiguousarray(records["side"].view("i1")) prices = np.ascontiguousarray(records["price"].astype(np.int64)) sizes = np.ascontiguousarray(records["size"].astype(np.int64)) order_ids = np.ascontiguousarray(records["order_id"].astype(np.uint64)) ts_recvs = np.ascontiguousarray(records["ts_recv"].astype(np.uint64)) flags = np.ascontiguousarray(records["flags"].view("u1")) if ( self.start_time or self.end_time or (self.start_time is None and self.end_time is None) ): # If explicit times are None, we might still want to default end_time to Market Close? # User request: "can we make the end time the end of the trading day by default?" # This implies if self.end_time is None, we should look it up. # We need at least one timestamp to know the date. if len(ts_recvs) > 0: # Get date from first event # ts_recvs are uint64 nanoseconds first_ts_ns = ts_recvs[0] first_dt = pd.to_datetime(first_ts_ns, unit="ns", utc=True).tz_convert( self.timezone ) # User Request: "Make it clear that when doing order book reconstruction, we'll always start with the first thing... # So, the start time is only used when calculating r-squared" # Therefore, we DO NOT filter by start_time here. We simulate from the beginning of the file. # But we DO filter by end_time to stop simulation early (perf optimization). # S Time is effectively beginning of day s_time = "00:00" e_time = self.end_time # Auto-detect Market Close if end_time is missing if e_time is None: from forecasting.data.calendar_utils import get_market_close # Infer exchange from filename if possible fname = self.data_path.lower() exchange_code = "NYSE" # Default US Equity if "xnas" in fname: exchange_code = "NASDAQ" elif "xnys" in fname: exchange_code = "NYSE" try: e_time = get_market_close(first_dt, exchange=exchange_code) logger.info( f" Auto-detected Market Close for {first_dt.date()} ({exchange_code}): {e_time}" ) except Exception as e: print( f"Warning: Failed to detect market close: {e}. Defaulting to 23:59:59." ) logger.warning( f" Failed to detect market close: {e}. Defaulting to 23:59:59." ) e_time = "23:59:59.999999" logger.info( f"Filtering Time (Loader): Start={s_time} (Include Pre-Market), End={e_time} ({self.timezone})" ) ts_index = pd.to_datetime(ts_recvs, unit="ns", utc=True).tz_convert( self.timezone ) dummy = pd.Series(np.arange(len(ts_index)), index=ts_index) # Use between_time. subset = dummy.between_time(s_time, e_time) valid_indices = subset.values logger.info(f" Kept {len(valid_indices)}/{len(ts_index)} events") # Apply Filter actions = actions[valid_indices] sides = sides[valid_indices] prices = prices[valid_indices] sizes = sizes[valid_indices] order_ids = order_ids[valid_indices] ts_recvs = ts_recvs[valid_indices] flags = flags[valid_indices] n_events = len(actions) else: logger.warning("No events found to determine date for time filtering.") return SimulationData( instrument_id=instrument_id, n_events=n_events, actions=actions, sides=sides, prices=prices, sizes=sizes, order_ids=order_ids, ts_recvs=ts_recvs, flags=flags, )