Skip to content

Commit b32e6c5

Browse files
authored
Merge pull request #141 from ChEB-AI/fix/file_not_found_for_loss
BCE Loss unable to locate processed files
2 parents c9c08dc + a5ea56a commit b32e6c5

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,9 +1173,7 @@ def _retrieve_splits_from_csv(self) -> None:
11731173
splits_df = pd.read_csv(self.splits_file_path)
11741174

11751175
filename = self.processed_file_names_dict["data"]
1176-
data = self.load_processed_data_from_file(
1177-
os.path.join(self.processed_dir, filename)
1178-
)
1176+
data = self.load_processed_data_from_file(filename)
11791177
df_data = pd.DataFrame(data)
11801178

11811179
if self.apply_id_filter:
@@ -1254,8 +1252,23 @@ def load_processed_data(
12541252
# If filename is provided
12551253
return self.load_processed_data_from_file(filename)
12561254

1257-
def load_processed_data_from_file(self, filename):
1258-
return torch.load(os.path.join(filename), weights_only=False)
1255+
def load_processed_data_from_file(self, filename: str) -> list[dict[str, Any]]:
1256+
"""Load processed data from a file.
1257+
1258+
The full path is not required; only the filename is needed, as it will be joined with the processed directory.
1259+
1260+
Args:
1261+
filename (str): The name of the file to load the processed data from.
1262+
1263+
Returns:
1264+
List[Dict[str, Any]]: The loaded processed data.
1265+
1266+
Example:
1267+
data = self.load_processed_data_from_file('data.pt')
1268+
"""
1269+
return torch.load(
1270+
os.path.join(self.processed_dir, filename), weights_only=False
1271+
)
12591272

12601273
# ------------------------------ Phase: Raw Properties -----------------------------------
12611274
@property

chebai/preprocessing/datasets/chebi.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,7 @@ def _get_data_splits(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
516516
"""
517517
try:
518518
filename = self.processed_file_names_dict["data"]
519-
data_chebi_version = self.load_processed_data_from_file(
520-
os.path.join(self.processed_dir, filename)
521-
)
519+
data_chebi_version = self.load_processed_data_from_file(filename)
522520
except FileNotFoundError:
523521
raise FileNotFoundError(
524522
"File data.pt doesn't exists. "

0 commit comments

Comments
 (0)