Skip to content

Commit 8734dc8

Browse files
authored
fixes for tox21 dataset (#163)
* fixes for tox21 dataset * ruff fixes
1 parent f910a39 commit 8734dc8

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

chebai/preprocessing/datasets/pubchem.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def _load_dict(input_file_path: str) -> Generator[dict, None, None]:
106106
Yields:
107107
dict: Dictionary containing 'features', 'labels' (None), and 'ident' fields.
108108
"""
109+
# pubchem IDs are here
109110
with open(input_file_path, "r") as input_file:
110111
for row in input_file:
111112
ident, smiles = row.split("\t")

chebai/preprocessing/datasets/tox21.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def _load_dict(self, input_file_path: str) -> List[Dict]:
161161
for row in reader:
162162
smiles = row["smiles"]
163163
labels = [
164-
bool(int(float(label))) if len(label) > 1 else None
164+
bool(int(float(label))) if len(label) >= 1 else None
165165
for label in (row[k] for k in self.HEADERS)
166166
]
167167
# group = int(row["group"])

chebai/preprocessing/reader.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,11 @@ def _read_components(self, row: Dict[str, Any]) -> Dict[str, Any]:
9898
under the additional `missing_labels` keyword."""
9999
labels = self._get_raw_label(row)
100100
additional_kwargs = self._get_additional_kwargs(row)
101-
if any(label is None for label in labels):
102-
additional_kwargs["missing_labels"] = [label is None for label in labels]
101+
if labels is not None:
102+
if any(label is None for label in labels):
103+
additional_kwargs["missing_labels"] = [
104+
label is None for label in labels
105+
]
103106
return dict(
104107
features=self._get_raw_data(row),
105108
labels=labels,

0 commit comments

Comments
 (0)