Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/weathergen/datasets/multi_stream_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from weathergen.readers_extra.registry import get_extra_reader
from weathergen.train.utils import get_batch_size_from_config
from weathergen.utils.distributed import is_root
from weathergen.utils.train_logger import TRAIN, Stage
from weathergen.utils.train_logger import Stage

type AnyDataReader = DataReaderBase | DataReaderAnemoi | DataReaderObs
type StreamName = str
Expand Down Expand Up @@ -538,9 +538,11 @@ def _get_data_windows(self, base_idx, num_forecast_steps, num_steps_input_max, s

rdata = collect_datasources(stream_ds, idx, "source", self.rng)

if rdata.is_empty() and self._stage == TRAIN:
if rdata.is_empty():
# work around for https://github.com/pytorch/pytorch/issues/158719
# create non-empty mean data instead of empty tensor
# Also needed to ensure input_tokens has same length as input_data
# (get_tokens_windows skips empty data, causing IndexError later)
time_win = self.time_window_handler.window(idx)
rdata = spoof(
self.healpix_level,
Expand All @@ -560,9 +562,10 @@ def _get_data_windows(self, base_idx, num_forecast_steps, num_steps_input_max, s

rdata = collect_datasources(stream_ds, step_forecast_dt, "target", self.rng)

if rdata.is_empty() and self._stage == TRAIN:
if rdata.is_empty():
# work around for https://github.com/pytorch/pytorch/issues/158719
# create non-empty mean data instead of empty tensor
# Also needed to ensure output_tokens has same length as output_data
time_win = self.time_window_handler.window(timestep_idx)
rdata = spoof(
self.healpix_level,
Expand Down