1+ import importlib
12import os .path
23from abc import ABC , abstractmethod
34from typing import Any , Dict , Optional , Tuple , Union
45
56import torch
7+ from lightning .pytorch import LightningModule
68from torch import Tensor
79
8- from chebai .custom_typehints import ModelConfig
9- from chebai .models import ChebaiBaseNet , Electra
10+ from chebai .models import ChebaiBaseNet
1011from chebai .preprocessing .structures import XYData
1112
1213
1314class _EnsembleBase (ChebaiBaseNet , ABC ):
14- def __init__ (self , model_configs : Dict [str , ModelConfig ], ** kwargs ):
15+ def __init__ (self , model_configs : Dict [str , Dict ], ** kwargs ):
1516 super ().__init__ (** kwargs )
17+ self ._validate_model_configs (model_configs )
1618
17- self .models : Dict [str , ChebaiBaseNet ] = {}
18- self .model_configs : Dict [ str , ModelConfig ] = model_configs
19+ self .models : Dict [str , LightningModule ] = {}
20+ self .model_configs = model_configs
1921
2022 for model_name in self .model_configs :
21- model_path = self .model_configs [model_name ]["path" ]
22- if not os .path .exists (model_path ):
23+ model_ckpt_path = self .model_configs [model_name ]["ckpt_path" ]
24+ model_class_path = self .model_configs [model_name ]["class_path" ]
25+ if not os .path .exists (model_ckpt_path ):
2326 raise FileNotFoundError (
24- f"Model path '{ model_path } ' for '{ model_name } ' does not exist."
27+ f"Model path '{ model_ckpt_path } ' for '{ model_name } ' does not exist."
2528 )
2629
27- # Attempt to load the model to check validity
30+ class_name = model_class_path .split ("." )[- 1 ]
31+ module_path = "." .join (model_class_path .split ("." )[:- 1 ])
32+
2833 try :
29- self .models [model_name ] = Electra .load_from_checkpoint (
30- model_path , map_location = self .device
31- )
34+ module = importlib .import_module (module_path )
35+ lightning_cls : LightningModule = getattr (module , class_name )
36+
37+ model = lightning_cls .load_from_checkpoint (model_ckpt_path )
38+ model .eval ()
39+ model .freeze ()
40+ self .models [model_name ] = model
41+
42+ except ModuleNotFoundError :
43+ print (f"Module '{ module_path } ' not found!" )
44+ except AttributeError :
45+ print (f"Class '{ class_name } ' not found in '{ module_path } '!" )
46+
3247 except Exception as e :
3348 raise RuntimeError (
34- f"Failed to load model '{ model_name } ' from { model_path } : { e } "
49+ f"Failed to load model '{ model_name } ' from { model_ckpt_path } : \n { e } "
3550 )
3651
37- for model in self .models .values ():
38- model .freeze ()
39-
4052 # TODO: Later discuss whether this threshold should be independent of metric threshold or not ?
4153 # if kwargs.get("threshold") is None:
4254 # first_metric_key = next(iter(self.train_metrics)) # Get the first key
@@ -45,27 +57,12 @@ def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
4557 # else:
4658 # self.threshold = int(kwargs["threshold"])
4759
48- @abstractmethod
49- def _get_prediction_and_labels (
50- self , data : Dict [str , Any ], labels : torch .Tensor , output : torch .Tensor
51- ) -> (torch .Tensor , torch .Tensor ):
52- pass
53-
54-
55- class ChebiEnsemble (_EnsembleBase ):
56-
57- NAME = "ChebiEnsemble"
58-
59- def __init__ (self , model_configs : Dict [str , ModelConfig ], ** kwargs ):
60- self ._validate_model_configs (model_configs )
61- super ().__init__ (model_configs , ** kwargs )
62- # Add a dummy trainable parameter
63- self .dummy_param = torch .nn .Parameter (torch .randn (1 , requires_grad = True ))
64-
6560 @classmethod
66- def _validate_model_configs (cls , model_configs : Dict [str , ModelConfig ]):
61+ def _validate_model_configs (cls , model_configs : Dict [str , Dict ]):
6762 path_set = set ()
68- required_keys = {"path" , "TPV" , "FPV" }
63+ class_set = set ()
64+
65+ required_keys = {"class_path" , "ckpt_path" }
6966
7067 for model_name , config in model_configs .items ():
7168 missing_keys = required_keys - config .keys ()
@@ -75,27 +72,65 @@ def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]):
7572 f"Missing keys { missing_keys } in model '{ model_name } ' configuration."
7673 )
7774
78- model_path = config ["path" ]
75+ model_path = config ["ckpt_path" ]
76+ class_path = config ["class_path" ]
7977
8078 # if model_path in path_set:
8179 # raise ValueError(
8280 # f"Duplicate model path detected: '{model_path}'. Each model must have a unique path."
8381 # )
8482
83+ # if class_path not in class_set:
84+ # raise ValueError(
85+ # f"Duplicate class path detected: '{class_path}'. Each model must have a unique path."
86+ # )
87+
8588 path_set .add (model_path )
89+ class_set .add (class_path )
8690
87- # Validate 'tpv' and 'fpv' are either floats or convertible to float
88- for key in ["TPV" , "FPV" ]:
89- try :
90- value = float (config [key ])
91- if value < 0 :
92- raise ValueError (
93- f"'{ key } ' in model '{ model_name } ' must be non-negative, but got { value } ."
94- )
95- except (TypeError , ValueError ):
91+ cls ._extra_validation (model_name , config )
92+
93+ @classmethod
94+ def _extra_validation (cls , model_name : str , config : Dict [str , Any ]):
95+ pass
96+
97+ @abstractmethod
98+ def _get_prediction_and_labels (
99+ self , data : Dict [str , Any ], labels : torch .Tensor , output : torch .Tensor
100+ ) -> (torch .Tensor , torch .Tensor ):
101+ pass
102+
103+
104+ class ChebiEnsemble (_EnsembleBase ):
105+
106+ NAME = "ChebiEnsemble"
107+
108+ def __init__ (self , model_configs : Dict [str , Dict ], ** kwargs ):
109+ super ().__init__ (model_configs , ** kwargs )
110+
111+ # Add a dummy trainable parameter
112+ self .dummy_param = torch .nn .Parameter (torch .randn (1 , requires_grad = True ))
113+
114+ @classmethod
115+ def _extra_validation (cls , model_name : str , config : Dict [str , Any ]):
116+
117+ if "TPV" not in config .keys () or "FPV" not in config .keys ():
118+ raise AttributeError (
119+ f"Missing keys 'TPV' and/or 'FPV' in model '{ model_name } ' configuration."
120+ )
121+
122+ # Validate 'tpv' and 'fpv' are either floats or convertible to float
123+ for key in ["TPV" , "FPV" ]:
124+ try :
125+ value = float (config [key ])
126+ if value < 0 :
96127 raise ValueError (
97- f"'{ key } ' in model '{ model_name } ' must be a float or convertible to float , but got { config [ key ] } ."
128+ f"'{ key } ' in model '{ model_name } ' must be non-negative , but got { value } ."
98129 )
130+ except (TypeError , ValueError ):
131+ raise ValueError (
132+ f"'{ key } ' in model '{ model_name } ' must be a float or convertible to float, but got { config [key ]} ."
133+ )
99134
100135 def forward (self , data : Dict [str , Tensor ], ** kwargs : Any ) -> Dict [str , Any ]:
101136 predictions = {}
0 commit comments