From 8ea1d572ff514bc2cf6d26936778582412b82f68 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Thu, 20 Jun 2024 12:53:42 -0400 Subject: [PATCH 01/13] rely on checkpoint config not on package configs --- ocpmodels/common/gfn.py | 63 ++++++++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 17 deletions(-) diff --git a/ocpmodels/common/gfn.py b/ocpmodels/common/gfn.py index 0a1f45521c..0c0560f690 100644 --- a/ocpmodels/common/gfn.py +++ b/ocpmodels/common/gfn.py @@ -1,16 +1,17 @@ +import os from copy import deepcopy from pathlib import Path -from typing import Callable, Union, List - -import os +from typing import Callable, List, Union +import torch import torch.nn as nn -from torch_geometric.data.data import Data from torch_geometric.data.batch import Batch +from torch_geometric.data.data import Data -from ocpmodels.common.utils import make_trainer_from_dir, resolve -from ocpmodels.models.faenet import FAENet +from ocpmodels.common.registry import registry +from ocpmodels.common.utils import resolve, setup_imports from ocpmodels.datasets.data_transforms import get_transforms +from ocpmodels.models.faenet import FAENet class FAENetWrapper(nn.Module): @@ -190,6 +191,37 @@ def parse_loc() -> str: return loc +def reset_data_paths(config): + """ + Reset config data paths to defaults, instead of SLURM temporary paths (inplace). + + Args: + config (dict): The trainer config dictionary to modify. + + Returns: + dict: The modified config dictionary. + """ + ds_configs = deepcopy(config["dataset"]) + task_name = config["task"]["name"] + if task_name != "is2re": + raise NotImplementedError( + "Only the is2re task is currently supported for resetting data paths." + + " To implement this for other tasks, modify how `base_path` is constructed" + " in `reset_data_paths()`" + ) + base_path = Path("/network/projects/ocp/oc20/is2re") + for name, ds_config in ds_configs.items(): + if not isinstance(ds_config, dict): + continue + if "slurm" in ds_config["src"].lower(): + ds_config["src"] = str( + base_path / ds_config["split"] / Path(ds_config["src"]).name + ) + config["dataset"][name] = ds_config + + return config + + def find_ckpt(ckpt_paths: dict, release: str) -> Path: """ Finds a checkpoint in a dictionary of paths, based on the current cluster name and @@ -223,7 +255,7 @@ def find_ckpt(ckpt_paths: dict, release: str) -> Path: if path.is_file(): return path path = path / release - ckpts = list(path.glob("**/*.ckpt")) + ckpts = list(path.glob("**/*.pt")) if len(ckpts) == 0: raise ValueError(f"No FAENet proxy checkpoint found at {str(path)}.") if len(ckpts) > 1: @@ -256,18 +288,15 @@ def prepare_for_gfn(ckpt_paths: dict, release: str) -> tuple: Returns: tuple: (model, loaders) where loaders is a dict of loaders for the model. """ + setup_imports() ckpt_path = find_ckpt(ckpt_paths, release) assert ckpt_path.exists(), f"Path {ckpt_path} does not exist." - trainer = make_trainer_from_dir( - ckpt_path, - mode="continue", - overrides={ - "is_debug": True, - "silent": True, - "cp_data_to_tmpdir": False, - }, - silent=True, - ) + config = torch.load(ckpt_path, map_location="cpu")["config"] + config["is_debug"] = True + config["silent"] = True + config["cp_data_to_tmpdir"] = False + config = reset_data_paths(config) + trainer = registry.get_trainer_class(config["trainer"])(**config) wrapper = FAENetWrapper( faenet=trainer.model, From ac1456047201cf350c1a8147c11413b4fb9c2a9c Mon Sep 17 00:00:00 2001 From: vict0rsch Date: Tue, 24 Sep 2024 19:12:49 +0200 Subject: [PATCH 02/13] add `prevent_load`kwarg dict to `load()` --- ocpmodels/trainers/base_trainer.py | 75 ++++++++++++++++++++++-------- 1 file changed, 55 insertions(+), 20 deletions(-) diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index e871027efe..2aa7911a7c 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -4,6 +4,7 @@ This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. """ + import datetime import errno import logging @@ -56,7 +57,7 @@ @registry.register_trainer("base") class BaseTrainer(ABC): - def __init__(self, load=True, **kwargs): + def __init__(self, **kwargs): run_dir = kwargs["run_dir"] model_name = kwargs["model"].pop( @@ -76,9 +77,14 @@ def __init__(self, load=True, **kwargs): } self.sigterm = False - self.objective = None self.epoch = 0 self.step = 0 + self.objective = None + self.logger = None + self.parallel_collater = None + self.ema_decay = None + self.clip_grad_norm = None + self.scheduler = None self.cpu = self.config["cpu"] self.task_name = self.config["task"].get("name", self.config.get("name")) assert self.task_name, "Specify task name (got {})".format(self.task_name) @@ -90,6 +96,7 @@ def __init__(self, load=True, **kwargs): self.datasets = {} self.samplers = {} self.loaders = {} + self.normalizers = {} self.early_stopper = EarlyStopper( patience=self.config["optim"].get("es_patience") or 15, min_abs_change=self.config["optim"].get("es_min_abs_change") or 1e-5, @@ -189,21 +196,49 @@ def __init__(self, load=True, **kwargs): ) self.config["is_disconnected"] = True - self.load() + self.load(self.config.get("prevent_load")) self.evaluator = Evaluator( task=self.task_name, model_regresses_forces=self.config["model"].get("regress_forces", ""), ) - def load(self): - self.load_seed_from_config() - self.load_logger() - self.load_datasets() - self.load_task() - self.load_model() - self.load_loss() - self.load_optimizer() - self.load_extras() + def load(self, prevent_load={}): + """Load all components of the trainer. + + Arbitrary components can be prevented from loading by specifying them in the + ``prevent_load`` dictionary. Allowed keys are: + + - ``seed`` + - ``logger`` + - ``datasets`` + - ``task`` + - ``model`` + - ``checkpoint`` + - ``optimizer`` + - ``extras`` + + Parameters + ---------- + prevent_load : dict, optional + Dictionary describing loading events that should be prevented, by default {} + """ + prevent_load = prevent_load or {} + if "seed" not in prevent_load: + self.load_seed_from_config() + if "logger" not in prevent_load: + self.load_logger() + if "datasets" not in prevent_load: + self.load_datasets() + if "task" not in prevent_load: + self.load_task() + if "model" not in prevent_load: + self.load_model() + if "checkpoint" not in prevent_load: + self.load_loss() + if "optimizer" not in prevent_load: + self.load_optimizer() + if "extras" not in prevent_load: + self.load_extras() def load_seed_from_config(self): # https://pytorch.org/docs/stable/notes/randomness.html @@ -220,7 +255,6 @@ def load_seed_from_config(self): torch.backends.cudnn.benchmark = False def load_logger(self): - self.logger = None if not self.is_debug and dist_utils.is_master() and not self.is_hpo: assert self.config["logger"] is not None, "Specify logger in config" @@ -380,7 +414,6 @@ def load_datasets(self): # Normalizer for the dataset. # Compute mean, std of training set labels. - self.normalizers = {} if self.normalizer.get("normalize_labels", False): if "target_mean" in self.normalizer: self.normalizers["target"] = Normalizer( @@ -619,9 +652,11 @@ def save( "step": self.step, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), - "scheduler": self.scheduler.scheduler.state_dict() - if self.scheduler.scheduler_type != "Null" - else None, + "scheduler": ( + self.scheduler.scheduler.state_dict() + if self.scheduler.scheduler_type != "Null" + else None + ), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() @@ -632,9 +667,9 @@ def save( "amp": self.scaler.state_dict() if self.scaler else None, } if self.scheduler.warmup_scheduler is not None: - ckpt_dict[ - "warmup_scheduler" - ] = self.scheduler.warmup_scheduler.state_dict() + ckpt_dict["warmup_scheduler"] = ( + self.scheduler.warmup_scheduler.state_dict() + ) save_checkpoint( ckpt_dict, From 82407720ba52c44dd9d87cf9d66cd23a956f0425 Mon Sep 17 00:00:00 2001 From: vict0rsch Date: Tue, 24 Sep 2024 19:18:37 +0200 Subject: [PATCH 03/13] add default `prevent_load` --- configs/models/tasks/is2re.yaml | 1 + configs/models/tasks/qm7x.yaml | 1 + configs/models/tasks/qm9.yaml | 1 + configs/models/tasks/s2ef.yaml | 2 ++ 4 files changed, 5 insertions(+) diff --git a/configs/models/tasks/is2re.yaml b/configs/models/tasks/is2re.yaml index 4b1fe4304f..209b4b2567 100644 --- a/configs/models/tasks/is2re.yaml +++ b/configs/models/tasks/is2re.yaml @@ -1,6 +1,7 @@ default: trainer: single logger: wandb + prevent_load: {} task: dataset: single_point_lmdb diff --git a/configs/models/tasks/qm7x.yaml b/configs/models/tasks/qm7x.yaml index defb898e91..7c5f590e55 100644 --- a/configs/models/tasks/qm7x.yaml +++ b/configs/models/tasks/qm7x.yaml @@ -1,6 +1,7 @@ default: trainer: single logger: wandb + prevent_load: {} eval_on_test: True model: diff --git a/configs/models/tasks/qm9.yaml b/configs/models/tasks/qm9.yaml index b13954b393..76371e50d6 100644 --- a/configs/models/tasks/qm9.yaml +++ b/configs/models/tasks/qm9.yaml @@ -2,6 +2,7 @@ default: trainer: single logger: wandb eval_on_test: True + prevent_load: {} model: otf_graph: False diff --git a/configs/models/tasks/s2ef.yaml b/configs/models/tasks/s2ef.yaml index ef62591945..9a54254adc 100644 --- a/configs/models/tasks/s2ef.yaml +++ b/configs/models/tasks/s2ef.yaml @@ -1,6 +1,8 @@ default: trainer: single logger: wandb + prevent_load: {} + task: dataset: trajectory_lmdb description: "Regressing to energies and forces for DFT trajectories from OCP" From 124abcf19d0af2d8b2980537109e3454abc0564a Mon Sep 17 00:00:00 2001 From: vict0rsch Date: Tue, 24 Sep 2024 19:18:50 +0200 Subject: [PATCH 04/13] use `dict.get` --- ocpmodels/trainers/base_trainer.py | 34 +++++++++++++++--------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index 2aa7911a7c..625b26a707 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -208,36 +208,36 @@ def load(self, prevent_load={}): Arbitrary components can be prevented from loading by specifying them in the ``prevent_load`` dictionary. Allowed keys are: - - ``seed`` - - ``logger`` - - ``datasets`` - - ``task`` - - ``model`` - - ``checkpoint`` - - ``optimizer`` - - ``extras`` + - "seed" + - "logger" + - "datasets" + - "task" + - "model" + - "loss" + - "optimizer" + - "extras" Parameters ---------- prevent_load : dict, optional - Dictionary describing loading events that should be prevented, by default {} + Dictionary describing loading events that should be prevented, by default ``{}`` """ prevent_load = prevent_load or {} - if "seed" not in prevent_load: + if prevent_load.get("seed"): self.load_seed_from_config() - if "logger" not in prevent_load: + if prevent_load.get("logger"): self.load_logger() - if "datasets" not in prevent_load: + if prevent_load.get("datasets"): self.load_datasets() - if "task" not in prevent_load: + if prevent_load.get("task"): self.load_task() - if "model" not in prevent_load: + if prevent_load.get("model"): self.load_model() - if "checkpoint" not in prevent_load: + if prevent_load.get("loss"): self.load_loss() - if "optimizer" not in prevent_load: + if prevent_load.get("optimizer"): self.load_optimizer() - if "extras" not in prevent_load: + if prevent_load.get("extras"): self.load_extras() def load_seed_from_config(self): From 210d977868c40288687862ac66e99feb7c9eec3e Mon Sep 17 00:00:00 2001 From: vict0rsch Date: Tue, 24 Sep 2024 19:19:10 +0200 Subject: [PATCH 05/13] prevent loading non-required trainer bits --- ocpmodels/common/gfn.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ocpmodels/common/gfn.py b/ocpmodels/common/gfn.py index 0c0560f690..833b8577d4 100644 --- a/ocpmodels/common/gfn.py +++ b/ocpmodels/common/gfn.py @@ -295,6 +295,13 @@ def prepare_for_gfn(ckpt_paths: dict, release: str) -> tuple: config["is_debug"] = True config["silent"] = True config["cp_data_to_tmpdir"] = False + config["prevent_loaders"] = { + "logger": True, + "loss": True, + "datasets": True, + "optimizer": True, + "extras": True, + } config = reset_data_paths(config) trainer = registry.get_trainer_class(config["trainer"])(**config) From 265e64d4d5ef8231d54f136bb4a4bda0177eb25e Mon Sep 17 00:00:00 2001 From: vict0rsch Date: Tue, 24 Sep 2024 19:19:42 +0200 Subject: [PATCH 06/13] typo in logic --- ocpmodels/trainers/base_trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index 625b26a707..40b446b2d8 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -223,21 +223,21 @@ def load(self, prevent_load={}): Dictionary describing loading events that should be prevented, by default ``{}`` """ prevent_load = prevent_load or {} - if prevent_load.get("seed"): + if not prevent_load.get("seed"): self.load_seed_from_config() - if prevent_load.get("logger"): + if not prevent_load.get("logger"): self.load_logger() - if prevent_load.get("datasets"): + if not prevent_load.get("datasets"): self.load_datasets() - if prevent_load.get("task"): + if not prevent_load.get("task"): self.load_task() - if prevent_load.get("model"): + if not prevent_load.get("model"): self.load_model() - if prevent_load.get("loss"): + if not prevent_load.get("loss"): self.load_loss() - if prevent_load.get("optimizer"): + if not prevent_load.get("optimizer"): self.load_optimizer() - if prevent_load.get("extras"): + if not prevent_load.get("extras"): self.load_extras() def load_seed_from_config(self): From ecd424eefeacfc688229351ea9266ad2cf4f2a2a Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 8 Oct 2024 04:16:44 -0400 Subject: [PATCH 07/13] Ignore cosmosis req --- ocpmodels/datasets/qm7x.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ocpmodels/datasets/qm7x.py b/ocpmodels/datasets/qm7x.py index 7ecb7e0bf3..683476cadc 100644 --- a/ocpmodels/datasets/qm7x.py +++ b/ocpmodels/datasets/qm7x.py @@ -20,10 +20,18 @@ from torch_geometric.data import Data from tqdm import tqdm -from cosmosis.dataset import CDataset from ocpmodels.common.registry import registry from ocpmodels.common.utils import ROOT +CDataset = object +try: + from cosmosis.dataset import CDataset +except ImportError: + print("\nWarning: `cosmosis` is not installed. `QM7X` will not be available.\n") + print("See https://github.com/icanswim/cosmosis") + print(f"(message from {Path(__file__).resolve()})\n") + + try: import orjson as json # noqa: F401 except: # noqa: E722 From f864704e449fda74072326393f2b8c418b19f950 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 8 Oct 2024 04:24:59 -0400 Subject: [PATCH 08/13] fix spherenet --- ocpmodels/datasets/qm7x.py | 9 ++++++--- ocpmodels/models/spherenet.py | 16 +++++++++++++--- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/ocpmodels/datasets/qm7x.py b/ocpmodels/datasets/qm7x.py index 683476cadc..1345333f5e 100644 --- a/ocpmodels/datasets/qm7x.py +++ b/ocpmodels/datasets/qm7x.py @@ -27,9 +27,11 @@ try: from cosmosis.dataset import CDataset except ImportError: - print("\nWarning: `cosmosis` is not installed. `QM7X` will not be available.\n") - print("See https://github.com/icanswim/cosmosis") - print(f"(message from {Path(__file__).resolve()})\n") + print( + "Warning: `cosmosis` is not installed. `QM7X` will not be available.", + "See https://github.com/icanswim/cosmosis", + ) + print(f"(message from {Path(__file__).resolve()})") try: @@ -41,6 +43,7 @@ "`orjson` is not installed. ", "Consider `pip install orjson` to speed up json loading.", ) + print(f"(message from {Path(__file__).resolve()})") class Molecule: diff --git a/ocpmodels/models/spherenet.py b/ocpmodels/models/spherenet.py index df0024fe8e..d0627a4661 100644 --- a/ocpmodels/models/spherenet.py +++ b/ocpmodels/models/spherenet.py @@ -1,9 +1,19 @@ -from dig.threedgraph.method import SphereNet as DIGSphereNet -from ocpmodels.models.base_model import BaseModel +from copy import deepcopy + import torch + from ocpmodels.common.registry import registry from ocpmodels.common.utils import conditional_grad -from copy import deepcopy +from ocpmodels.models.base_model import BaseModel + +DIGSphereNet = None +try: + from dig.threedgraph.method import SphereNet as DIGSphereNet +except ImportError: + from pathlib import Path + + print("Warning: `dig` is not installed. `SphereNet` will not be available.") + print(f"(message from {Path(__file__).resolve()})\n") @registry.register_model("spherenet") From 2e39210ddb255af3bd37139bcf78e3e87e7101bc Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 8 Oct 2024 04:26:35 -0400 Subject: [PATCH 09/13] fix comenet --- ocpmodels/models/comenet.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/ocpmodels/models/comenet.py b/ocpmodels/models/comenet.py index ec8ada6e54..3aa8562a6e 100644 --- a/ocpmodels/models/comenet.py +++ b/ocpmodels/models/comenet.py @@ -1,9 +1,19 @@ -from dig.threedgraph.method import ComENet as DIGComENet -from ocpmodels.models.base_model import BaseModel +from copy import deepcopy + import torch + from ocpmodels.common.registry import registry from ocpmodels.common.utils import conditional_grad -from copy import deepcopy +from ocpmodels.models.base_model import BaseModel + +DIGComENet = None +try: + from dig.threedgraph.method import ComENet as DIGComENet +except ImportError: + from pathlib import Path + + print("Warning: `dig` is not installed. `SphereNet` will not be available.") + print(f"(message from {Path(__file__).resolve()})\n") @registry.register_model("comenet") From 5cb82c1fdb8a8281f1c5261e6ecec7215547651c Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 8 Oct 2024 06:19:26 -0400 Subject: [PATCH 10/13] move warmup import --- ocpmodels/modules/scheduler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ocpmodels/modules/scheduler.py b/ocpmodels/modules/scheduler.py index af4107ebb9..4e9d0fd634 100644 --- a/ocpmodels/modules/scheduler.py +++ b/ocpmodels/modules/scheduler.py @@ -1,10 +1,11 @@ """scheduler.py """ + import inspect + import torch.optim.lr_scheduler as lr_scheduler from ocpmodels.common.utils import warmup_lr_lambda -import pytorch_warmup as warmup class LRScheduler: @@ -54,6 +55,8 @@ def scheduler_lambda_fn(x): if not self.silent: print(f"Using fidelity_max_steps for scheduler -> {T_max}") if self.optim_config["warmup_steps"] > 0: + import pytorch_warmup as warmup + self.warmup_scheduler = warmup.ExponentialWarmup( self.optimizer, warmup_period=self.optim_config["warmup_steps"] ) From 7bcd6b50c6ff24c20dc6d183307116bf4ff5f954 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 8 Oct 2024 06:20:10 -0400 Subject: [PATCH 11/13] proper install from pyproject.toml in addition to readme instructions --- pyproject.toml | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index fed6c97593..738fcd9d40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,57 @@ +[build-system] +# A list of packages that are needed to build your package: +requires = ["setuptools"] # REQUIRED if [build-system] table is used +# The name of the Python object that frontends will use to perform the build: +build-backend = "setuptools.build_meta" # If not defined, then legacy behavior can happen. + +[project] + +name = "ocpmodels" # REQUIRED, is the only field that cannot be marked as dynamic. +version = "0.1.0" # REQUIRED, although can be dynamic +description = "RolnickLab's OCP fork" +readme = "README.md" +requires-python = ">=3.8" +license = { file = "LICENSE.md" } + + +dependencies = [ + "ase>=3.19.3", + "black>=23.1.0", + "CatKit @ git+https://github.com/vict0rsch/CatKit.git@df7f1aa7a47eb7b8022452fa77e4ce60cd006a7d", + "dive-into-graphs @ git+https://github.com/divelab/DIG.git@55c40a7b0938d3804d9f265193cb45a2fe80da8c", + "e3nn==0.5.1", + "flake8>=6.0.0", + "h5py>=3.8.0", + "lmdb>=1.4.0", + "matplotlib>=3.7.0", + "mendeleev>=0.12", + "minydra==0.1.6", + "orjson>=3.8", + "pytorch-warmup>=0.1", + "pymatgen>=2023.2", + "PyYAML>=6.0", + "rdkit>=2022.9.5", + "rich", + "ruamel.yaml", + "scikit-learn", + "scikit-optimize", + "tensorboard", + "torch>=1.12", + "torch_geometric==2.3.0", + "tqdm>=4.66", + "wandb", +] + +[project.optional-dependencies] +geom = [ + "pyg_lib", + "torch_scatter", + "torch_sparse", + "torch_cluster", + "torch_spline_conv", +] +dev = ["ipdb", "ipython", "pytest"] + [tool.black] line-length = 88 include = '\.pyi?$' From ee3d007049ddf5ac71207dd9fb21ae6757f66f4d Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 8 Oct 2024 06:22:35 -0400 Subject: [PATCH 12/13] typo --- ocpmodels/common/gfn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocpmodels/common/gfn.py b/ocpmodels/common/gfn.py index 833b8577d4..cc6b3657a2 100644 --- a/ocpmodels/common/gfn.py +++ b/ocpmodels/common/gfn.py @@ -295,7 +295,7 @@ def prepare_for_gfn(ckpt_paths: dict, release: str) -> tuple: config["is_debug"] = True config["silent"] = True config["cp_data_to_tmpdir"] = False - config["prevent_loaders"] = { + config["prevent_load"] = { "logger": True, "loss": True, "datasets": True, From 9127a5a22e4be1811328420a662e9f98c6fc98fd Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 8 Oct 2024 09:42:23 -0400 Subject: [PATCH 13/13] refactor to `skip_modules` --- ocpmodels/common/utils.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/ocpmodels/common/utils.py b/ocpmodels/common/utils.py index b7a39d391b..bf27099853 100644 --- a/ocpmodels/common/utils.py +++ b/ocpmodels/common/utils.py @@ -755,7 +755,16 @@ def add_edge_distance_to_graph( # Copied from https://github.com/facebookresearch/mmf/blob/master/mmf/utils/env.py#L89. -def setup_imports(skip_imports=[]): +def setup_imports(skip_modules=[]): + """Automatically load all of the modules, so that they register within the registry. + + Parameters + ---------- + skip_modules : list, optional + List of modules (as ``str``) to skip while importing, by default []. Use module + names not paths, for instance, to skip ``ocpmodels.models.gemnet_oc.gemnet_oc``, + use ``skip_modules=["gemnet_oc"]``. + """ from ocpmodels.common.registry import registry try: @@ -803,7 +812,7 @@ def setup_imports(skip_imports=[]): splits = f.split(os.sep) file_name = splits[-1] module_name = file_name[: file_name.find(".py")] - if module_name not in skip_imports: + if module_name not in skip_modules: importlib.import_module("ocpmodels.%s.%s" % (key[1:], module_name)) # manual model imports @@ -1191,7 +1200,7 @@ def build_config(args, args_override=[], dict_overrides={}, silent=None): # load config from `model-task-split` pattern config = load_config(args.config) - # overwride with command-line args, including default values + # override with command-line args, including default values config = merge_dicts(config, args_dict_with_defaults) # override with build_config()'s overrides config = merge_dicts(config, overrides) @@ -1801,7 +1810,7 @@ def make_script_trainer(str_args=[], overrides={}, silent=False, mode="train"): return trainer -def make_config_from_dir(path, mode, overrides={}, silent=None, skip_imports=[]): +def make_config_from_dir(path, mode, overrides={}, silent=None, skip_modules=[]): """ Make a config from a directory. This is useful when restarting or continuing from a previous run. @@ -1838,11 +1847,11 @@ def make_config_from_dir(path, mode, overrides={}, silent=None, skip_imports=[]) config = build_config(default_args, silent=silent) config = merge_dicts(config, overrides) - setup_imports(skip_imports=skip_imports) + setup_imports(skip_modules=skip_modules) return config -def make_trainer_from_dir(path, mode, overrides={}, silent=None, skip_imports=[]): +def make_trainer_from_dir(path, mode, overrides={}, silent=None, skip_modules=[]): """ Make a trainer from a directory. @@ -1858,7 +1867,7 @@ def make_trainer_from_dir(path, mode, overrides={}, silent=None, skip_imports=[] Returns: Trainer: The loaded trainer. """ - config = make_config_from_dir(path, mode, overrides, silent, skip_imports) + config = make_config_from_dir(path, mode, overrides, silent, skip_modules) return registry.get_trainer_class(config["trainer"])(**config)