Skip to content

Commit 43dadc4

Browse files
authored
Fix PubChem pretraining (#159)
* allow input_dim and out_dim of 0 * dont pass config to base model * catch missing labels if no labels exist * only create classes.txt path for dynamic datasets * change message for loading properties
1 parent a55f0f3 commit 43dadc4

File tree

5 files changed

+32
-12
lines changed

5 files changed

+32
-12
lines changed

chebai/models/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def __init__(
4848
if exclude_hyperparameter_logging is None:
4949
exclude_hyperparameter_logging = tuple()
5050
self.criterion = criterion
51-
assert out_dim is not None and out_dim > 0, "out_dim must be specified"
52-
assert input_dim is not None and input_dim > 0, "input_dim must be specified"
51+
assert out_dim is not None, "out_dim must be specified"
52+
assert input_dim is not None, "input_dim must be specified"
5353
self.out_dim = out_dim
5454
self.input_dim = input_dim
5555
print(

chebai/models/electra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ class ElectraPre(ChebaiBaseNet):
3939
replace_p (float): Probability of replacing tokens during training.
4040
"""
4141

42-
def __init__(self, config: Dict[str, Any] = None, **kwargs: Any):
43-
super().__init__(config=config, **kwargs)
42+
def __init__(self, config: Dict[str, Any], **kwargs: Any):
43+
super().__init__(**kwargs)
4444

4545
self.generator_config = ElectraConfig(**config["generator"])
4646
self.generator = ElectraForMaskedLM(self.generator_config)

chebai/preprocessing/collate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,11 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
8787
*((d["features"], d["labels"], d.get("ident")) for d in data)
8888
)
8989
missing_labels = [
90-
d.get("missing_labels", [False for _ in y[0]]) for d in data
90+
d.get(
91+
"missing_labels",
92+
[False for _ in y[0]] if y[0] is not None else [False],
93+
)
94+
for d in data
9195
]
9296

9397
if any(x is not None for x in y):

chebai/preprocessing/datasets/base.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -514,11 +514,13 @@ def setup(self, *args, **kwargs) -> None:
514514

515515
rank_zero_info(f"Check for processed data in {self.processed_dir}")
516516
rank_zero_info(f"Cross-validation enabled: {self.use_inner_cross_validation}")
517-
rank_zero_info(f"Looking for files: {self.processed_file_names}")
518517
if any(
519518
not os.path.isfile(os.path.join(self.processed_dir, f))
520519
for f in self.processed_file_names
521520
):
521+
rank_zero_info(
522+
f"Did not find one of: {', '.join(self.processed_file_names)} in {self.processed_dir}"
523+
)
522524
self.setup_processed()
523525

524526
self._after_setup(**kwargs)
@@ -627,17 +629,17 @@ def raw_file_names_dict(self) -> dict:
627629
raise NotImplementedError
628630

629631
@property
630-
def classes_txt_file_path(self) -> str:
632+
def classes_txt_file_path(self) -> Optional[str]:
631633
"""
632-
Returns the filename for the classes text file.
634+
Returns the filename for the classes text file (for labeled datasets that produce a list of labels).
633635
634636
Returns:
635-
str: The filename for the classes text file.
637+
Optional[str]: The filename for the classes text file.
636638
"""
637639
# This property also used in following places:
638640
# - chebai/result/prediction.py: to load class names for csv columns names
639641
# - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path`
640-
return os.path.join(self.processed_dir_main, "classes.txt")
642+
return None
641643

642644

643645
class MergedDataset(XYBaseDataModule):
@@ -1264,3 +1266,16 @@ def processed_file_names_dict(self) -> dict:
12641266
if self.n_token_limit is not None:
12651267
return {"data": f"data_maxlen{self.n_token_limit}.pt"}
12661268
return {"data": "data.pt"}
1269+
1270+
@property
1271+
def classes_txt_file_path(self) -> str:
1272+
"""
1273+
Returns the filename for the classes text file.
1274+
1275+
Returns:
1276+
str: The filename for the classes text file.
1277+
"""
1278+
# This property also used in following places:
1279+
# - chebai/result/prediction.py: to load class names for csv columns names
1280+
# - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path`
1281+
return os.path.join(self.processed_dir_main, "classes.txt")

chebai/preprocessing/datasets/pubchem.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,9 @@ def _set_processed_data_props(self):
195195
self._num_of_labels = 0
196196
self._feature_vector_size = 0
197197

198-
print(f"Number of labels for loaded data: {self._num_of_labels}")
199-
print(f"Feature vector size: {self._feature_vector_size}")
198+
print(
199+
f"Number of labels and feature vector size set to: {self._num_of_labels} / {self._feature_vector_size} (default values, not used for self-supervised learning)"
200+
)
200201

201202
def _perform_data_preparation(self, *args, **kwargs):
202203
"""

0 commit comments

Comments
 (0)