diff --git a/the_well/__init__.py b/the_well/__init__.py index fe54b95b..c7dec85a 100755 --- a/the_well/__init__.py +++ b/the_well/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.2.0" +__version__ = "1.2.1" __all__ = ["__version__"] diff --git a/the_well/benchmark/trainer/training.py b/the_well/benchmark/trainer/training.py index ffb057ba..7d9e0ad7 100755 --- a/the_well/benchmark/trainer/training.py +++ b/the_well/benchmark/trainer/training.py @@ -140,6 +140,7 @@ def __init__( self.best_val_loss = None self.starting_val_loss = float("inf") self.dset_metadata = self.datamodule.train_dataset.metadata + self.dset_norm = None if self.datamodule.train_dataset.use_normalization: self.dset_norm = self.datamodule.train_dataset.norm if formatter == "channels_first_default": @@ -176,38 +177,51 @@ def load_checkpoint(self, checkpoint_path: str): checkpoint["epoch"] + 1 ) # Saves after training loop, so start at next epoch - def normalize(self, batch): + def normalize(self, batch_dict=None, direct_tensor=None): if hasattr(self, "dset_norm") and self.dset_norm: - batch["input_fields"] = self.dset_norm.normalize_flattened( - batch["input_fields"], "variable" - ) - if "constant_fields" in batch: - batch["constant_fields"] = self.dset_norm.normalize_flattened( - batch["constant_fields"], "constant" + if batch_dict is not None: + batch_dict["input_fields"] = self.dset_norm.normalize_flattened( + batch_dict["input_fields"], "variable" ) - return batch + if "constant_fields" in batch_dict: + batch_dict["constant_fields"] = self.dset_norm.normalize_flattened( + batch_dict["constant_fields"], "constant" + ) + if direct_tensor is not None: + if self.is_delta: + direct_tensor = self.dset_norm.normalize_delta_flattened( + direct_tensor, "variable" + ) + else: + direct_tensor = self.dset_norm.normalize_flattened( + direct_tensor, "variable" + ) + return batch_dict, direct_tensor - def denormalize(self, batch, prediction): + def denormalize(self, batch_dict=None, direct_tensor=None): if hasattr(self, "dset_norm") and self.dset_norm: - batch["input_fields"] = self.dset_norm.denormalize_flattened( - batch["input_fields"], "variable" - ) - if "constant_fields" in batch: - batch["constant_fields"] = self.dset_norm.denormalize_flattened( - batch["constant_fields"], "constant" - ) - - # Delta denormalization is different than full denormalization - if self.is_delta: - prediction = self.dset_norm.delta_denormalize_flattened( - prediction, "variable" - ) - else: - prediction = self.dset_norm.denormalize_flattened( - prediction, "variable" + if batch_dict is not None: + batch_dict["input_fields"] = self.dset_norm.denormalize_flattened( + batch_dict["input_fields"], "variable" ) + if "constant_fields" in batch_dict: + batch_dict["constant_fields"] = ( + self.dset_norm.denormalize_flattened( + batch_dict["constant_fields"], "constant" + ) + ) + if direct_tensor is not None: + # Delta denormalization is different than full denormalization + if self.is_delta: + direct_tensor = self.dset_norm.delta_denormalize_flattened( + direct_tensor, "variable" + ) + else: + direct_tensor = self.dset_norm.denormalize_flattened( + direct_tensor, "variable" + ) - return batch, prediction + return batch_dict, direct_tensor def rollout_model(self, model, batch, formatter, train=True): """Rollout the model for as many steps as we have data for.""" @@ -216,8 +230,12 @@ def rollout_model(self, model, batch, formatter, train=True): y_ref.shape[1], self.max_rollout_steps ) # Number of timesteps in target y_ref = y_ref[:, :rollout_steps] + # NOTE: This is a quick fix so we can make datamodule behavior consistent. Revisit this next release (MM). + if not train: + _, y_ref = self.denormalize(None, y_ref) + # Create a moving batch of one step at a time - moving_batch = batch + moving_batch = dict(batch) moving_batch["input_fields"] = moving_batch["input_fields"].to(self.device) if "constant_fields" in moving_batch: moving_batch["constant_fields"] = moving_batch["constant_fields"].to( @@ -225,22 +243,24 @@ def rollout_model(self, model, batch, formatter, train=True): ) y_preds = [] for i in range(rollout_steps): - if not train: - moving_batch = self.normalize(moving_batch) + # NOTE: This is a quick fix so we can make datamodule behavior consistent. + # Including local normalization schemes means there needs to be the option of normalizing each step + # and there's currently not a registry of local vs global normalization schemes. + if not train and self.datamodule.val_dataset.use_normalization and i > 0: + moving_batch, _ = self.normalize(moving_batch) inputs, _ = formatter.process_input(moving_batch) inputs = [x.to(self.device) for x in inputs] y_pred = model(*inputs) y_pred = formatter.process_output_channel_last(y_pred) - if not train: moving_batch, y_pred = self.denormalize(moving_batch, y_pred) if (not train) and self.is_delta: - assert { + assert ( moving_batch["input_fields"][:, -1, ...].shape == y_pred.shape - }, f"Mismatching shapes between last input timestep {moving_batch[:, -1, ...].shape}\ + ), f"Mismatching shapes between last input timestep {moving_batch[:, -1, ...].shape}\ and prediction {y_pred.shape}" y_pred = moving_batch["input_fields"][:, -1, ...] + y_pred y_pred = formatter.process_output_expand_time(y_pred) diff --git a/the_well/data/datamodule.py b/the_well/data/datamodule.py index 3e037ec2..f3660a55 100755 --- a/the_well/data/datamodule.py +++ b/the_well/data/datamodule.py @@ -163,6 +163,8 @@ def __init__( well_split_name="valid", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, n_steps_input=n_steps_input, n_steps_output=n_steps_output, storage_options=storage_kwargs, @@ -181,6 +183,8 @@ def __init__( well_split_name="valid", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, max_rollout_steps=max_rollout_steps, n_steps_input=n_steps_input, n_steps_output=n_steps_output, @@ -201,6 +205,8 @@ def __init__( well_split_name="test", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, n_steps_input=n_steps_input, n_steps_output=n_steps_output, storage_options=storage_kwargs, @@ -219,6 +225,8 @@ def __init__( well_split_name="test", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, max_rollout_steps=max_rollout_steps, n_steps_input=n_steps_input, n_steps_output=n_steps_output,