"""
Module: loader.py
Description: Data loading logic to convert Databento MBO records into contiguous numpy arrays for the C++ Engine.
"""
import logging
import numpy as np
import databento as db
import os
from dataclasses import dataclass
from typing import Union, BinaryIO
import io
from urllib.parse import urlparse
import pandas as pd
logger = logging.getLogger(__name__)
[docs]
@dataclass
class SimulationData:
"""
Holds prepared, contiguous memory arrays for the simulation engine.
Also keeps reference to the original records for Python-side access if needed.
"""
instrument_id: int
n_events: int
# Contiguous arrays for C++ Engine
actions: np.ndarray # int8
sides: np.ndarray # int8
prices: np.ndarray # int64
sizes: np.ndarray # int64
order_ids: np.ndarray # uint64
ts_recvs: np.ndarray # uint64
flags: np.ndarray # uint8
def __repr__(self):
return f"<SimulationData instrument_id={self.instrument_id} n_events={self.n_events}>"
[docs]
class DataLoader:
def __init__(
self,
data_path: str,
start_time: str = None,
end_time: str = None,
timezone: str = "UTC",
):
self.data_path = data_path
self.start_time = start_time
self.end_time = end_time
self.timezone = timezone
def _get_data_source(self) -> Union[str, bytes]:
"""
Resolves the data source.
If the path starts with 'r2://', this method streams the file content directly
into memory using boto3. Otherwise, it returns the local file path.
Returns:
Union[str, bytes]: A local file path string, or bytes object containing
the file content if downloaded from S3.
"""
if self.data_path.startswith("r2://"):
from setup.r2 import require_r2
r2 = require_r2()
parsed = urlparse(self.data_path)
bucket = parsed.netloc
key = parsed.path.lstrip("/")
logger.info(f"Downloading from S3 to memory: {self.data_path}")
s3 = r2.to_client()
response = s3.get_object(Bucket=bucket, Key=key)
return response["Body"].read()
return self.data_path
[docs]
def load(self, instrument_id: int = None, n_events: int = None) -> SimulationData:
"""
Loads data from file (or S3), filters for instrument, and prepares contiguous arrays.
Args:
instrument_id (int, optional): Specific ID to filter for.
If None, defaults to the most active instrument in the file.
n_events (int, optional): Maximum number of events to load.
If None, loads all events. Useful for memory-efficient sampling.
Returns:
SimulationData: A dataclass containing contiguous (C-style) numpy arrays
ready for zero-copy access by the C++ engine.
"""
# Resolve source
data_source = self._get_data_source()
logger.info(f"Loading data...")
if isinstance(data_source, bytes):
store = db.DBNStore.from_bytes(data_source)
else:
store = db.DBNStore.from_file(data_source)
# If n_events is specified, use iterator for memory efficiency
if n_events is not None:
logger.info(f"Loading first {n_events} events (memory-efficient mode)...")
records_list = []
for i, record in enumerate(store):
if i >= n_events:
break
records_list.append(record)
if not records_list:
return SimulationData(
instrument_id=instrument_id or 0,
n_events=0,
actions=np.array([], dtype=np.int8),
sides=np.array([], dtype=np.int8),
prices=np.array([], dtype=np.int64),
sizes=np.array([], dtype=np.int64),
order_ids=np.array([], dtype=np.uint64),
ts_recvs=np.array([], dtype=np.uint64),
flags=np.array([], dtype=np.uint8),
)
# Convert list of records to structured array
all_records = np.array(
records_list,
dtype=records_list[0].dtype
if hasattr(records_list[0], "dtype")
else None,
)
if all_records.dtype.names is None:
# Fallback: load normally if iterator didn't return structured records
all_records = store.to_ndarray()[:n_events]
else:
all_records = store.to_ndarray()
# Auto-detect instrument if not provided
if instrument_id is None:
u, c = np.unique(all_records["instrument_id"], return_counts=True)
instrument_id = u[np.argmax(c)]
logger.info(f"Auto-detected most active instrument: {instrument_id}")
logger.info(f"Filtering for instrument: {instrument_id}")
records = all_records[all_records["instrument_id"] == instrument_id]
n_events = len(records)
logger.info(f"Processing {n_events} events...")
# Prepare Contiguous Memory Views
# Note: .view() on structured arrays creates non-contiguous strided views.
# Using np.ascontiguousarray() copies data to a dense buffer, which is safer for C++ interop.
actions = np.ascontiguousarray(records["action"].view("i1"))
sides = np.ascontiguousarray(records["side"].view("i1"))
prices = np.ascontiguousarray(records["price"].astype(np.int64))
sizes = np.ascontiguousarray(records["size"].astype(np.int64))
order_ids = np.ascontiguousarray(records["order_id"].astype(np.uint64))
ts_recvs = np.ascontiguousarray(records["ts_recv"].astype(np.uint64))
flags = np.ascontiguousarray(records["flags"].view("u1"))
if (
self.start_time
or self.end_time
or (self.start_time is None and self.end_time is None)
):
# If explicit times are None, we might still want to default end_time to Market Close?
# User request: "can we make the end time the end of the trading day by default?"
# This implies if self.end_time is None, we should look it up.
# We need at least one timestamp to know the date.
if len(ts_recvs) > 0:
# Get date from first event
# ts_recvs are uint64 nanoseconds
first_ts_ns = ts_recvs[0]
first_dt = pd.to_datetime(first_ts_ns, unit="ns", utc=True).tz_convert(
self.timezone
)
# User Request: "Make it clear that when doing order book reconstruction, we'll always start with the first thing...
# So, the start time is only used when calculating r-squared"
# Therefore, we DO NOT filter by start_time here. We simulate from the beginning of the file.
# But we DO filter by end_time to stop simulation early (perf optimization).
# S Time is effectively beginning of day
s_time = "00:00"
e_time = self.end_time
# Auto-detect Market Close if end_time is missing
if e_time is None:
from forecasting.data.calendar_utils import get_market_close
# Infer exchange from filename if possible
fname = self.data_path.lower()
exchange_code = "NYSE" # Default US Equity
if "xnas" in fname:
exchange_code = "NASDAQ"
elif "xnys" in fname:
exchange_code = "NYSE"
try:
e_time = get_market_close(first_dt, exchange=exchange_code)
logger.info(
f" Auto-detected Market Close for {first_dt.date()} ({exchange_code}): {e_time}"
)
except Exception as e:
print(
f"Warning: Failed to detect market close: {e}. Defaulting to 23:59:59."
)
logger.warning(
f" Failed to detect market close: {e}. Defaulting to 23:59:59."
)
e_time = "23:59:59.999999"
logger.info(
f"Filtering Time (Loader): Start={s_time} (Include Pre-Market), End={e_time} ({self.timezone})"
)
ts_index = pd.to_datetime(ts_recvs, unit="ns", utc=True).tz_convert(
self.timezone
)
dummy = pd.Series(np.arange(len(ts_index)), index=ts_index)
# Use between_time.
subset = dummy.between_time(s_time, e_time)
valid_indices = subset.values
logger.info(f" Kept {len(valid_indices)}/{len(ts_index)} events")
# Apply Filter
actions = actions[valid_indices]
sides = sides[valid_indices]
prices = prices[valid_indices]
sizes = sizes[valid_indices]
order_ids = order_ids[valid_indices]
ts_recvs = ts_recvs[valid_indices]
flags = flags[valid_indices]
n_events = len(actions)
else:
logger.warning("No events found to determine date for time filtering.")
return SimulationData(
instrument_id=instrument_id,
n_events=n_events,
actions=actions,
sides=sides,
prices=prices,
sizes=sizes,
order_ids=order_ids,
ts_recvs=ts_recvs,
flags=flags,
)