Skip to content

Commit efac8aa

Browse files
authored
Merge pull request #138 from ChEB-AI/fix/read_data
Raise error for invalid smiles and return None
2 parents 0a66ef4 + 2ead405 commit efac8aa

File tree

3 files changed

+48
-26
lines changed

3 files changed

+48
-26
lines changed

chebai/preprocessing/reader.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -199,23 +199,26 @@ def _read_data(self, raw_data: str) -> List[int]:
199199
Returns:
200200
List[int]: A list of integers representing the indices of the SMILES tokens.
201201
"""
202-
if self.canonicalize_smiles:
203-
try:
204-
mol = Chem.MolFromSmiles(raw_data.strip())
205-
if mol is not None:
206-
raw_data = Chem.MolToSmiles(mol, canonical=True)
207-
except Exception as e:
208-
print(f"RDKit failed to process {raw_data}")
209-
print(f"\t{e}")
210202
try:
211-
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]
203+
mol = Chem.MolFromSmiles(raw_data.strip())
204+
if mol is None:
205+
raise ValueError(f"Invalid SMILES: {raw_data}")
212206
except ValueError as e:
213207
print(f"could not process {raw_data}")
214-
print(f"\t{e}")
208+
print(f"\tError: {e}")
215209
return None
216210

217-
def _back_to_smiles(self, smiles_encoded):
211+
if self.canonicalize_smiles:
212+
try:
213+
raw_data = Chem.MolToSmiles(mol, canonical=True)
214+
except Exception as e:
215+
print(f"RDKit failed to canonicalize the SMILES: {raw_data}")
216+
print(f"\t{e}")
217+
return None
218+
219+
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]
218220

221+
def _back_to_smiles(self, smiles_encoded):
219222
token_file = self.reader.token_path
220223
token_coding = {}
221224
counter = 0

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ dependencies = [
2121
"torch",
2222
"transformers",
2323
"pysmiles==1.1.2",
24-
"rdkit",
24+
"rdkit==2024.3.6",
2525
"lightning==2.5.1",
2626
]
2727

tests/unit/readers/testChemDataReader.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,22 @@ def test_read_data(self) -> None:
4242
"""
4343
Test the _read_data method with a SMILES string to ensure it correctly tokenizes the string.
4444
"""
45-
raw_data = "CC(=O)NC1[Mg-2]"
45+
raw_data = "CC(=O)NC1CC1[Mg-2]"
4646
# Expected output as per the tokens already in the cache, and ")" getting added to it.
4747
expected_output: List[int] = [
4848
EMBEDDING_OFFSET + 0, # C
4949
EMBEDDING_OFFSET + 0, # C
50-
EMBEDDING_OFFSET + 5, # =
51-
EMBEDDING_OFFSET + 3, # O
52-
EMBEDDING_OFFSET + 1, # N
53-
EMBEDDING_OFFSET + len(self.reader.cache), # (
54-
EMBEDDING_OFFSET + 2, # C
50+
EMBEDDING_OFFSET + 5, # (
51+
EMBEDDING_OFFSET + 3, # =
52+
EMBEDDING_OFFSET + 1, # O
53+
EMBEDDING_OFFSET + len(self.reader.cache), # ) - new token
54+
EMBEDDING_OFFSET + 2, # N
5555
EMBEDDING_OFFSET + 0, # C
5656
EMBEDDING_OFFSET + 4, # 1
57-
EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2]
57+
EMBEDDING_OFFSET + 0, # C
58+
EMBEDDING_OFFSET + 0, # C
59+
EMBEDDING_OFFSET + 4, # 1
60+
EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2] - new token
5861
]
5962
result = self.reader._read_data(raw_data)
6063
self.assertEqual(
@@ -99,13 +102,29 @@ def test_read_data_with_invalid_input(self) -> None:
99102
Test the _read_data method with an invalid input.
100103
The invalid token should prompt a return value None
101104
"""
102-
raw_data = "%INVALID%"
103-
104-
result = self.reader._read_data(raw_data)
105-
self.assertIsNone(
106-
result,
107-
"The output for invalid token '%INVALID%' should be None.",
108-
)
105+
# see https://github.com/ChEB-AI/python-chebai/issues/137
106+
raw_datas = ["%INVALID%", "ADADAD", "ADASDAD", "CC(=O)NC1[Mg-2]"]
107+
for raw_data in raw_datas:
108+
result = self.reader._read_data(raw_data)
109+
self.assertIsNone(
110+
result,
111+
f"The output for invalid token '{raw_data}' should be None.",
112+
)
113+
114+
def test_read_data_with_invalid_input_with_no_canonicalize(self) -> None:
115+
"""
116+
Test the _read_data method with an invalid input.
117+
The invalid token should prompt a return value None
118+
"""
119+
self.reader.canonicalize_smiles = False
120+
raw_datas = ["%INVALID%", "ADADAD", "ADASDAD", "CC(=O)NC1[Mg-2]"]
121+
for raw_data in raw_datas:
122+
result = self.reader._read_data(raw_data)
123+
self.assertIsNone(
124+
result,
125+
f"The output for invalid token '{raw_data}' should be None.",
126+
)
127+
self.reader.canonicalize_smiles = True # Reset to original state
109128

110129
@patch("builtins.open", new_callable=mock_open)
111130
def test_finish_method_for_new_tokens(self, mock_file: mock_open) -> None:

0 commit comments

Comments
 (0)