From e81f66dbac8a0de466060486f5bd24f340ea3b2e Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Fri, 6 Feb 2026 15:34:56 +0100 Subject: [PATCH 1/2] remove the and stage = train for the if condition for spoofing in _get_data_windows --- src/weathergen/datasets/multi_stream_data_sampler.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 83d436fb9..86cd25193 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -537,10 +537,12 @@ def _get_data_windows(self, base_idx, num_forecast_steps, num_steps_input_max, s # TODO: check that we are not out of bounds when we go back in time rdata = collect_datasources(stream_ds, idx, "source", self.rng) - - if rdata.is_empty() and self._stage == TRAIN: + + if rdata.is_empty(): # work around for https://github.com/pytorch/pytorch/issues/158719 # create non-empty mean data instead of empty tensor + # Also needed to ensure input_tokens has same length as input_data + # (get_tokens_windows skips empty data, causing IndexError later) time_win = self.time_window_handler.window(idx) rdata = spoof( self.healpix_level, @@ -560,9 +562,10 @@ def _get_data_windows(self, base_idx, num_forecast_steps, num_steps_input_max, s rdata = collect_datasources(stream_ds, step_forecast_dt, "target", self.rng) - if rdata.is_empty() and self._stage == TRAIN: + if rdata.is_empty(): # work around for https://github.com/pytorch/pytorch/issues/158719 # create non-empty mean data instead of empty tensor + # Also needed to ensure output_tokens has same length as output_data time_win = self.time_window_handler.window(timestep_idx) rdata = spoof( self.healpix_level, From d96aff2fd7c5acdc8b5df0776f6d8b55113cefee Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Fri, 6 Feb 2026 15:39:08 +0100 Subject: [PATCH 2/2] lint --- src/weathergen/datasets/multi_stream_data_sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 86cd25193..a5a575de3 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -33,7 +33,7 @@ from weathergen.readers_extra.registry import get_extra_reader from weathergen.train.utils import get_batch_size_from_config from weathergen.utils.distributed import is_root -from weathergen.utils.train_logger import TRAIN, Stage +from weathergen.utils.train_logger import Stage type AnyDataReader = DataReaderBase | DataReaderAnemoi | DataReaderObs type StreamName = str @@ -537,7 +537,7 @@ def _get_data_windows(self, base_idx, num_forecast_steps, num_steps_input_max, s # TODO: check that we are not out of bounds when we go back in time rdata = collect_datasources(stream_ds, idx, "source", self.rng) - + if rdata.is_empty(): # work around for https://github.com/pytorch/pytorch/issues/158719 # create non-empty mean data instead of empty tensor