From e5be74179d7fda6c4046c1a49a18ca608dfd9ad1 Mon Sep 17 00:00:00 2001 From: Felix Divo Date: Sat, 13 Jun 2026 21:20:18 +0000 Subject: [PATCH 1/2] FEAT Forecasting: batch large datasets in BaseTSFMSolver.forecast Split the flattened (series, cutoff) windows into chunks of inference_batch_size before calling forecast_batch, instead of sending the whole dataset as one oversized batch. Reconstruction is unchanged since the per-batch outputs stay aligned with the flat input list. Closes #43 Co-Authored-By: Claude Opus 4.8 --- benchmark_utils/base_solver.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/benchmark_utils/base_solver.py b/benchmark_utils/base_solver.py index cd82ebd..3801228 100644 --- a/benchmark_utils/base_solver.py +++ b/benchmark_utils/base_solver.py @@ -53,6 +53,10 @@ class BaseTSFMSolver(BaseSolver): dtype The data type of both data and model. Default to bfloat16 on CUDA, float32 elsewhere. + + inference_batch_size + Max number of (series, cutoff) windows sent to the model per + forecast_batch call. Subclasses may override to tune memory/throughput. """ supported_tasks: set[TaskType] @@ -77,6 +81,9 @@ def __init__(self, **kwargs: Any) -> None: # Initialize cached model state self._loaded_model = None self.model = None + # Max number of (series, cutoff) windows sent to the model per call. + # Subclasses may override to tune the memory/throughput trade-off. + self.inference_batch_size = 128 for key, value in kwargs.items(): setattr(self, key, value) @@ -257,10 +264,13 @@ def forecast( if not inputs: return ForecastOutput(quantiles=[], quantile_levels=quantile_levels) - # TODO We still do this in batches in case data is very large - - # Get a list of model outputs aligned with inputs - raw = self.forecast_batch(inputs, covariates) + # Run in batches so very large datasets do not go through the model + # as a single oversized batch. raw stays a flat list aligned with + # inputs, so the reconstruction below is unaffected. + raw: list[torch.Tensor] = [] + for start in range(0, len(inputs), self.inference_batch_size): + end = start + self.inference_batch_size + raw.extend(self.forecast_batch(inputs[start:end], covariates[start:end])) per_series_preds = [ [None] * n_cutoffs for _, n_cutoffs in per_series_shape From af6e0f0630d29d139a864d44b049b89f63e99bef Mon Sep 17 00:00:00 2001 From: Felix Divo Date: Sat, 13 Jun 2026 21:24:18 +0000 Subject: [PATCH 2/2] Cleanup after AI --- benchmark_utils/base_solver.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/benchmark_utils/base_solver.py b/benchmark_utils/base_solver.py index 3801228..577aff3 100644 --- a/benchmark_utils/base_solver.py +++ b/benchmark_utils/base_solver.py @@ -56,7 +56,7 @@ class BaseTSFMSolver(BaseSolver): inference_batch_size Max number of (series, cutoff) windows sent to the model per - forecast_batch call. Subclasses may override to tune memory/throughput. + forward call. Subclasses may override this to tune memory/throughput. """ supported_tasks: set[TaskType] @@ -81,9 +81,10 @@ def __init__(self, **kwargs: Any) -> None: # Initialize cached model state self._loaded_model = None self.model = None - # Max number of (series, cutoff) windows sent to the model per call. - # Subclasses may override to tune the memory/throughput trade-off. + + # Possibly overridden by kwargs: self.inference_batch_size = 128 + for key, value in kwargs.items(): setattr(self, key, value) @@ -265,8 +266,7 @@ def forecast( return ForecastOutput(quantiles=[], quantile_levels=quantile_levels) # Run in batches so very large datasets do not go through the model - # as a single oversized batch. raw stays a flat list aligned with - # inputs, so the reconstruction below is unaffected. + # as a single oversized batch. Maintains ordering. raw: list[torch.Tensor] = [] for start in range(0, len(inputs), self.inference_batch_size): end = start + self.inference_batch_size