diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 71a133da9..90c20fd04 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -33,7 +33,7 @@ from weathergen.readers_extra.registry import get_extra_reader from weathergen.train.utils import get_batch_size_from_config from weathergen.utils.distributed import is_root -from weathergen.utils.train_logger import TRAIN, Stage +from weathergen.utils.train_logger import Stage type AnyDataReader = DataReaderBase | DataReaderAnemoi | DataReaderObs type StreamName = str @@ -538,9 +538,11 @@ def _get_data_windows(self, base_idx, num_forecast_steps, num_steps_input_max, s rdata = collect_datasources(stream_ds, idx, "source", self.rng) - if rdata.is_empty() and self._stage == TRAIN: + if rdata.is_empty(): # work around for https://github.com/pytorch/pytorch/issues/158719 # create non-empty mean data instead of empty tensor + # Also needed to ensure input_tokens has same length as input_data + # (get_tokens_windows skips empty data, causing IndexError later) time_win = self.time_window_handler.window(idx) rdata = spoof( self.healpix_level, @@ -560,9 +562,10 @@ def _get_data_windows(self, base_idx, num_forecast_steps, num_steps_input_max, s rdata = collect_datasources(stream_ds, step_forecast_dt, "target", self.rng) - if rdata.is_empty() and self._stage == TRAIN: + if rdata.is_empty(): # work around for https://github.com/pytorch/pytorch/issues/158719 # create non-empty mean data instead of empty tensor + # Also needed to ensure output_tokens has same length as output_data time_win = self.time_window_handler.window(timestep_idx) rdata = spoof( self.healpix_level,