diff --git a/benchmark_utils/base_solver.py b/benchmark_utils/base_solver.py index cd82ebd..577aff3 100644 --- a/benchmark_utils/base_solver.py +++ b/benchmark_utils/base_solver.py @@ -53,6 +53,10 @@ class BaseTSFMSolver(BaseSolver): dtype The data type of both data and model. Default to bfloat16 on CUDA, float32 elsewhere. + + inference_batch_size + Max number of (series, cutoff) windows sent to the model per + forward call. Subclasses may override this to tune memory/throughput. """ supported_tasks: set[TaskType] @@ -77,6 +81,10 @@ def __init__(self, **kwargs: Any) -> None: # Initialize cached model state self._loaded_model = None self.model = None + + # Possibly overridden by kwargs: + self.inference_batch_size = 128 + for key, value in kwargs.items(): setattr(self, key, value) @@ -257,10 +265,12 @@ def forecast( if not inputs: return ForecastOutput(quantiles=[], quantile_levels=quantile_levels) - # TODO We still do this in batches in case data is very large - - # Get a list of model outputs aligned with inputs - raw = self.forecast_batch(inputs, covariates) + # Run in batches so very large datasets do not go through the model + # as a single oversized batch. Maintains ordering. + raw: list[torch.Tensor] = [] + for start in range(0, len(inputs), self.inference_batch_size): + end = start + self.inference_batch_size + raw.extend(self.forecast_batch(inputs[start:end], covariates[start:end])) per_series_preds = [ [None] * n_cutoffs for _, n_cutoffs in per_series_shape