Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,17 +205,19 @@ def _read_data(self, raw_data: str) -> List[int]:
if mol is not None:
raw_data = Chem.MolToSmiles(mol, canonical=True)
except Exception as e:
print(f"RDKit failed to process {raw_data}")
print(f"RDKit failed to canonicalize the SMILES: {raw_data}")
print(f"\t{e}")
try:
mol = Chem.MolFromSmiles(raw_data.strip())
Comment thread
aditya0by0 marked this conversation as resolved.
if mol is None:
raise ValueError(f"Invalid SMILES: {raw_data}")
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]
except ValueError as e:
print(f"could not process {raw_data}")
print(f"\t{e}")
print(f"\tError: {e}")
return None

def _back_to_smiles(self, smiles_encoded):

token_file = self.reader.token_path
token_coding = {}
counter = 0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"torch",
"transformers",
"pysmiles==1.1.2",
"rdkit",
"rdkit==2024.3.6",
"lightning==2.5.1",
]

Expand Down
32 changes: 18 additions & 14 deletions tests/unit/readers/testChemDataReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,22 @@ def test_read_data(self) -> None:
"""
Test the _read_data method with a SMILES string to ensure it correctly tokenizes the string.
"""
raw_data = "CC(=O)NC1[Mg-2]"
raw_data = "CC(=O)NC1CC1[Mg-2]"
# Expected output as per the tokens already in the cache, and ")" getting added to it.
expected_output: List[int] = [
EMBEDDING_OFFSET + 0, # C
EMBEDDING_OFFSET + 0, # C
EMBEDDING_OFFSET + 5, # =
EMBEDDING_OFFSET + 3, # O
EMBEDDING_OFFSET + 1, # N
EMBEDDING_OFFSET + len(self.reader.cache), # (
EMBEDDING_OFFSET + 2, # C
EMBEDDING_OFFSET + 5, # (
EMBEDDING_OFFSET + 3, # =
EMBEDDING_OFFSET + 1, # O
EMBEDDING_OFFSET + len(self.reader.cache), # ) - new token
EMBEDDING_OFFSET + 2, # N
EMBEDDING_OFFSET + 0, # C
EMBEDDING_OFFSET + 4, # 1
EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2]
EMBEDDING_OFFSET + 0, # C
EMBEDDING_OFFSET + 0, # C
EMBEDDING_OFFSET + 4, # 1
EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2] - new token
]
result = self.reader._read_data(raw_data)
self.assertEqual(
Expand Down Expand Up @@ -99,13 +102,14 @@ def test_read_data_with_invalid_input(self) -> None:
Test the _read_data method with an invalid input.
The invalid token should prompt a return value None
"""
raw_data = "%INVALID%"

result = self.reader._read_data(raw_data)
self.assertIsNone(
result,
"The output for invalid token '%INVALID%' should be None.",
)
# see https://github.com/ChEB-AI/python-chebai/issues/137
raw_datas = ["%INVALID%", "ADADAD", "ADASDAD", "CC(=O)NC1[Mg-2]"]
for raw_data in raw_datas:
result = self.reader._read_data(raw_data)
self.assertIsNone(
result,
f"The output for invalid token '{raw_data}' should be None.",
)

@patch("builtins.open", new_callable=mock_open)
def test_finish_method_for_new_tokens(self, mock_file: mock_open) -> None:
Expand Down