Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions benchmark_utils/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class BaseTSFMSolver(BaseSolver):
dtype
The data type of both data and model.
Default to bfloat16 on CUDA, float32 elsewhere.

inference_batch_size
Max number of (series, cutoff) windows sent to the model per
forward call. Subclasses may override this to tune memory/throughput.
"""

supported_tasks: set[TaskType]
Expand All @@ -77,6 +81,10 @@ def __init__(self, **kwargs: Any) -> None:
# Initialize cached model state
self._loaded_model = None
self.model = None

# Possibly overridden by kwargs:
self.inference_batch_size = 128

for key, value in kwargs.items():
setattr(self, key, value)

Expand Down Expand Up @@ -257,10 +265,12 @@ def forecast(
if not inputs:
return ForecastOutput(quantiles=[], quantile_levels=quantile_levels)

# TODO We still do this in batches in case data is very large

# Get a list of model outputs aligned with inputs
raw = self.forecast_batch(inputs, covariates)
# Run in batches so very large datasets do not go through the model
# as a single oversized batch. Maintains ordering.
raw: list[torch.Tensor] = []
for start in range(0, len(inputs), self.inference_batch_size):
end = start + self.inference_batch_size
raw.extend(self.forecast_batch(inputs[start:end], covariates[start:end]))

per_series_preds = [
[None] * n_cutoffs for _, n_cutoffs in per_series_shape
Expand Down
Loading