From 11d1de0d0ab6cf5a732703b79cbd0eaf48dfbed9 Mon Sep 17 00:00:00 2001 From: LimitingFactor Date: Sat, 23 Sep 2023 14:25:24 +0100 Subject: [PATCH 1/4] improved the inference plotting, fixed the checkpoint loading --- .dockerignore | 3 + examples/cfd/vortex_shedding_mgn/constants.py | 44 ++- examples/cfd/vortex_shedding_mgn/inference.py | 195 ++++++++--- examples/cfd/vortex_shedding_mgn/train.py | 38 ++- .../cfd/vortex_shedding_mgn/wandb_train.py | 312 ++++++++++++++++++ modulus/launch/logging/wandb.py | 4 +- modulus/launch/utils/checkpoint.py | 3 +- setup.py | 4 +- 8 files changed, 539 insertions(+), 64 deletions(-) create mode 100644 .dockerignore create mode 100644 examples/cfd/vortex_shedding_mgn/wandb_train.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..99fd87a --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +examples/cfd/vortex_shedding_mgn/checkpoints* +examples/cfd/vortex_shedding_mgn/raw_dataset* +examples/cfd/vortex_shedding_mgn/wandb* \ No newline at end of file diff --git a/examples/cfd/vortex_shedding_mgn/constants.py b/examples/cfd/vortex_shedding_mgn/constants.py index 0dc81ee..b8005a3 100644 --- a/examples/cfd/vortex_shedding_mgn/constants.py +++ b/examples/cfd/vortex_shedding_mgn/constants.py @@ -21,33 +21,57 @@ class Constants(BaseModel): """vortex shedding constants""" + # Model name + model_name: str = "test_2" + # data configs - data_dir: str = "./raw_dataset/cylinder_flow/cylinder_flow" + data_dir: str = "/home/swifta/modulus/datasets/cylinder_flow/cylinder_flow" # training configs batch_size: int = 1 epochs: int = 25 - num_training_samples: int = 400 - num_training_time_steps: int = 300 + num_training_samples: int = 1000 + num_training_time_steps: int = 600 + training_noise_std: float = 0.02 + + num_valid_samples: int = 100 + num_valid_time_steps: int = 600 + lr: float = 0.0001 lr_decay_rate: float = 0.9999991 + ckpt_path: str = "checkpoints_test_3" + ckpt_name: str = "test_3.pt" + + # Mesh Graph Net Setup num_input_features: int = 6 - num_output_features: int = 3 num_edge_features: int = 3 - ckpt_path: str = "checkpoints" - ckpt_name: str = "model.pt" + num_output_features: int = 3 + processor_size: int = 15 + num_layers_node_processor: int = 2 + num_layers_edge_processor: int = 2 + hidden_dim_processor: int = 128 + hidden_dim_node_encoder: int = 128 + num_layers_node_encoder: int = 2 + hidden_dim_edge_encoder: int = 128 + num_layers_edge_encoder: int = 2 + hidden_dim_node_decoder: int = 128 + num_layers_node_decoder: int = 2 + aggregation: str = "sum" + do_concat_trick: bool = False + num_processor_checkpoint_segments: int = 0 + activation_fn: str = "silu" # performance configs amp: bool = False jit: bool = False # test & visualization configs - num_test_samples: int = 10 - num_test_time_steps: int = 300 + num_test_samples: int = 100 + num_test_time_steps: int = 600 viz_vars: Tuple[str, ...] = ("u", "v", "p") frame_skip: int = 10 frame_interval: int = 1 # wb configs - wandb_mode: str = "disabled" - watch_model: bool = False + wandb_mode: str = "online" + watch_model: bool = True diff --git a/examples/cfd/vortex_shedding_mgn/inference.py b/examples/cfd/vortex_shedding_mgn/inference.py index 5373d80..8ee4cb9 100644 --- a/examples/cfd/vortex_shedding_mgn/inference.py +++ b/examples/cfd/vortex_shedding_mgn/inference.py @@ -12,28 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch, dgl -from dgl.dataloading import GraphDataLoader -import torch +import os + import matplotlib.pyplot as plt -import numpy as np +import torch +from dgl.dataloading import GraphDataLoader from matplotlib import animation from matplotlib import tri as mtri -import os from matplotlib.patches import Rectangle - -from modulus.models.meshgraphnet import MeshGraphNet from modulus.datapipes.gnn.vortex_shedding_dataset import VortexSheddingDataset +from modulus.models.meshgraphnet import MeshGraphNet + +from constants import Constants from modulus.launch.logging import PythonLogger from modulus.launch.utils import load_checkpoint -from constants import Constants # Instantiate constants C = Constants() class MGNRollout: - def __init__(self, logger): + def __init__(self, logger, config): + self.config = config # set device self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using {self.device} device") @@ -41,10 +41,10 @@ def __init__(self, logger): # instantiate dataset self.dataset = VortexSheddingDataset( name="vortex_shedding_test", - data_dir=C.data_dir, + data_dir=config.data_dir, split="test", - num_samples=C.num_test_samples, - num_steps=C.num_test_time_steps, + num_samples=config.num_test_samples, + num_steps=config.num_test_time_steps, ) # instantiate dataloader @@ -57,19 +57,38 @@ def __init__(self, logger): # instantiate the model self.model = MeshGraphNet( - C.num_input_features, C.num_edge_features, C.num_output_features + input_dim_nodes=config.num_input_features, + input_dim_edges=config.num_edge_features, + output_dim=config.num_output_features, + processor_size=config.processor_size, + num_layers_node_processor=config.num_layers_node_processor, + num_layers_edge_processor=config.num_layers_edge_processor, + hidden_dim_processor=config.hidden_dim_processor, + hidden_dim_node_encoder=config.hidden_dim_node_encoder, + num_layers_node_encoder=config.num_layers_node_encoder, + hidden_dim_edge_encoder=config.hidden_dim_edge_encoder, + num_layers_edge_encoder=config.num_layers_edge_encoder, + hidden_dim_node_decoder=config.hidden_dim_node_decoder, + num_layers_node_decoder=config.num_layers_node_decoder, + aggregation=config.aggregation, + do_concat_trick=config.do_concat_trick, + num_processor_checkpoint_segments=config.num_processor_checkpoint_segments, + activation_fn=config.activation_fn, ) - if C.jit: + if config.jit: self.model = torch.jit.script(self.model).to(self.device) else: self.model = self.model.to(self.device) + # instantiate loss + self.criterion = torch.nn.MSELoss() + # enable train mode self.model.eval() # load checkpoint _ = load_checkpoint( - os.path.join(C.ckpt_path, C.ckpt_name), + path=os.path.join(config.ckpt_path, config.ckpt_name), models=self.model, device=self.device, ) @@ -77,12 +96,13 @@ def __init__(self, logger): self.var_identifier = {"u": 0, "v": 1, "p": 2} def predict(self): - self.pred, self.exact, self.faces, self.graphs = [], [], [], [] + self.pred, self.exact, self.faces, self.graphs, self.loss = [], [], [], [], [] stats = { key: value.to(self.device) for key, value in self.dataset.node_stats.items() } for i, (graph, cells, mask) in enumerate(self.dataloader): graph = graph.to(self.device) + # denormalize data graph.ndata["x"][:, 0:2] = self.dataset.denormalize( graph.ndata["x"][:, 0:2], stats["velocity_mean"], stats["velocity_std"] @@ -101,12 +121,14 @@ def predict(self): # inference step invar = graph.ndata["x"].clone() - if i % (C.num_test_time_steps - 1) != 0: + if i % (self.config.num_test_time_steps - 1) != 0: invar[:, 0:2] = self.pred[i - 1][:, 0:2].clone() i += 1 invar[:, 0:2] = self.dataset.normalize_node( invar[:, 0:2], stats["velocity_mean"], stats["velocity_std"] ) + + # Get the predition pred_i = self.model(invar, graph.edata["x"], graph).detach() # predict # denormalize prediction @@ -116,97 +138,182 @@ def predict(self): pred_i[:, 2] = self.dataset.denormalize( pred_i[:, 2], stats["pressure_mean"], stats["pressure_std"] ) + + loss = self.criterion(pred_i, graph.ndata["y"]) + self.loss.append(loss.cpu().detach()) + invar[:, 0:2] = self.dataset.denormalize( invar[:, 0:2], stats["velocity_mean"], stats["velocity_std"] ) # do not update the "wall_boundary" & "outflow" nodes mask = torch.cat((mask, mask), dim=-1).to(self.device) - pred_i[:, 0:2] = torch.where( - mask, pred_i[:, 0:2], torch.zeros_like(pred_i[:, 0:2]) - ) + pred_i[:, 0:2] = torch.where(mask, pred_i[:, 0:2], torch.zeros_like(pred_i[:, 0:2])) # integration - self.pred.append( - torch.cat( - ((pred_i[:, 0:2] + invar[:, 0:2]), pred_i[:, [2]]), dim=-1 - ).cpu() - ) + self.pred.append(torch.cat(((pred_i[:, 0:2] + invar[:, 0:2]), pred_i[:, [2]]), dim=-1).cpu()) self.exact.append( - torch.cat( - ( - (graph.ndata["y"][:, 0:2] + graph.ndata["x"][:, 0:2]), - graph.ndata["y"][:, [2]], - ), - dim=-1, - ).cpu() - ) + torch.cat(((graph.ndata["y"][:, 0:2] + graph.ndata["x"][:, 0:2]), graph.ndata["y"][:, [2]],), + dim=-1, ).cpu()) self.faces.append(torch.squeeze(cells).numpy()) self.graphs.append(graph.cpu()) def init_animation(self, idx): + self.animation_variable = C.viz_vars[idx] self.pred_i = [var[:, idx] for var in self.pred] self.exact_i = [var[:, idx] for var in self.exact] # fig configs plt.rcParams["image.cmap"] = "inferno" - self.fig, self.ax = plt.subplots(2, 1, figsize=(16, 9)) + self.fig, self.ax = plt.subplots(3, 1, figsize=(16, (9 / 2) * 3)) # Set background color to black self.fig.set_facecolor("black") self.ax[0].set_facecolor("black") self.ax[1].set_facecolor("black") + self.ax[2].set_facecolor("black") + self.first_call = True # make animations dir if not os.path.exists("./animations"): os.makedirs("./animations") def animate(self, num): - num *= C.frame_skip + if self.animation_variable == "u": + min_var = -1.0 + max_var = 4.5 + min_delta_var = -0.25 + max_delta_var = 0.25 + elif self.animation_variable == "v": + min_var = -2.0 + max_var = 2.0 + min_delta_var = -0.25 + max_delta_var = 0.25 + elif self.animation_variable == "p": + min_var = -6.0 + max_var = 6.0 + min_delta_var = -0.25 + max_delta_var = 0.25 + + num *= self.config.frame_skip graph = self.graphs[num] y_star = self.pred_i[num].numpy() y_exact = self.exact_i[num].numpy() + y_error = y_star - y_exact triang = mtri.Triangulation( graph.ndata["mesh_pos"][:, 0].numpy(), graph.ndata["mesh_pos"][:, 1].numpy(), self.faces[num], ) + + # Prediction plotting self.ax[0].cla() self.ax[0].set_aspect("equal") self.ax[0].set_axis_off() navy_box = Rectangle((0, 0), 1.4, 0.4, facecolor="navy") self.ax[0].add_patch(navy_box) # Add a navy box to the first subplot - self.ax[0].tripcolor(triang, y_star, vmin=np.min(y_star), vmax=np.max(y_star)) + ans = self.ax[0].tripcolor(triang, y_star, vmin=min_var, vmax=max_var) self.ax[0].triplot(triang, "ko-", ms=0.5, lw=0.3) self.ax[0].set_title("Modulus MeshGraphNet Prediction", color="white") + if num == 0 and self.first_call: + cb_ax = self.fig.add_axes([.9525, .69, .01, .26]) + cb = self.fig.colorbar(ans, orientation='vertical', cax=cb_ax) + # cb = self.fig.colorbar(ans, ax=self.ax[0], location="right") + # COLORBAR + # set colorbar label plus label color + cb.set_label(self.animation_variable, color="white") + + # set colorbar tick color + cb.ax.yaxis.set_tick_params(color="white") + + # set colorbar edgecolor + cb.outline.set_edgecolor("white") + + # set colorbar ticklabels + plt.setp(plt.getp(cb.ax.axes, 'yticklabels'), color="white") + + # Truth plotting self.ax[1].cla() self.ax[1].set_aspect("equal") self.ax[1].set_axis_off() navy_box = Rectangle((0, 0), 1.4, 0.4, facecolor="navy") self.ax[1].add_patch(navy_box) # Add a navy box to the second subplot - self.ax[1].tripcolor( - triang, y_exact, vmin=np.min(y_exact), vmax=np.max(y_exact) - ) + ans = self.ax[1].tripcolor(triang, y_exact, vmin=min_var, vmax=max_var) self.ax[1].triplot(triang, "ko-", ms=0.5, lw=0.3) self.ax[1].set_title("Ground Truth", color="white") + if num == 0 and self.first_call: + cb_ax = self.fig.add_axes([.9525, .37, .01, .26]) + cb = self.fig.colorbar(ans, orientation='vertical', cax=cb_ax) + # cb = self.fig.colorbar(ans, ax=self.ax[0], location="right") + # COLORBAR + # set colorbar label plus label color + cb.set_label(self.animation_variable, color="white") + + # set colorbar tick color + cb.ax.yaxis.set_tick_params(color="white") + + # set colorbar edgecolor + cb.outline.set_edgecolor("white") + + # set colorbar ticklabels + plt.setp(plt.getp(cb.ax.axes, 'yticklabels'), color="white") + + # Error plotting + self.ax[2].cla() + self.ax[2].set_aspect("equal") + self.ax[2].set_axis_off() + navy_box = Rectangle((0, 0), 1.4, 0.4, facecolor="navy") + self.ax[2].add_patch(navy_box) # Add a navy box to the second subplot + ans = self.ax[2].tripcolor(triang, y_error, vmin=min_delta_var, vmax=max_delta_var, cmap="coolwarm") + self.ax[2].triplot(triang, "ko-", ms=0.5, lw=0.3) + self.ax[2].set_title("Absolute Error (Prediction - Ground Truth)", color="white") + if num == 0 and self.first_call: + cb_ax = self.fig.add_axes([.9525, .055, .01, .26]) + cb = self.fig.colorbar(ans, orientation='vertical', cax=cb_ax) + # cb = self.fig.colorbar(ans, ax=self.ax[0], location="right") + # COLORBAR + # set colorbar label plus label color + cb.set_label(self.animation_variable, color="white") + + # set colorbar tick color + cb.ax.yaxis.set_tick_params(color="white") + + # set colorbar edgecolor + cb.outline.set_edgecolor("white") + + # set colorbar ticklabels + plt.setp(plt.getp(cb.ax.axes, 'yticklabels'), color="white") # Adjust subplots to minimize empty space self.ax[0].set_aspect("auto", adjustable="box") - self.ax[1].set_aspect("auto", adjustable="box") self.ax[0].autoscale(enable=True, tight=True) + + self.ax[1].set_aspect("auto", adjustable="box") self.ax[1].autoscale(enable=True, tight=True) + + self.ax[2].set_aspect("auto", adjustable="box") + self.ax[2].autoscale(enable=True, tight=True) + self.fig.subplots_adjust( left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.1, hspace=0.2 ) return self.fig +def setup_config(wandb_config={}): + constant = Constants(**wandb_config) + + return constant + + if __name__ == "__main__": + C = setup_config() + logger = PythonLogger("main") # General python logger logger.file_logging() logger.info("Rollout started...") - rollout = MGNRollout(logger) + rollout = MGNRollout(logger, config=C) idx = [rollout.var_identifier[k] for k in C.viz_vars] rollout.predict() for i in idx: @@ -215,7 +322,11 @@ def animate(self, num): rollout.fig, rollout.animate, frames=len(rollout.graphs) // C.frame_skip, - interval=C.frame_interval, + interval=C.frame_interval ) ani.save("animations/animation_" + C.viz_vars[i] + ".gif") logger.info(f"Created animation for {C.viz_vars[i]}") + + fig, ax = plt.subplots(1, 1, figsize=(16, 4.5)) + ax.plot(rollout.loss) + plt.savefig("animations/loss.png") diff --git a/examples/cfd/vortex_shedding_mgn/train.py b/examples/cfd/vortex_shedding_mgn/train.py index 3470937..ca5414f 100644 --- a/examples/cfd/vortex_shedding_mgn/train.py +++ b/examples/cfd/vortex_shedding_mgn/train.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import time import torch +import torch.nn as nn from dgl.dataloading import GraphDataLoader from torch.cuda.amp import autocast, GradScaler from torch.nn.parallel import DistributedDataParallel -import time, os + import wandb as wb try: @@ -65,7 +68,23 @@ def __init__(self, wb, dist, rank_zero_logger): # instantiate the model self.model = MeshGraphNet( - C.num_input_features, C.num_edge_features, C.num_output_features + input_dim_nodes=C.num_input_features, + input_dim_edges=C.num_edge_features, + output_dim=C.num_output_features, + processor_size=C.processor_size, + num_layers_node_processor=C.num_layers_node_processor, + num_layers_edge_processor=C.num_layers_edge_processor, + hidden_dim_processor=C.hidden_dim_processor, + hidden_dim_node_encoder=C.hidden_dim_node_encoder, + num_layers_node_encoder=C.num_layers_node_encoder, + hidden_dim_edge_encoder=C.hidden_dim_edge_encoder, + num_layers_edge_encoder=C.num_layers_edge_encoder, + hidden_dim_node_decoder=C.hidden_dim_node_decoder, + num_layers_node_decoder=C.num_layers_node_decoder, + aggregation=C.aggregation, + do_concat_trick=C.do_concat_trick, + num_processor_checkpoint_segments=C.num_processor_checkpoint_segments, + activation_fn=C.activation_fn, ) if C.jit: self.model = torch.jit.script(self.model).to(dist.device) @@ -95,7 +114,7 @@ def __init__(self, wb, dist, rank_zero_logger): except: self.optimizer = torch.optim.Adam(self.model.parameters(), lr=C.lr) self.scheduler = torch.optim.lr_scheduler.LambdaLR( - self.optimizer, lr_lambda=lambda epoch: C.lr_decay_rate**epoch + self.optimizer, lr_lambda=lambda epoch: C.lr_decay_rate ** epoch ) self.scaler = GradScaler() @@ -146,17 +165,18 @@ def backward(self, loss): if dist.rank == 0: os.makedirs(C.ckpt_path, exist_ok=True) with open( - os.path.join(C.ckpt_path, C.ckpt_name.replace(".pt", ".json")), "w" + os.path.join(C.ckpt_path, C.ckpt_name.replace(".pt", ".json")), "w" ) as json_file: json_file.write(C.json(indent=4)) # initialize loggers initialize_wandb( - project="Modulus-Launch", - entity="Modulus", - name="Vortex_Shedding-Training", - group="Vortex_Shedding-DDP-Group", + project="modulus_gnn", + entity="limitingfactor", + name="Vortex_Shedding-Training_2", + group=None, mode=C.wandb_mode, + config=C.__dict__ ) # Wandb logger logger = PythonLogger("main") # General python logger rank_zero_logger = RankZeroLoggingWrapper(logger, dist) # Rank 0 logger @@ -169,7 +189,7 @@ def backward(self, loss): for graph in trainer.dataloader: loss = trainer.train(graph) rank_zero_logger.info( - f"epoch: {epoch}, loss: {loss:10.3e}, time per epoch: {(time.time()-start):10.3e}" + f"epoch: {epoch}, loss: {loss:10.3e}, time per epoch: {(time.time() - start):10.3e}" ) wb.log({"loss": loss.detach().cpu()}) diff --git a/examples/cfd/vortex_shedding_mgn/wandb_train.py b/examples/cfd/vortex_shedding_mgn/wandb_train.py new file mode 100644 index 0000000..028c0b9 --- /dev/null +++ b/examples/cfd/vortex_shedding_mgn/wandb_train.py @@ -0,0 +1,312 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import time +from typing import Optional, Any + +import torch +from dgl.dataloading import GraphDataLoader +from torch.cuda.amp import autocast, GradScaler +from torch.nn.parallel import DistributedDataParallel + +import wandb as wb + +try: + import apex +except: + pass + +from modulus.models.meshgraphnet import MeshGraphNet +from modulus.datapipes.gnn.vortex_shedding_dataset import VortexSheddingDataset +from modulus.distributed.manager import DistributedManager + +from modulus.launch.logging import ( + PythonLogger, + initialize_wandb, + RankZeroLoggingWrapper, +) +from modulus.launch.utils import load_checkpoint, save_checkpoint +from constants import Constants + + +class MGNTrainer: + def __init__(self, wb, dist, rank_zero_logger, config): + self.dist = dist + self.config = config + + # instantiate dataset + rank_zero_logger.info("Loading the training dataset...") + dataset = VortexSheddingDataset( + name="vortex_shedding_train", + data_dir=config.data_dir, + split="train", + num_samples=config.num_training_samples, + num_steps=config.num_training_time_steps, + noise_std=config.noise_std, + force_reload=False, + verbose=False, + ) + + # instantiate dataloader + self.dataloader = GraphDataLoader( + dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True, + use_ddp=dist.world_size > 1, + ) + + # instantiate validation dataset + rank_zero_logger.info("Loading the validation dataset...") + valid_dataset = VortexSheddingDataset( + name="vortex_shedding_valid", + data_dir=config.data_dir, + split="valid", + num_samples=config.num_valid_samples, + num_steps=config.num_valid_time_steps, + noise_std=config.noise_std, + force_reload=False, + verbose=False, + ) + + # instantiate validation dataloader + self.valid_dataloader = GraphDataLoader( + valid_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True, + use_ddp=False, + ) + + # instantiate test dataset + rank_zero_logger.info("Loading the test dataset...") + test_dataset = VortexSheddingDataset( + name="vortex_shedding_test", + data_dir=config.data_dir, + split="test", + num_samples=config.num_test_samples, + num_steps=config.num_test_time_steps, + noise_std=config.noise_std, + force_reload=False, + verbose=False, + ) + + # instantiate test dataloader + self.test_dataloader = GraphDataLoader( + test_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True, + use_ddp=False, + ) + + # instantiate the model + self.model = MeshGraphNet( + input_dim_nodes=config.num_input_features, + input_dim_edges=config.num_edge_features, + output_dim=config.num_output_features, + processor_size=config.processor_size, + num_layers_node_processor=config.num_layers_node_processor, + num_layers_edge_processor=config.num_layers_edge_processor, + hidden_dim_processor=config.hidden_dim_processor, + hidden_dim_node_encoder=config.hidden_dim_node_encoder, + num_layers_node_encoder=config.num_layers_node_encoder, + hidden_dim_edge_encoder=config.hidden_dim_edge_encoder, + num_layers_edge_encoder=config.num_layers_edge_encoder, + hidden_dim_node_decoder=config.hidden_dim_node_decoder, + num_layers_node_decoder=config.num_layers_node_decoder, + aggregation=config.aggregation, + do_concat_trick=config.do_concat_trick, + num_processor_checkpoint_segments=config.num_processor_checkpoint_segments, + activation_fn=config.activation_fn, + ) + if config.jit: + self.model = torch.jit.script(self.model).to(dist.device) + else: + self.model = self.model.to(dist.device) + if config.watch_model and not config.jit and dist.rank == 0: + wb.watch(self.model) + + # distributed data parallel for multi-node training + if dist.world_size > 1: + self.model = DistributedDataParallel( + self.model, + device_ids=[dist.local_rank], + output_device=dist.device, + broadcast_buffers=dist.broadcast_buffers, + find_unused_parameters=dist.find_unused_parameters, + ) + + # instantiate loss, optimizer, and scheduler + self.criterion = torch.nn.MSELoss() + try: + self.optimizer = apex.optimizers.FusedAdam(self.model.parameters(), lr=config.lr) + rank_zero_logger.info("Using FusedAdam optimizer") + except: + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=config.lr) + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=lambda epoch: config.lr_decay_rate ** epoch + ) + self.scaler = GradScaler() + + # load checkpoint + if dist.world_size > 1: + torch.distributed.barrier() + self.epoch_init = load_checkpoint( + os.path.join(config.ckpt_path, config.ckpt_name), + models=self.model, + optimizer=self.optimizer, + scheduler=self.scheduler, + scaler=self.scaler, + device=dist.device, + ) + + def train(self, graph): + # enable train mode + self.model.train() + + graph = graph.to(self.dist.device) + self.optimizer.zero_grad() + loss = self.forward(graph) + self.backward(loss) + self.scheduler.step() + return loss + + def forward(self, graph): + # forward pass + with autocast(enabled=self.config.amp): + pred = self.model(graph.ndata["x"], graph.edata["x"], graph) + loss = self.criterion(pred, graph.ndata["y"]) + return loss + + def backward(self, loss): + # backward pass + if self.config.amp: + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + else: + loss.backward() + self.optimizer.step() + + def get_lr(self): + # get the learning rate + for param_group in self.optimizer.param_groups: + return param_group["lr"] + + # @torch.no_grad() + # def validation(self): + # # enable train mode + # self.model.eval() + # error = 0 + # for graph in self.validation_dataloader: + # graph = graph.to(self.dist.device) + # pred = self.model(graph.ndata["x"], graph.edata["x"], graph) + # gt = graph.ndata["y"] + # error += relative_lp_error(pred, gt) + # error = error / len(self.validation_dataloader) + # self.wb.log({"val_error (%)": error}) + # self.rank_zero_logger.info(f"Validation error (%): {error}") + + +def setup_config(wandb_config): + constant = Constants(**wandb_config) + + return constant + + +def main(project: Optional[str] = "modulus_gnn", entity: Optional[str] = "limitingfactor", **kwargs: Any): + # initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() + + if project is None: + project = "modulus_gnn" + + if project is None: + entity = "limitingfactor" + + # initialize loggers + run = initialize_wandb( + project=project, + entity=entity, + mode="online" + ) # Wandb logger + + C = setup_config(wandb_config=run.config) + if kwargs["activation_fn"]: + C.activation_fn = kwargs["activation_fn"] + + # save constants to JSON file + if dist.rank == 0: + os.makedirs(C.ckpt_path, exist_ok=True) + with open( + os.path.join(C.ckpt_path, C.ckpt_name.replace(".pt", ".json")), "w" + ) as json_file: + json_file.write(C.json(indent=4)) + + logger = PythonLogger("main") # General python logger + rank_zero_logger = RankZeroLoggingWrapper(logger, dist) # Rank 0 logger + logger.file_logging() + + trainer = MGNTrainer(wb, dist, rank_zero_logger, C) + start = time.time() + rank_zero_logger.info("Training started...") + for epoch in range(trainer.epoch_init, C.epochs): + loss_agg = 0 + for graph in trainer.dataloader: + loss = trainer.train(graph) + loss_agg += loss.detach().cpu().numpy() + loss_agg /= len(trainer.dataloader) + rank_zero_logger.info( + f"epoch: {epoch}, loss: {loss_agg:10.3e}, time per epoch: {(time.time() - start):10.3e}" + ) + wb.log({"loss_train": loss_agg}) + + # save checkpoint + if dist.world_size > 1: + torch.distributed.barrier() + if dist.rank == 0: + save_checkpoint( + os.path.join(C.ckpt_path, C.ckpt_name), + models=trainer.model, + optimizer=trainer.optimizer, + scheduler=trainer.scheduler, + scaler=trainer.scaler, + epoch=epoch, + ) + logger.info(f"Saved model on rank {dist.rank}") + start = time.time() + rank_zero_logger.info("Training completed!") + + +def get_options(): + parser = argparse.ArgumentParser() + parser.add_argument("--entity", "-e", type=str, default=None) + parser.add_argument("--project", "-p", type=str, default=None) + parser.add_argument("--activation_fn", type=str, default=None) + args = parser.parse_args() + + return vars(args) + + +if __name__ == "__main__": + opts = get_options() + + main(**opts) diff --git a/modulus/launch/logging/wandb.py b/modulus/launch/logging/wandb.py index 19cf6c1..23e2f64 100644 --- a/modulus/launch/logging/wandb.py +++ b/modulus/launch/logging/wandb.py @@ -91,7 +91,7 @@ def initialize_wandb( if not os.path.exists(wandb_dir): os.makedirs(wandb_dir) - wandb.init( + run = wandb.init( project=project, entity=entity, sync_tensorboard=sync_tensorboard, @@ -104,6 +104,8 @@ def initialize_wandb( save_code=save_code, ) + return run + def alert(title, text, duration=300, level=0, is_master=True): """Send alert.""" diff --git a/modulus/launch/utils/checkpoint.py b/modulus/launch/utils/checkpoint.py index ceb6a9c..c0ee60f 100644 --- a/modulus/launch/utils/checkpoint.py +++ b/modulus/launch/utils/checkpoint.py @@ -357,7 +357,8 @@ def load_checkpoint( checkpoint_logging.success("Loaded scheduler state dictionary") # Scaler state dict - if "scaler_state_dict" in checkpoint_dict: + if scaler and "scaler_state_dict" in checkpoint_dict: + print(checkpoint_dict["scaler_state_dict"]) scaler.load_state_dict(checkpoint_dict["scaler_state_dict"]) checkpoint_logging.success("Loaded grad scaler state dictionary") diff --git a/setup.py b/setup.py index df03908..97ecd3f 100644 --- a/setup.py +++ b/setup.py @@ -14,4 +14,6 @@ from setuptools import setup -setup() +setup( + version="0.3.0a0" +) From 3b85fda52ef67a23e3cc2e6a9433a12006a2e67d Mon Sep 17 00:00:00 2001 From: LimitingFactor Date: Sat, 23 Sep 2023 14:25:24 +0100 Subject: [PATCH 2/4] improved the inference plotting, fixed the checkpoint loading --- .dockerignore | 3 + examples/cfd/vortex_shedding_mgn/constants.py | 44 ++- examples/cfd/vortex_shedding_mgn/inference.py | 195 ++++++++--- examples/cfd/vortex_shedding_mgn/train.py | 38 ++- .../cfd/vortex_shedding_mgn/wandb_train.py | 312 ++++++++++++++++++ modulus/launch/logging/wandb.py | 4 +- modulus/launch/utils/checkpoint.py | 3 +- setup.py | 4 +- 8 files changed, 539 insertions(+), 64 deletions(-) create mode 100644 .dockerignore create mode 100644 examples/cfd/vortex_shedding_mgn/wandb_train.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..99fd87a --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +examples/cfd/vortex_shedding_mgn/checkpoints* +examples/cfd/vortex_shedding_mgn/raw_dataset* +examples/cfd/vortex_shedding_mgn/wandb* \ No newline at end of file diff --git a/examples/cfd/vortex_shedding_mgn/constants.py b/examples/cfd/vortex_shedding_mgn/constants.py index 0dc81ee..b8005a3 100644 --- a/examples/cfd/vortex_shedding_mgn/constants.py +++ b/examples/cfd/vortex_shedding_mgn/constants.py @@ -21,33 +21,57 @@ class Constants(BaseModel): """vortex shedding constants""" + # Model name + model_name: str = "test_2" + # data configs - data_dir: str = "./raw_dataset/cylinder_flow/cylinder_flow" + data_dir: str = "/home/swifta/modulus/datasets/cylinder_flow/cylinder_flow" # training configs batch_size: int = 1 epochs: int = 25 - num_training_samples: int = 400 - num_training_time_steps: int = 300 + num_training_samples: int = 1000 + num_training_time_steps: int = 600 + training_noise_std: float = 0.02 + + num_valid_samples: int = 100 + num_valid_time_steps: int = 600 + lr: float = 0.0001 lr_decay_rate: float = 0.9999991 + ckpt_path: str = "checkpoints_test_3" + ckpt_name: str = "test_3.pt" + + # Mesh Graph Net Setup num_input_features: int = 6 - num_output_features: int = 3 num_edge_features: int = 3 - ckpt_path: str = "checkpoints" - ckpt_name: str = "model.pt" + num_output_features: int = 3 + processor_size: int = 15 + num_layers_node_processor: int = 2 + num_layers_edge_processor: int = 2 + hidden_dim_processor: int = 128 + hidden_dim_node_encoder: int = 128 + num_layers_node_encoder: int = 2 + hidden_dim_edge_encoder: int = 128 + num_layers_edge_encoder: int = 2 + hidden_dim_node_decoder: int = 128 + num_layers_node_decoder: int = 2 + aggregation: str = "sum" + do_concat_trick: bool = False + num_processor_checkpoint_segments: int = 0 + activation_fn: str = "silu" # performance configs amp: bool = False jit: bool = False # test & visualization configs - num_test_samples: int = 10 - num_test_time_steps: int = 300 + num_test_samples: int = 100 + num_test_time_steps: int = 600 viz_vars: Tuple[str, ...] = ("u", "v", "p") frame_skip: int = 10 frame_interval: int = 1 # wb configs - wandb_mode: str = "disabled" - watch_model: bool = False + wandb_mode: str = "online" + watch_model: bool = True diff --git a/examples/cfd/vortex_shedding_mgn/inference.py b/examples/cfd/vortex_shedding_mgn/inference.py index 5373d80..8ee4cb9 100644 --- a/examples/cfd/vortex_shedding_mgn/inference.py +++ b/examples/cfd/vortex_shedding_mgn/inference.py @@ -12,28 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch, dgl -from dgl.dataloading import GraphDataLoader -import torch +import os + import matplotlib.pyplot as plt -import numpy as np +import torch +from dgl.dataloading import GraphDataLoader from matplotlib import animation from matplotlib import tri as mtri -import os from matplotlib.patches import Rectangle - -from modulus.models.meshgraphnet import MeshGraphNet from modulus.datapipes.gnn.vortex_shedding_dataset import VortexSheddingDataset +from modulus.models.meshgraphnet import MeshGraphNet + +from constants import Constants from modulus.launch.logging import PythonLogger from modulus.launch.utils import load_checkpoint -from constants import Constants # Instantiate constants C = Constants() class MGNRollout: - def __init__(self, logger): + def __init__(self, logger, config): + self.config = config # set device self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using {self.device} device") @@ -41,10 +41,10 @@ def __init__(self, logger): # instantiate dataset self.dataset = VortexSheddingDataset( name="vortex_shedding_test", - data_dir=C.data_dir, + data_dir=config.data_dir, split="test", - num_samples=C.num_test_samples, - num_steps=C.num_test_time_steps, + num_samples=config.num_test_samples, + num_steps=config.num_test_time_steps, ) # instantiate dataloader @@ -57,19 +57,38 @@ def __init__(self, logger): # instantiate the model self.model = MeshGraphNet( - C.num_input_features, C.num_edge_features, C.num_output_features + input_dim_nodes=config.num_input_features, + input_dim_edges=config.num_edge_features, + output_dim=config.num_output_features, + processor_size=config.processor_size, + num_layers_node_processor=config.num_layers_node_processor, + num_layers_edge_processor=config.num_layers_edge_processor, + hidden_dim_processor=config.hidden_dim_processor, + hidden_dim_node_encoder=config.hidden_dim_node_encoder, + num_layers_node_encoder=config.num_layers_node_encoder, + hidden_dim_edge_encoder=config.hidden_dim_edge_encoder, + num_layers_edge_encoder=config.num_layers_edge_encoder, + hidden_dim_node_decoder=config.hidden_dim_node_decoder, + num_layers_node_decoder=config.num_layers_node_decoder, + aggregation=config.aggregation, + do_concat_trick=config.do_concat_trick, + num_processor_checkpoint_segments=config.num_processor_checkpoint_segments, + activation_fn=config.activation_fn, ) - if C.jit: + if config.jit: self.model = torch.jit.script(self.model).to(self.device) else: self.model = self.model.to(self.device) + # instantiate loss + self.criterion = torch.nn.MSELoss() + # enable train mode self.model.eval() # load checkpoint _ = load_checkpoint( - os.path.join(C.ckpt_path, C.ckpt_name), + path=os.path.join(config.ckpt_path, config.ckpt_name), models=self.model, device=self.device, ) @@ -77,12 +96,13 @@ def __init__(self, logger): self.var_identifier = {"u": 0, "v": 1, "p": 2} def predict(self): - self.pred, self.exact, self.faces, self.graphs = [], [], [], [] + self.pred, self.exact, self.faces, self.graphs, self.loss = [], [], [], [], [] stats = { key: value.to(self.device) for key, value in self.dataset.node_stats.items() } for i, (graph, cells, mask) in enumerate(self.dataloader): graph = graph.to(self.device) + # denormalize data graph.ndata["x"][:, 0:2] = self.dataset.denormalize( graph.ndata["x"][:, 0:2], stats["velocity_mean"], stats["velocity_std"] @@ -101,12 +121,14 @@ def predict(self): # inference step invar = graph.ndata["x"].clone() - if i % (C.num_test_time_steps - 1) != 0: + if i % (self.config.num_test_time_steps - 1) != 0: invar[:, 0:2] = self.pred[i - 1][:, 0:2].clone() i += 1 invar[:, 0:2] = self.dataset.normalize_node( invar[:, 0:2], stats["velocity_mean"], stats["velocity_std"] ) + + # Get the predition pred_i = self.model(invar, graph.edata["x"], graph).detach() # predict # denormalize prediction @@ -116,97 +138,182 @@ def predict(self): pred_i[:, 2] = self.dataset.denormalize( pred_i[:, 2], stats["pressure_mean"], stats["pressure_std"] ) + + loss = self.criterion(pred_i, graph.ndata["y"]) + self.loss.append(loss.cpu().detach()) + invar[:, 0:2] = self.dataset.denormalize( invar[:, 0:2], stats["velocity_mean"], stats["velocity_std"] ) # do not update the "wall_boundary" & "outflow" nodes mask = torch.cat((mask, mask), dim=-1).to(self.device) - pred_i[:, 0:2] = torch.where( - mask, pred_i[:, 0:2], torch.zeros_like(pred_i[:, 0:2]) - ) + pred_i[:, 0:2] = torch.where(mask, pred_i[:, 0:2], torch.zeros_like(pred_i[:, 0:2])) # integration - self.pred.append( - torch.cat( - ((pred_i[:, 0:2] + invar[:, 0:2]), pred_i[:, [2]]), dim=-1 - ).cpu() - ) + self.pred.append(torch.cat(((pred_i[:, 0:2] + invar[:, 0:2]), pred_i[:, [2]]), dim=-1).cpu()) self.exact.append( - torch.cat( - ( - (graph.ndata["y"][:, 0:2] + graph.ndata["x"][:, 0:2]), - graph.ndata["y"][:, [2]], - ), - dim=-1, - ).cpu() - ) + torch.cat(((graph.ndata["y"][:, 0:2] + graph.ndata["x"][:, 0:2]), graph.ndata["y"][:, [2]],), + dim=-1, ).cpu()) self.faces.append(torch.squeeze(cells).numpy()) self.graphs.append(graph.cpu()) def init_animation(self, idx): + self.animation_variable = C.viz_vars[idx] self.pred_i = [var[:, idx] for var in self.pred] self.exact_i = [var[:, idx] for var in self.exact] # fig configs plt.rcParams["image.cmap"] = "inferno" - self.fig, self.ax = plt.subplots(2, 1, figsize=(16, 9)) + self.fig, self.ax = plt.subplots(3, 1, figsize=(16, (9 / 2) * 3)) # Set background color to black self.fig.set_facecolor("black") self.ax[0].set_facecolor("black") self.ax[1].set_facecolor("black") + self.ax[2].set_facecolor("black") + self.first_call = True # make animations dir if not os.path.exists("./animations"): os.makedirs("./animations") def animate(self, num): - num *= C.frame_skip + if self.animation_variable == "u": + min_var = -1.0 + max_var = 4.5 + min_delta_var = -0.25 + max_delta_var = 0.25 + elif self.animation_variable == "v": + min_var = -2.0 + max_var = 2.0 + min_delta_var = -0.25 + max_delta_var = 0.25 + elif self.animation_variable == "p": + min_var = -6.0 + max_var = 6.0 + min_delta_var = -0.25 + max_delta_var = 0.25 + + num *= self.config.frame_skip graph = self.graphs[num] y_star = self.pred_i[num].numpy() y_exact = self.exact_i[num].numpy() + y_error = y_star - y_exact triang = mtri.Triangulation( graph.ndata["mesh_pos"][:, 0].numpy(), graph.ndata["mesh_pos"][:, 1].numpy(), self.faces[num], ) + + # Prediction plotting self.ax[0].cla() self.ax[0].set_aspect("equal") self.ax[0].set_axis_off() navy_box = Rectangle((0, 0), 1.4, 0.4, facecolor="navy") self.ax[0].add_patch(navy_box) # Add a navy box to the first subplot - self.ax[0].tripcolor(triang, y_star, vmin=np.min(y_star), vmax=np.max(y_star)) + ans = self.ax[0].tripcolor(triang, y_star, vmin=min_var, vmax=max_var) self.ax[0].triplot(triang, "ko-", ms=0.5, lw=0.3) self.ax[0].set_title("Modulus MeshGraphNet Prediction", color="white") + if num == 0 and self.first_call: + cb_ax = self.fig.add_axes([.9525, .69, .01, .26]) + cb = self.fig.colorbar(ans, orientation='vertical', cax=cb_ax) + # cb = self.fig.colorbar(ans, ax=self.ax[0], location="right") + # COLORBAR + # set colorbar label plus label color + cb.set_label(self.animation_variable, color="white") + + # set colorbar tick color + cb.ax.yaxis.set_tick_params(color="white") + + # set colorbar edgecolor + cb.outline.set_edgecolor("white") + + # set colorbar ticklabels + plt.setp(plt.getp(cb.ax.axes, 'yticklabels'), color="white") + + # Truth plotting self.ax[1].cla() self.ax[1].set_aspect("equal") self.ax[1].set_axis_off() navy_box = Rectangle((0, 0), 1.4, 0.4, facecolor="navy") self.ax[1].add_patch(navy_box) # Add a navy box to the second subplot - self.ax[1].tripcolor( - triang, y_exact, vmin=np.min(y_exact), vmax=np.max(y_exact) - ) + ans = self.ax[1].tripcolor(triang, y_exact, vmin=min_var, vmax=max_var) self.ax[1].triplot(triang, "ko-", ms=0.5, lw=0.3) self.ax[1].set_title("Ground Truth", color="white") + if num == 0 and self.first_call: + cb_ax = self.fig.add_axes([.9525, .37, .01, .26]) + cb = self.fig.colorbar(ans, orientation='vertical', cax=cb_ax) + # cb = self.fig.colorbar(ans, ax=self.ax[0], location="right") + # COLORBAR + # set colorbar label plus label color + cb.set_label(self.animation_variable, color="white") + + # set colorbar tick color + cb.ax.yaxis.set_tick_params(color="white") + + # set colorbar edgecolor + cb.outline.set_edgecolor("white") + + # set colorbar ticklabels + plt.setp(plt.getp(cb.ax.axes, 'yticklabels'), color="white") + + # Error plotting + self.ax[2].cla() + self.ax[2].set_aspect("equal") + self.ax[2].set_axis_off() + navy_box = Rectangle((0, 0), 1.4, 0.4, facecolor="navy") + self.ax[2].add_patch(navy_box) # Add a navy box to the second subplot + ans = self.ax[2].tripcolor(triang, y_error, vmin=min_delta_var, vmax=max_delta_var, cmap="coolwarm") + self.ax[2].triplot(triang, "ko-", ms=0.5, lw=0.3) + self.ax[2].set_title("Absolute Error (Prediction - Ground Truth)", color="white") + if num == 0 and self.first_call: + cb_ax = self.fig.add_axes([.9525, .055, .01, .26]) + cb = self.fig.colorbar(ans, orientation='vertical', cax=cb_ax) + # cb = self.fig.colorbar(ans, ax=self.ax[0], location="right") + # COLORBAR + # set colorbar label plus label color + cb.set_label(self.animation_variable, color="white") + + # set colorbar tick color + cb.ax.yaxis.set_tick_params(color="white") + + # set colorbar edgecolor + cb.outline.set_edgecolor("white") + + # set colorbar ticklabels + plt.setp(plt.getp(cb.ax.axes, 'yticklabels'), color="white") # Adjust subplots to minimize empty space self.ax[0].set_aspect("auto", adjustable="box") - self.ax[1].set_aspect("auto", adjustable="box") self.ax[0].autoscale(enable=True, tight=True) + + self.ax[1].set_aspect("auto", adjustable="box") self.ax[1].autoscale(enable=True, tight=True) + + self.ax[2].set_aspect("auto", adjustable="box") + self.ax[2].autoscale(enable=True, tight=True) + self.fig.subplots_adjust( left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.1, hspace=0.2 ) return self.fig +def setup_config(wandb_config={}): + constant = Constants(**wandb_config) + + return constant + + if __name__ == "__main__": + C = setup_config() + logger = PythonLogger("main") # General python logger logger.file_logging() logger.info("Rollout started...") - rollout = MGNRollout(logger) + rollout = MGNRollout(logger, config=C) idx = [rollout.var_identifier[k] for k in C.viz_vars] rollout.predict() for i in idx: @@ -215,7 +322,11 @@ def animate(self, num): rollout.fig, rollout.animate, frames=len(rollout.graphs) // C.frame_skip, - interval=C.frame_interval, + interval=C.frame_interval ) ani.save("animations/animation_" + C.viz_vars[i] + ".gif") logger.info(f"Created animation for {C.viz_vars[i]}") + + fig, ax = plt.subplots(1, 1, figsize=(16, 4.5)) + ax.plot(rollout.loss) + plt.savefig("animations/loss.png") diff --git a/examples/cfd/vortex_shedding_mgn/train.py b/examples/cfd/vortex_shedding_mgn/train.py index 3470937..ca5414f 100644 --- a/examples/cfd/vortex_shedding_mgn/train.py +++ b/examples/cfd/vortex_shedding_mgn/train.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import time import torch +import torch.nn as nn from dgl.dataloading import GraphDataLoader from torch.cuda.amp import autocast, GradScaler from torch.nn.parallel import DistributedDataParallel -import time, os + import wandb as wb try: @@ -65,7 +68,23 @@ def __init__(self, wb, dist, rank_zero_logger): # instantiate the model self.model = MeshGraphNet( - C.num_input_features, C.num_edge_features, C.num_output_features + input_dim_nodes=C.num_input_features, + input_dim_edges=C.num_edge_features, + output_dim=C.num_output_features, + processor_size=C.processor_size, + num_layers_node_processor=C.num_layers_node_processor, + num_layers_edge_processor=C.num_layers_edge_processor, + hidden_dim_processor=C.hidden_dim_processor, + hidden_dim_node_encoder=C.hidden_dim_node_encoder, + num_layers_node_encoder=C.num_layers_node_encoder, + hidden_dim_edge_encoder=C.hidden_dim_edge_encoder, + num_layers_edge_encoder=C.num_layers_edge_encoder, + hidden_dim_node_decoder=C.hidden_dim_node_decoder, + num_layers_node_decoder=C.num_layers_node_decoder, + aggregation=C.aggregation, + do_concat_trick=C.do_concat_trick, + num_processor_checkpoint_segments=C.num_processor_checkpoint_segments, + activation_fn=C.activation_fn, ) if C.jit: self.model = torch.jit.script(self.model).to(dist.device) @@ -95,7 +114,7 @@ def __init__(self, wb, dist, rank_zero_logger): except: self.optimizer = torch.optim.Adam(self.model.parameters(), lr=C.lr) self.scheduler = torch.optim.lr_scheduler.LambdaLR( - self.optimizer, lr_lambda=lambda epoch: C.lr_decay_rate**epoch + self.optimizer, lr_lambda=lambda epoch: C.lr_decay_rate ** epoch ) self.scaler = GradScaler() @@ -146,17 +165,18 @@ def backward(self, loss): if dist.rank == 0: os.makedirs(C.ckpt_path, exist_ok=True) with open( - os.path.join(C.ckpt_path, C.ckpt_name.replace(".pt", ".json")), "w" + os.path.join(C.ckpt_path, C.ckpt_name.replace(".pt", ".json")), "w" ) as json_file: json_file.write(C.json(indent=4)) # initialize loggers initialize_wandb( - project="Modulus-Launch", - entity="Modulus", - name="Vortex_Shedding-Training", - group="Vortex_Shedding-DDP-Group", + project="modulus_gnn", + entity="limitingfactor", + name="Vortex_Shedding-Training_2", + group=None, mode=C.wandb_mode, + config=C.__dict__ ) # Wandb logger logger = PythonLogger("main") # General python logger rank_zero_logger = RankZeroLoggingWrapper(logger, dist) # Rank 0 logger @@ -169,7 +189,7 @@ def backward(self, loss): for graph in trainer.dataloader: loss = trainer.train(graph) rank_zero_logger.info( - f"epoch: {epoch}, loss: {loss:10.3e}, time per epoch: {(time.time()-start):10.3e}" + f"epoch: {epoch}, loss: {loss:10.3e}, time per epoch: {(time.time() - start):10.3e}" ) wb.log({"loss": loss.detach().cpu()}) diff --git a/examples/cfd/vortex_shedding_mgn/wandb_train.py b/examples/cfd/vortex_shedding_mgn/wandb_train.py new file mode 100644 index 0000000..028c0b9 --- /dev/null +++ b/examples/cfd/vortex_shedding_mgn/wandb_train.py @@ -0,0 +1,312 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import time +from typing import Optional, Any + +import torch +from dgl.dataloading import GraphDataLoader +from torch.cuda.amp import autocast, GradScaler +from torch.nn.parallel import DistributedDataParallel + +import wandb as wb + +try: + import apex +except: + pass + +from modulus.models.meshgraphnet import MeshGraphNet +from modulus.datapipes.gnn.vortex_shedding_dataset import VortexSheddingDataset +from modulus.distributed.manager import DistributedManager + +from modulus.launch.logging import ( + PythonLogger, + initialize_wandb, + RankZeroLoggingWrapper, +) +from modulus.launch.utils import load_checkpoint, save_checkpoint +from constants import Constants + + +class MGNTrainer: + def __init__(self, wb, dist, rank_zero_logger, config): + self.dist = dist + self.config = config + + # instantiate dataset + rank_zero_logger.info("Loading the training dataset...") + dataset = VortexSheddingDataset( + name="vortex_shedding_train", + data_dir=config.data_dir, + split="train", + num_samples=config.num_training_samples, + num_steps=config.num_training_time_steps, + noise_std=config.noise_std, + force_reload=False, + verbose=False, + ) + + # instantiate dataloader + self.dataloader = GraphDataLoader( + dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True, + use_ddp=dist.world_size > 1, + ) + + # instantiate validation dataset + rank_zero_logger.info("Loading the validation dataset...") + valid_dataset = VortexSheddingDataset( + name="vortex_shedding_valid", + data_dir=config.data_dir, + split="valid", + num_samples=config.num_valid_samples, + num_steps=config.num_valid_time_steps, + noise_std=config.noise_std, + force_reload=False, + verbose=False, + ) + + # instantiate validation dataloader + self.valid_dataloader = GraphDataLoader( + valid_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True, + use_ddp=False, + ) + + # instantiate test dataset + rank_zero_logger.info("Loading the test dataset...") + test_dataset = VortexSheddingDataset( + name="vortex_shedding_test", + data_dir=config.data_dir, + split="test", + num_samples=config.num_test_samples, + num_steps=config.num_test_time_steps, + noise_std=config.noise_std, + force_reload=False, + verbose=False, + ) + + # instantiate test dataloader + self.test_dataloader = GraphDataLoader( + test_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True, + use_ddp=False, + ) + + # instantiate the model + self.model = MeshGraphNet( + input_dim_nodes=config.num_input_features, + input_dim_edges=config.num_edge_features, + output_dim=config.num_output_features, + processor_size=config.processor_size, + num_layers_node_processor=config.num_layers_node_processor, + num_layers_edge_processor=config.num_layers_edge_processor, + hidden_dim_processor=config.hidden_dim_processor, + hidden_dim_node_encoder=config.hidden_dim_node_encoder, + num_layers_node_encoder=config.num_layers_node_encoder, + hidden_dim_edge_encoder=config.hidden_dim_edge_encoder, + num_layers_edge_encoder=config.num_layers_edge_encoder, + hidden_dim_node_decoder=config.hidden_dim_node_decoder, + num_layers_node_decoder=config.num_layers_node_decoder, + aggregation=config.aggregation, + do_concat_trick=config.do_concat_trick, + num_processor_checkpoint_segments=config.num_processor_checkpoint_segments, + activation_fn=config.activation_fn, + ) + if config.jit: + self.model = torch.jit.script(self.model).to(dist.device) + else: + self.model = self.model.to(dist.device) + if config.watch_model and not config.jit and dist.rank == 0: + wb.watch(self.model) + + # distributed data parallel for multi-node training + if dist.world_size > 1: + self.model = DistributedDataParallel( + self.model, + device_ids=[dist.local_rank], + output_device=dist.device, + broadcast_buffers=dist.broadcast_buffers, + find_unused_parameters=dist.find_unused_parameters, + ) + + # instantiate loss, optimizer, and scheduler + self.criterion = torch.nn.MSELoss() + try: + self.optimizer = apex.optimizers.FusedAdam(self.model.parameters(), lr=config.lr) + rank_zero_logger.info("Using FusedAdam optimizer") + except: + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=config.lr) + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=lambda epoch: config.lr_decay_rate ** epoch + ) + self.scaler = GradScaler() + + # load checkpoint + if dist.world_size > 1: + torch.distributed.barrier() + self.epoch_init = load_checkpoint( + os.path.join(config.ckpt_path, config.ckpt_name), + models=self.model, + optimizer=self.optimizer, + scheduler=self.scheduler, + scaler=self.scaler, + device=dist.device, + ) + + def train(self, graph): + # enable train mode + self.model.train() + + graph = graph.to(self.dist.device) + self.optimizer.zero_grad() + loss = self.forward(graph) + self.backward(loss) + self.scheduler.step() + return loss + + def forward(self, graph): + # forward pass + with autocast(enabled=self.config.amp): + pred = self.model(graph.ndata["x"], graph.edata["x"], graph) + loss = self.criterion(pred, graph.ndata["y"]) + return loss + + def backward(self, loss): + # backward pass + if self.config.amp: + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + else: + loss.backward() + self.optimizer.step() + + def get_lr(self): + # get the learning rate + for param_group in self.optimizer.param_groups: + return param_group["lr"] + + # @torch.no_grad() + # def validation(self): + # # enable train mode + # self.model.eval() + # error = 0 + # for graph in self.validation_dataloader: + # graph = graph.to(self.dist.device) + # pred = self.model(graph.ndata["x"], graph.edata["x"], graph) + # gt = graph.ndata["y"] + # error += relative_lp_error(pred, gt) + # error = error / len(self.validation_dataloader) + # self.wb.log({"val_error (%)": error}) + # self.rank_zero_logger.info(f"Validation error (%): {error}") + + +def setup_config(wandb_config): + constant = Constants(**wandb_config) + + return constant + + +def main(project: Optional[str] = "modulus_gnn", entity: Optional[str] = "limitingfactor", **kwargs: Any): + # initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() + + if project is None: + project = "modulus_gnn" + + if project is None: + entity = "limitingfactor" + + # initialize loggers + run = initialize_wandb( + project=project, + entity=entity, + mode="online" + ) # Wandb logger + + C = setup_config(wandb_config=run.config) + if kwargs["activation_fn"]: + C.activation_fn = kwargs["activation_fn"] + + # save constants to JSON file + if dist.rank == 0: + os.makedirs(C.ckpt_path, exist_ok=True) + with open( + os.path.join(C.ckpt_path, C.ckpt_name.replace(".pt", ".json")), "w" + ) as json_file: + json_file.write(C.json(indent=4)) + + logger = PythonLogger("main") # General python logger + rank_zero_logger = RankZeroLoggingWrapper(logger, dist) # Rank 0 logger + logger.file_logging() + + trainer = MGNTrainer(wb, dist, rank_zero_logger, C) + start = time.time() + rank_zero_logger.info("Training started...") + for epoch in range(trainer.epoch_init, C.epochs): + loss_agg = 0 + for graph in trainer.dataloader: + loss = trainer.train(graph) + loss_agg += loss.detach().cpu().numpy() + loss_agg /= len(trainer.dataloader) + rank_zero_logger.info( + f"epoch: {epoch}, loss: {loss_agg:10.3e}, time per epoch: {(time.time() - start):10.3e}" + ) + wb.log({"loss_train": loss_agg}) + + # save checkpoint + if dist.world_size > 1: + torch.distributed.barrier() + if dist.rank == 0: + save_checkpoint( + os.path.join(C.ckpt_path, C.ckpt_name), + models=trainer.model, + optimizer=trainer.optimizer, + scheduler=trainer.scheduler, + scaler=trainer.scaler, + epoch=epoch, + ) + logger.info(f"Saved model on rank {dist.rank}") + start = time.time() + rank_zero_logger.info("Training completed!") + + +def get_options(): + parser = argparse.ArgumentParser() + parser.add_argument("--entity", "-e", type=str, default=None) + parser.add_argument("--project", "-p", type=str, default=None) + parser.add_argument("--activation_fn", type=str, default=None) + args = parser.parse_args() + + return vars(args) + + +if __name__ == "__main__": + opts = get_options() + + main(**opts) diff --git a/modulus/launch/logging/wandb.py b/modulus/launch/logging/wandb.py index 19cf6c1..23e2f64 100644 --- a/modulus/launch/logging/wandb.py +++ b/modulus/launch/logging/wandb.py @@ -91,7 +91,7 @@ def initialize_wandb( if not os.path.exists(wandb_dir): os.makedirs(wandb_dir) - wandb.init( + run = wandb.init( project=project, entity=entity, sync_tensorboard=sync_tensorboard, @@ -104,6 +104,8 @@ def initialize_wandb( save_code=save_code, ) + return run + def alert(title, text, duration=300, level=0, is_master=True): """Send alert.""" diff --git a/modulus/launch/utils/checkpoint.py b/modulus/launch/utils/checkpoint.py index ceb6a9c..c0ee60f 100644 --- a/modulus/launch/utils/checkpoint.py +++ b/modulus/launch/utils/checkpoint.py @@ -357,7 +357,8 @@ def load_checkpoint( checkpoint_logging.success("Loaded scheduler state dictionary") # Scaler state dict - if "scaler_state_dict" in checkpoint_dict: + if scaler and "scaler_state_dict" in checkpoint_dict: + print(checkpoint_dict["scaler_state_dict"]) scaler.load_state_dict(checkpoint_dict["scaler_state_dict"]) checkpoint_logging.success("Loaded grad scaler state dictionary") diff --git a/setup.py b/setup.py index df03908..97ecd3f 100644 --- a/setup.py +++ b/setup.py @@ -14,4 +14,6 @@ from setuptools import setup -setup() +setup( + version="0.3.0a0" +) From a9e41814df30c7038913a831690a33bd9f1e7866 Mon Sep 17 00:00:00 2001 From: LimitingFactor Date: Sun, 24 Sep 2023 13:05:18 +0100 Subject: [PATCH 3/4] updated to work with wandb as a docker container and now has validation Signed-off-by: LimitingFactor --- .dockerignore | 3 - examples/cfd/vortex_shedding_mgn/constants.py | 36 ++-- examples/cfd/vortex_shedding_mgn/launch.sh | 17 ++ .../cfd/vortex_shedding_mgn/wandb_train.py | 168 ++++++++++++------ modulus/launch/utils/checkpoint.py | 1 - setup.py | 4 +- 6 files changed, 145 insertions(+), 84 deletions(-) delete mode 100644 .dockerignore create mode 100755 examples/cfd/vortex_shedding_mgn/launch.sh diff --git a/.dockerignore b/.dockerignore deleted file mode 100644 index 99fd87a..0000000 --- a/.dockerignore +++ /dev/null @@ -1,3 +0,0 @@ -examples/cfd/vortex_shedding_mgn/checkpoints* -examples/cfd/vortex_shedding_mgn/raw_dataset* -examples/cfd/vortex_shedding_mgn/wandb* \ No newline at end of file diff --git a/examples/cfd/vortex_shedding_mgn/constants.py b/examples/cfd/vortex_shedding_mgn/constants.py index b8005a3..29310f8 100644 --- a/examples/cfd/vortex_shedding_mgn/constants.py +++ b/examples/cfd/vortex_shedding_mgn/constants.py @@ -12,35 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from pathlib import Path +from typing import Tuple + from pydantic import BaseModel -from typing import Tuple, Optional class Constants(BaseModel): """vortex shedding constants""" + # wb configs + wandb_mode: str = "online" + watch_model: bool = True + # Model name - model_name: str = "test_2" + model_name: str = "training" # data configs - data_dir: str = "/home/swifta/modulus/datasets/cylinder_flow/cylinder_flow" + data_dir: str = "/datasets/cylinder_flow/cylinder_flow" # training configs - batch_size: int = 1 - epochs: int = 25 - num_training_samples: int = 1000 - num_training_time_steps: int = 600 + epochs: int = 5 # 25 + training_batch_size: int = 11 + num_training_samples: int = 22 # 1000 + num_training_time_steps: int = 100 # 600 training_noise_std: float = 0.02 - num_valid_samples: int = 100 - num_valid_time_steps: int = 600 + valid_batch_size: int = 1 # Must be 1 for now + num_valid_samples: int = 4 # 100 + num_valid_time_steps: int = 200 # 600 lr: float = 0.0001 lr_decay_rate: float = 0.9999991 - ckpt_path: str = "checkpoints_test_3" - ckpt_name: str = "test_3.pt" + ckpt_path: str = "/workspace/checkpoints_training_5" + ckpt_name: str = "model.pt" # Mesh Graph Net Setup num_input_features: int = 6 @@ -59,7 +63,7 @@ class Constants(BaseModel): aggregation: str = "sum" do_concat_trick: bool = False num_processor_checkpoint_segments: int = 0 - activation_fn: str = "silu" + activation_fn: str = "elu" # performance configs amp: bool = False @@ -71,7 +75,3 @@ class Constants(BaseModel): viz_vars: Tuple[str, ...] = ("u", "v", "p") frame_skip: int = 10 frame_interval: int = 1 - - # wb configs - wandb_mode: str = "online" - watch_model: bool = True diff --git a/examples/cfd/vortex_shedding_mgn/launch.sh b/examples/cfd/vortex_shedding_mgn/launch.sh new file mode 100755 index 0000000..342769d --- /dev/null +++ b/examples/cfd/vortex_shedding_mgn/launch.sh @@ -0,0 +1,17 @@ +python -m pip uninstall nvidia-modulus nvidia-modulus.sym nvidia-modulus.launch -y + +cd /modulus/ +python -m pip install -e . + +cd /modulus-sym/ +python -m pip install -e . + +cd /modulus-launch/ +python -m pip install -e . + +cd /modulus-launch/examples/cfd/vortex_shedding_mgn/ +git config --global --add safe.directory /modulus-launch + +pip install wandb --upgrade + +python wandb_train.py "$@" diff --git a/examples/cfd/vortex_shedding_mgn/wandb_train.py b/examples/cfd/vortex_shedding_mgn/wandb_train.py index 028c0b9..07384c8 100644 --- a/examples/cfd/vortex_shedding_mgn/wandb_train.py +++ b/examples/cfd/vortex_shedding_mgn/wandb_train.py @@ -16,7 +16,7 @@ import os import time from typing import Optional, Any - +from tqdm import tqdm, trange import torch from dgl.dataloading import GraphDataLoader from torch.cuda.amp import autocast, GradScaler @@ -55,7 +55,7 @@ def __init__(self, wb, dist, rank_zero_logger, config): split="train", num_samples=config.num_training_samples, num_steps=config.num_training_time_steps, - noise_std=config.noise_std, + noise_std=config.training_noise_std, force_reload=False, verbose=False, ) @@ -63,7 +63,7 @@ def __init__(self, wb, dist, rank_zero_logger, config): # instantiate dataloader self.dataloader = GraphDataLoader( dataset, - batch_size=config.batch_size, + batch_size=config.training_batch_size, shuffle=True, drop_last=True, pin_memory=True, @@ -72,46 +72,22 @@ def __init__(self, wb, dist, rank_zero_logger, config): # instantiate validation dataset rank_zero_logger.info("Loading the validation dataset...") - valid_dataset = VortexSheddingDataset( + self.valid_dataset = VortexSheddingDataset( name="vortex_shedding_valid", data_dir=config.data_dir, split="valid", num_samples=config.num_valid_samples, num_steps=config.num_valid_time_steps, - noise_std=config.noise_std, force_reload=False, verbose=False, ) # instantiate validation dataloader self.valid_dataloader = GraphDataLoader( - valid_dataset, - batch_size=config.batch_size, + self.valid_dataset, + batch_size=config.valid_batch_size, shuffle=False, - drop_last=True, - pin_memory=True, - use_ddp=False, - ) - - # instantiate test dataset - rank_zero_logger.info("Loading the test dataset...") - test_dataset = VortexSheddingDataset( - name="vortex_shedding_test", - data_dir=config.data_dir, - split="test", - num_samples=config.num_test_samples, - num_steps=config.num_test_time_steps, - noise_std=config.noise_std, - force_reload=False, - verbose=False, - ) - - # instantiate test dataloader - self.test_dataloader = GraphDataLoader( - test_dataset, - batch_size=config.batch_size, - shuffle=False, - drop_last=True, + drop_last=False, pin_memory=True, use_ddp=False, ) @@ -210,20 +186,75 @@ def get_lr(self): for param_group in self.optimizer.param_groups: return param_group["lr"] - # @torch.no_grad() - # def validation(self): - # # enable train mode - # self.model.eval() - # error = 0 - # for graph in self.validation_dataloader: - # graph = graph.to(self.dist.device) - # pred = self.model(graph.ndata["x"], graph.edata["x"], graph) - # gt = graph.ndata["y"] - # error += relative_lp_error(pred, gt) - # error = error / len(self.validation_dataloader) - # self.wb.log({"val_error (%)": error}) - # self.rank_zero_logger.info(f"Validation error (%): {error}") + @torch.no_grad() + def validation(self): + # enable eval mode + self.model.eval() + + self.pred = [] + stats = { + key: value.to(self.dist.device) for key, value in self.valid_dataset.node_stats.items() + } + loss_valid_agg = 0 + for i, (graph, cells, mask) in enumerate(self.valid_dataloader): + graph = graph.to(self.dist.device) + + # denormalize data + graph.ndata["x"][:, 0:2] = self.valid_dataset.denormalize( + graph.ndata["x"][:, 0:2], stats["velocity_mean"], stats["velocity_std"] + ) + graph.ndata["y"][:, 0:2] = self.valid_dataset.denormalize( + graph.ndata["y"][:, 0:2], + stats["velocity_diff_mean"], + stats["velocity_diff_std"], + ) + graph.ndata["y"][:, [2]] = self.valid_dataset.denormalize( + graph.ndata["y"][:, [2]], + stats["pressure_mean"], + stats["pressure_std"], + ) + + # inference step + invar = graph.ndata["x"].clone() + + if i % (self.config.num_valid_time_steps - 1) != 0: + invar[:, 0:2] = self.pred[i - 1][:, 0:2].clone() + i += 1 + invar[:, 0:2] = self.valid_dataset.normalize_node( + invar[:, 0:2], stats["velocity_mean"], stats["velocity_std"] + ) + # Get the prediction + pred_i = self.model(invar, graph.edata["x"], graph).detach() # predict + + # denormalize prediction + pred_i[:, 0:2] = self.valid_dataset.denormalize( + pred_i[:, 0:2], stats["velocity_diff_mean"], stats["velocity_diff_std"] + ) + pred_i[:, 2] = self.valid_dataset.denormalize( + pred_i[:, 2], stats["pressure_mean"], stats["pressure_std"] + ) + + loss = self.criterion(pred_i, graph.ndata["y"]) + loss_valid_agg += loss.detach().cpu().numpy() + + invar[:, 0:2] = self.valid_dataset.denormalize( + invar[:, 0:2], stats["velocity_mean"], stats["velocity_std"] + ) + + # do not update the "wall_boundary" & "outflow" nodes + mask = torch.cat((mask, mask), dim=-1).to(self.dist.device) + pred_i[:, 0:2] = torch.where(mask, pred_i[:, 0:2], torch.zeros_like(pred_i[:, 0:2])) + + # integration + self.pred.append(torch.cat(((pred_i[:, 0:2] + invar[:, 0:2]), pred_i[:, [2]]), dim=-1).cpu()) + + # Don't need to store this beyond the vailidation + self.pred = [] + + loss_valid_agg /= len(self.valid_dataloader) + + return loss_valid_agg def setup_config(wandb_config): constant = Constants(**wandb_config) @@ -236,12 +267,6 @@ def main(project: Optional[str] = "modulus_gnn", entity: Optional[str] = "limiti DistributedManager.initialize() dist = DistributedManager() - if project is None: - project = "modulus_gnn" - - if project is None: - entity = "limitingfactor" - # initialize loggers run = initialize_wandb( project=project, @@ -250,8 +275,6 @@ def main(project: Optional[str] = "modulus_gnn", entity: Optional[str] = "limiti ) # Wandb logger C = setup_config(wandb_config=run.config) - if kwargs["activation_fn"]: - C.activation_fn = kwargs["activation_fn"] # save constants to JSON file if dist.rank == 0: @@ -269,15 +292,39 @@ def main(project: Optional[str] = "modulus_gnn", entity: Optional[str] = "limiti start = time.time() rank_zero_logger.info("Training started...") for epoch in range(trainer.epoch_init, C.epochs): - loss_agg = 0 - for graph in trainer.dataloader: + rank_zero_logger.info(f"Training epoch {epoch}") + + # Train the model + tmp_start = time.time() + loss_train_agg = 0 + for graph in tqdm(trainer.dataloader): loss = trainer.train(graph) - loss_agg += loss.detach().cpu().numpy() - loss_agg /= len(trainer.dataloader) + loss_train_agg += loss.detach().cpu().numpy() + loss_train_agg /= len(trainer.dataloader) + time_per_epoch_train = (time.time() - tmp_start) + + # Run the validation rollout + rank_zero_logger.info(f"Validating epoch {epoch}") + tmp_start = time.time() + loss_valid_agg = trainer.validation() + time_per_epoch_valid = (time.time() - tmp_start) + + # Logging + time_per_epoch = (time.time() - start) rank_zero_logger.info( - f"epoch: {epoch}, loss: {loss_agg:10.3e}, time per epoch: {(time.time() - start):10.3e}" + f"epoch: {epoch}, " + f"loss/train: {loss_train_agg:10.3e}, " + f"loss/valid: {loss_valid_agg:10.3e}, " + f"time per epoch: {time_per_epoch:10.3e}" ) - wb.log({"loss_train": loss_agg}) + wb.log({ + "loss/train": loss_train_agg, + "loss/valid": loss_valid_agg, + "learning rate": trainer.get_lr(), + "time_per_epoch/train": time_per_epoch_train, + "time_per_epoch/valid": time_per_epoch_valid, + "time_per_epoch/total": time_per_epoch + }) # save checkpoint if dist.world_size > 1: @@ -292,15 +339,18 @@ def main(project: Optional[str] = "modulus_gnn", entity: Optional[str] = "limiti epoch=epoch, ) logger.info(f"Saved model on rank {dist.rank}") + start = time.time() + rank_zero_logger.info("Training completed!") def get_options(): parser = argparse.ArgumentParser() + parser.add_argument("--entity", "-e", type=str, default=None) parser.add_argument("--project", "-p", type=str, default=None) - parser.add_argument("--activation_fn", type=str, default=None) + args = parser.parse_args() return vars(args) diff --git a/modulus/launch/utils/checkpoint.py b/modulus/launch/utils/checkpoint.py index c0ee60f..0026c6e 100644 --- a/modulus/launch/utils/checkpoint.py +++ b/modulus/launch/utils/checkpoint.py @@ -358,7 +358,6 @@ def load_checkpoint( # Scaler state dict if scaler and "scaler_state_dict" in checkpoint_dict: - print(checkpoint_dict["scaler_state_dict"]) scaler.load_state_dict(checkpoint_dict["scaler_state_dict"]) checkpoint_logging.success("Loaded grad scaler state dictionary") diff --git a/setup.py b/setup.py index 97ecd3f..df03908 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,4 @@ from setuptools import setup -setup( - version="0.3.0a0" -) +setup() From 750d2cb07612a890c64a4a8ba4d2891bb6a2ca30 Mon Sep 17 00:00:00 2001 From: LimitingFactor Date: Sun, 24 Sep 2023 16:31:33 +0100 Subject: [PATCH 4/4] constant change Signed-off-by: LimitingFactor --- examples/cfd/vortex_shedding_mgn/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cfd/vortex_shedding_mgn/constants.py b/examples/cfd/vortex_shedding_mgn/constants.py index 29310f8..e63dd4f 100644 --- a/examples/cfd/vortex_shedding_mgn/constants.py +++ b/examples/cfd/vortex_shedding_mgn/constants.py @@ -43,7 +43,7 @@ class Constants(BaseModel): lr: float = 0.0001 lr_decay_rate: float = 0.9999991 - ckpt_path: str = "/workspace/checkpoints_training_5" + ckpt_path: str = "/workspace/checkpoints_training_6" ckpt_name: str = "model.pt" # Mesh Graph Net Setup