Skip to content

Commit 72a6b37

Browse files
committed
ensemble: add class path to config and load model via this class
1 parent f60b2d8 commit 72a6b37

File tree

3 files changed

+81
-56
lines changed

3 files changed

+81
-56
lines changed

chebai/custom_typehints/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

chebai/custom_typehints/model.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

chebai/models/ensemble.py

Lines changed: 81 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,54 @@
1+
import importlib
12
import os.path
23
from abc import ABC, abstractmethod
34
from typing import Any, Dict, Optional, Tuple, Union
45

56
import torch
7+
from lightning.pytorch import LightningModule
68
from torch import Tensor
79

8-
from chebai.custom_typehints import ModelConfig
9-
from chebai.models import ChebaiBaseNet, Electra
10+
from chebai.models import ChebaiBaseNet
1011
from chebai.preprocessing.structures import XYData
1112

1213

1314
class _EnsembleBase(ChebaiBaseNet, ABC):
14-
def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
15+
def __init__(self, model_configs: Dict[str, Dict], **kwargs):
1516
super().__init__(**kwargs)
17+
self._validate_model_configs(model_configs)
1618

17-
self.models: Dict[str, ChebaiBaseNet] = {}
18-
self.model_configs: Dict[str, ModelConfig] = model_configs
19+
self.models: Dict[str, LightningModule] = {}
20+
self.model_configs = model_configs
1921

2022
for model_name in self.model_configs:
21-
model_path = self.model_configs[model_name]["path"]
22-
if not os.path.exists(model_path):
23+
model_ckpt_path = self.model_configs[model_name]["ckpt_path"]
24+
model_class_path = self.model_configs[model_name]["class_path"]
25+
if not os.path.exists(model_ckpt_path):
2326
raise FileNotFoundError(
24-
f"Model path '{model_path}' for '{model_name}' does not exist."
27+
f"Model path '{model_ckpt_path}' for '{model_name}' does not exist."
2528
)
2629

27-
# Attempt to load the model to check validity
30+
class_name = model_class_path.split(".")[-1]
31+
module_path = ".".join(model_class_path.split(".")[:-1])
32+
2833
try:
29-
self.models[model_name] = Electra.load_from_checkpoint(
30-
model_path, map_location=self.device
31-
)
34+
module = importlib.import_module(module_path)
35+
lightning_cls: LightningModule = getattr(module, class_name)
36+
37+
model = lightning_cls.load_from_checkpoint(model_ckpt_path)
38+
model.eval()
39+
model.freeze()
40+
self.models[model_name] = model
41+
42+
except ModuleNotFoundError:
43+
print(f"Module '{module_path}' not found!")
44+
except AttributeError:
45+
print(f"Class '{class_name}' not found in '{module_path}'!")
46+
3247
except Exception as e:
3348
raise RuntimeError(
34-
f"Failed to load model '{model_name}' from {model_path}: {e}"
49+
f"Failed to load model '{model_name}' from {model_ckpt_path}: \n {e}"
3550
)
3651

37-
for model in self.models.values():
38-
model.freeze()
39-
4052
# TODO: Later discuss whether this threshold should be independent of metric threshold or not ?
4153
# if kwargs.get("threshold") is None:
4254
# first_metric_key = next(iter(self.train_metrics)) # Get the first key
@@ -45,27 +57,12 @@ def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
4557
# else:
4658
# self.threshold = int(kwargs["threshold"])
4759

48-
@abstractmethod
49-
def _get_prediction_and_labels(
50-
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
51-
) -> (torch.Tensor, torch.Tensor):
52-
pass
53-
54-
55-
class ChebiEnsemble(_EnsembleBase):
56-
57-
NAME = "ChebiEnsemble"
58-
59-
def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
60-
self._validate_model_configs(model_configs)
61-
super().__init__(model_configs, **kwargs)
62-
# Add a dummy trainable parameter
63-
self.dummy_param = torch.nn.Parameter(torch.randn(1, requires_grad=True))
64-
6560
@classmethod
66-
def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]):
61+
def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
6762
path_set = set()
68-
required_keys = {"path", "TPV", "FPV"}
63+
class_set = set()
64+
65+
required_keys = {"class_path", "ckpt_path"}
6966

7067
for model_name, config in model_configs.items():
7168
missing_keys = required_keys - config.keys()
@@ -75,27 +72,65 @@ def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]):
7572
f"Missing keys {missing_keys} in model '{model_name}' configuration."
7673
)
7774

78-
model_path = config["path"]
75+
model_path = config["ckpt_path"]
76+
class_path = config["class_path"]
7977

8078
# if model_path in path_set:
8179
# raise ValueError(
8280
# f"Duplicate model path detected: '{model_path}'. Each model must have a unique path."
8381
# )
8482

83+
# if class_path not in class_set:
84+
# raise ValueError(
85+
# f"Duplicate class path detected: '{class_path}'. Each model must have a unique path."
86+
# )
87+
8588
path_set.add(model_path)
89+
class_set.add(class_path)
8690

87-
# Validate 'tpv' and 'fpv' are either floats or convertible to float
88-
for key in ["TPV", "FPV"]:
89-
try:
90-
value = float(config[key])
91-
if value < 0:
92-
raise ValueError(
93-
f"'{key}' in model '{model_name}' must be non-negative, but got {value}."
94-
)
95-
except (TypeError, ValueError):
91+
cls._extra_validation(model_name, config)
92+
93+
@classmethod
94+
def _extra_validation(cls, model_name: str, config: Dict[str, Any]):
95+
pass
96+
97+
@abstractmethod
98+
def _get_prediction_and_labels(
99+
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
100+
) -> (torch.Tensor, torch.Tensor):
101+
pass
102+
103+
104+
class ChebiEnsemble(_EnsembleBase):
105+
106+
NAME = "ChebiEnsemble"
107+
108+
def __init__(self, model_configs: Dict[str, Dict], **kwargs):
109+
super().__init__(model_configs, **kwargs)
110+
111+
# Add a dummy trainable parameter
112+
self.dummy_param = torch.nn.Parameter(torch.randn(1, requires_grad=True))
113+
114+
@classmethod
115+
def _extra_validation(cls, model_name: str, config: Dict[str, Any]):
116+
117+
if "TPV" not in config.keys() or "FPV" not in config.keys():
118+
raise AttributeError(
119+
f"Missing keys 'TPV' and/or 'FPV' in model '{model_name}' configuration."
120+
)
121+
122+
# Validate 'tpv' and 'fpv' are either floats or convertible to float
123+
for key in ["TPV", "FPV"]:
124+
try:
125+
value = float(config[key])
126+
if value < 0:
96127
raise ValueError(
97-
f"'{key}' in model '{model_name}' must be a float or convertible to float, but got {config[key]}."
128+
f"'{key}' in model '{model_name}' must be non-negative, but got {value}."
98129
)
130+
except (TypeError, ValueError):
131+
raise ValueError(
132+
f"'{key}' in model '{model_name}' must be a float or convertible to float, but got {config[key]}."
133+
)
99134

100135
def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
101136
predictions = {}

0 commit comments

Comments
 (0)