From 715226c9ecc4d996188cb28fe0397114d7df997f Mon Sep 17 00:00:00 2001 From: Mike McCabe Date: Wed, 10 Dec 2025 13:53:13 -0500 Subject: [PATCH 1/6] Move normalization into dataloader - perform denormalization in file --- the_well/benchmark/trainer/training.py | 57 +++++++++++++++----------- the_well/data/datamodule.py | 8 ++++ 2 files changed, 42 insertions(+), 23 deletions(-) diff --git a/the_well/benchmark/trainer/training.py b/the_well/benchmark/trainer/training.py index ffb057ba..ba5587bc 100755 --- a/the_well/benchmark/trainer/training.py +++ b/the_well/benchmark/trainer/training.py @@ -176,38 +176,45 @@ def load_checkpoint(self, checkpoint_path: str): checkpoint["epoch"] + 1 ) # Saves after training loop, so start at next epoch - def normalize(self, batch): + def normalize(self, batch_dict=None, direct_tensor=None): if hasattr(self, "dset_norm") and self.dset_norm: - batch["input_fields"] = self.dset_norm.normalize_flattened( - batch["input_fields"], "variable" - ) - if "constant_fields" in batch: - batch["constant_fields"] = self.dset_norm.normalize_flattened( - batch["constant_fields"], "constant" + if batch_dict is not None: + batch_dict["input_fields"] = self.dset_norm.normalize_flattened( + batch_dict["input_fields"], "variable" + ) + if "constant_fields" in batch_dict: + batch_dict["constant_fields"] = self.dset_norm.normalize_flattened( + batch_dict["constant_fields"], "constant" + ) + if direct_tensor is not None: + direct_tensor = self.dset_norm.normalize_flattened( + direct_tensor, "variable" ) - return batch + return batch_dict, direct_tensor + return batch_dict, direct_tensor - def denormalize(self, batch, prediction): + def denormalize(self, batch_dict=None, direct_tensor=None): if hasattr(self, "dset_norm") and self.dset_norm: - batch["input_fields"] = self.dset_norm.denormalize_flattened( - batch["input_fields"], "variable" - ) - if "constant_fields" in batch: - batch["constant_fields"] = self.dset_norm.denormalize_flattened( - batch["constant_fields"], "constant" + if batch_dict is not None: + batch_dict["input_fields"] = self.dset_norm.denormalize_flattened( + batch_dict["input_fields"], "variable" ) + if "constant_fields" in batch_dict: + batch_dict["constant_fields"] = self.dset_norm.denormalize_flattened( + batch_dict["constant_fields"], "constant" + ) # Delta denormalization is different than full denormalization if self.is_delta: - prediction = self.dset_norm.delta_denormalize_flattened( - prediction, "variable" + direct_tensor = self.dset_norm.delta_denormalize_flattened( + direct_tensor, "variable" ) else: - prediction = self.dset_norm.denormalize_flattened( - prediction, "variable" + direct_tensor = self.dset_norm.denormalize_flattened( + direct_tensor, "variable" ) - return batch, prediction + return batch_dict, direct_tensor def rollout_model(self, model, batch, formatter, train=True): """Rollout the model for as many steps as we have data for.""" @@ -216,6 +223,10 @@ def rollout_model(self, model, batch, formatter, train=True): y_ref.shape[1], self.max_rollout_steps ) # Number of timesteps in target y_ref = y_ref[:, :rollout_steps] + # NOTE: This is a quick fix so we can make datamodule behavior consistent. Revisit this next release (MM). + if not train: + _, y_ref = self.denormalize(None, y_ref) + # Create a moving batch of one step at a time moving_batch = batch moving_batch["input_fields"] = moving_batch["input_fields"].to(self.device) @@ -225,15 +236,15 @@ def rollout_model(self, model, batch, formatter, train=True): ) y_preds = [] for i in range(rollout_steps): - if not train: - moving_batch = self.normalize(moving_batch) + # NOTE: This is a quick fix so we can make datamodule behavior consistent. Revisit this next release (MM). + if i > 0 and not train: + moving_batch, _ = self.normalize(moving_batch) inputs, _ = formatter.process_input(moving_batch) inputs = [x.to(self.device) for x in inputs] y_pred = model(*inputs) y_pred = formatter.process_output_channel_last(y_pred) - if not train: moving_batch, y_pred = self.denormalize(moving_batch, y_pred) diff --git a/the_well/data/datamodule.py b/the_well/data/datamodule.py index 3e037ec2..f3660a55 100755 --- a/the_well/data/datamodule.py +++ b/the_well/data/datamodule.py @@ -163,6 +163,8 @@ def __init__( well_split_name="valid", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, n_steps_input=n_steps_input, n_steps_output=n_steps_output, storage_options=storage_kwargs, @@ -181,6 +183,8 @@ def __init__( well_split_name="valid", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, max_rollout_steps=max_rollout_steps, n_steps_input=n_steps_input, n_steps_output=n_steps_output, @@ -201,6 +205,8 @@ def __init__( well_split_name="test", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, n_steps_input=n_steps_input, n_steps_output=n_steps_output, storage_options=storage_kwargs, @@ -219,6 +225,8 @@ def __init__( well_split_name="test", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, max_rollout_steps=max_rollout_steps, n_steps_input=n_steps_input, n_steps_output=n_steps_output, From 1f626ac97622f93b593a40b8bfbedb62a400b693 Mon Sep 17 00:00:00 2001 From: Mike McCabe Date: Wed, 10 Dec 2025 14:02:09 -0500 Subject: [PATCH 2/6] Increment hotfix version --- the_well/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/the_well/__init__.py b/the_well/__init__.py index fe54b95b..c7dec85a 100755 --- a/the_well/__init__.py +++ b/the_well/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.2.0" +__version__ = "1.2.1" __all__ = ["__version__"] From 5e22beec296ef99447bcc901f452d66cf591ec1b Mon Sep 17 00:00:00 2001 From: Mike McCabe Date: Wed, 10 Dec 2025 14:07:15 -0500 Subject: [PATCH 3/6] Finish synchronizing the norm/denorm funcs --- the_well/benchmark/trainer/training.py | 32 +++++++++++++++----------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/the_well/benchmark/trainer/training.py b/the_well/benchmark/trainer/training.py index ba5587bc..11c86543 100755 --- a/the_well/benchmark/trainer/training.py +++ b/the_well/benchmark/trainer/training.py @@ -187,10 +187,14 @@ def normalize(self, batch_dict=None, direct_tensor=None): batch_dict["constant_fields"], "constant" ) if direct_tensor is not None: - direct_tensor = self.dset_norm.normalize_flattened( - direct_tensor, "variable" - ) - return batch_dict, direct_tensor + if self.is_delta: + direct_tensor = self.dset_norm.normalize_delta_flattened( + direct_tensor, "variable" + ) + else: + direct_tensor = self.dset_norm.normalize_flattened( + direct_tensor, "variable" + ) return batch_dict, direct_tensor def denormalize(self, batch_dict=None, direct_tensor=None): @@ -203,16 +207,16 @@ def denormalize(self, batch_dict=None, direct_tensor=None): batch_dict["constant_fields"] = self.dset_norm.denormalize_flattened( batch_dict["constant_fields"], "constant" ) - - # Delta denormalization is different than full denormalization - if self.is_delta: - direct_tensor = self.dset_norm.delta_denormalize_flattened( - direct_tensor, "variable" - ) - else: - direct_tensor = self.dset_norm.denormalize_flattened( - direct_tensor, "variable" - ) + if direct_tensor is not None: + # Delta denormalization is different than full denormalization + if self.is_delta: + direct_tensor = self.dset_norm.delta_denormalize_flattened( + direct_tensor, "variable" + ) + else: + direct_tensor = self.dset_norm.denormalize_flattened( + direct_tensor, "variable" + ) return batch_dict, direct_tensor From 1540ae73d75a8384b1dfbe9ae16edc82e28063b0 Mon Sep 17 00:00:00 2001 From: Mike McCabe Date: Wed, 10 Dec 2025 14:22:05 -0500 Subject: [PATCH 4/6] linter --- the_well/benchmark/trainer/training.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/the_well/benchmark/trainer/training.py b/the_well/benchmark/trainer/training.py index 11c86543..8509c1ab 100755 --- a/the_well/benchmark/trainer/training.py +++ b/the_well/benchmark/trainer/training.py @@ -204,8 +204,10 @@ def denormalize(self, batch_dict=None, direct_tensor=None): batch_dict["input_fields"], "variable" ) if "constant_fields" in batch_dict: - batch_dict["constant_fields"] = self.dset_norm.denormalize_flattened( - batch_dict["constant_fields"], "constant" + batch_dict["constant_fields"] = ( + self.dset_norm.denormalize_flattened( + batch_dict["constant_fields"], "constant" + ) ) if direct_tensor is not None: # Delta denormalization is different than full denormalization @@ -241,7 +243,7 @@ def rollout_model(self, model, batch, formatter, train=True): y_preds = [] for i in range(rollout_steps): # NOTE: This is a quick fix so we can make datamodule behavior consistent. Revisit this next release (MM). - if i > 0 and not train: + if i > 0 and not train: moving_batch, _ = self.normalize(moving_batch) inputs, _ = formatter.process_input(moving_batch) From c674b86b747748aab809fde939c25825628e8708 Mon Sep 17 00:00:00 2001 From: Mike McCabe Date: Tue, 24 Feb 2026 19:07:31 -0500 Subject: [PATCH 5/6] Add some more explicit guardrails --- the_well/benchmark/trainer/training.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/the_well/benchmark/trainer/training.py b/the_well/benchmark/trainer/training.py index 8509c1ab..459ef594 100755 --- a/the_well/benchmark/trainer/training.py +++ b/the_well/benchmark/trainer/training.py @@ -140,6 +140,7 @@ def __init__( self.best_val_loss = None self.starting_val_loss = float("inf") self.dset_metadata = self.datamodule.train_dataset.metadata + self.dset_norm = None if self.datamodule.train_dataset.use_normalization: self.dset_norm = self.datamodule.train_dataset.norm if formatter == "channels_first_default": @@ -234,7 +235,7 @@ def rollout_model(self, model, batch, formatter, train=True): _, y_ref = self.denormalize(None, y_ref) # Create a moving batch of one step at a time - moving_batch = batch + moving_batch = dict(batch) moving_batch["input_fields"] = moving_batch["input_fields"].to(self.device) if "constant_fields" in moving_batch: moving_batch["constant_fields"] = moving_batch["constant_fields"].to( @@ -242,8 +243,10 @@ def rollout_model(self, model, batch, formatter, train=True): ) y_preds = [] for i in range(rollout_steps): - # NOTE: This is a quick fix so we can make datamodule behavior consistent. Revisit this next release (MM). - if i > 0 and not train: + # NOTE: This is a quick fix so we can make datamodule behavior consistent. + # Including local normalization schemes means there needs to be the option of normalizing each step + # and there's currently not a registry of local vs global normalization schemes. + if self.datamodule.val_dataset.use_normalization and i > 0 and not train: moving_batch, _ = self.normalize(moving_batch) inputs, _ = formatter.process_input(moving_batch) @@ -255,9 +258,9 @@ def rollout_model(self, model, batch, formatter, train=True): moving_batch, y_pred = self.denormalize(moving_batch, y_pred) if (not train) and self.is_delta: - assert { + assert ( moving_batch["input_fields"][:, -1, ...].shape == y_pred.shape - }, f"Mismatching shapes between last input timestep {moving_batch[:, -1, ...].shape}\ + ), f"Mismatching shapes between last input timestep {moving_batch[:, -1, ...].shape}\ and prediction {y_pred.shape}" y_pred = moving_batch["input_fields"][:, -1, ...] + y_pred y_pred = formatter.process_output_expand_time(y_pred) From 477565879adf766dbf82ba36032c15a6bed35452 Mon Sep 17 00:00:00 2001 From: Mike McCabe Date: Tue, 24 Feb 2026 19:17:22 -0500 Subject: [PATCH 6/6] Minor order update to prevent error in debugging scenario linter --- the_well/benchmark/trainer/training.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/the_well/benchmark/trainer/training.py b/the_well/benchmark/trainer/training.py index 459ef594..7d9e0ad7 100755 --- a/the_well/benchmark/trainer/training.py +++ b/the_well/benchmark/trainer/training.py @@ -243,10 +243,10 @@ def rollout_model(self, model, batch, formatter, train=True): ) y_preds = [] for i in range(rollout_steps): - # NOTE: This is a quick fix so we can make datamodule behavior consistent. + # NOTE: This is a quick fix so we can make datamodule behavior consistent. # Including local normalization schemes means there needs to be the option of normalizing each step - # and there's currently not a registry of local vs global normalization schemes. - if self.datamodule.val_dataset.use_normalization and i > 0 and not train: + # and there's currently not a registry of local vs global normalization schemes. + if not train and self.datamodule.val_dataset.use_normalization and i > 0: moving_batch, _ = self.normalize(moving_batch) inputs, _ = formatter.process_input(moving_batch)