-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample.py
More file actions
257 lines (213 loc) · 11.1 KB
/
example.py
File metadata and controls
257 lines (213 loc) · 11.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
from rdkit import Chem
from rdkit.Chem import AllChem
from abc import ABC, abstractmethod
import logging
from constants import GROUP_TO_SMARTS, GROUP_TO_ADD_SMILES
logger = logging.getLogger(__name__)
class BaseTool(ABC):
name: str
func_name: str
description: str
func_doc: tuple
func_description: str
examples: list
def __init__(self, init=True, interface='text') -> None:
assert interface in ('text', 'code'), "Interface '%s' is not supported. Please use 'text' or 'code'." % interface
self.interface = interface
super().__init__()
if init:
self._init_modules()
def _init_modules(self):
pass
def __call__(self, *args, **kwargs):
logger.debug("===== Starting tool {} =====".format(self.__class__.name))
if self.interface == 'text':
r = self.run_text(args[0])
elif self.interface == 'code':
r = self.run_code(*args, **kwargs)
else:
raise NotImplementedError("Interface '%s' is not supported. Please use 'text' or 'code'." % self.interface)
logger.debug("----- Ending tool {} -----".format(self.__class__.name))
return r
def run_text(self, query, *args, **kwargs):
return self._run_text(query, *args, **kwargs)
def run_code(self, *args, **kwargs):
return self._run_code(*args, **kwargs)
def _run_text(self, query, *args, **kwargs):
return str(self._run_base(query, *args, **kwargs))
def _run_code(self, *args, **kwargs):
return self._run_base(*args, **kwargs)
@abstractmethod
def _run_base(self, *args, **kwargs):
raise NotImplementedError
def run(self, query, *args, **kwargs):
raise DeprecationWarning("The run function is deprecated. Please modify the implementation.")
class CountMolAtoms(BaseTool):
name = "CountMolAtoms"
func_name = 'count_molecule_atoms'
description = "Count the number of atoms in a molecule. Input SMILES, returns the types of atoms and their numbers."
func_doc = ("smiles: str", "str")
func_description = description
input_output_description = "input: SMILES, output: the atom number and atom type in this SMILES."
examples = [
{'input': 'CCO', 'output': 'There are altogether 3 atoms (omitting hydrogen atoms). The types and corresponding numbers are: {"C": 2, "O": 1}'},
]
def _run_base(self, smiles: str, *args, **kwargs) -> str:
mol = Chem.MolFromSmiles(smiles, sanitize=False)
if mol is None:
return "Error: Invalid SMILES string"
mol = Chem.rdmolops.AddHs(mol, explicitOnly=True)
# Count the number of atoms
num_atoms = mol.GetNumAtoms()
# Get the atom types
atom_types = [atom.GetSymbol() for atom in mol.GetAtoms()]
# Count the occurrences of each atom type
atom_type_counts = {atom: atom_types.count(atom) for atom in set(atom_types)}
text = "There are altogether %d atoms (omitting hydrogen atoms). The types and corresponding numbers are: %s" % (num_atoms, str(atom_type_counts))
return text
class ReplaceFunctionalGroup(BaseTool):
name = "ReplaceFunctionalGroup"
func_name = 'replace_functional_group'
description = "Replace a functional group in a molecule with another functional group using functional group names from constants.py."
func_doc = ("base_smiles: str, old_group_name: str, new_group_name: str", "str")
func_description = description
input_output_description = "input: base molecule SMILES, old group name, new group name, output: modified molecule SMILES"
examples = [
{'input': 'CCO hydroxyl primary_amine', 'output': 'CCN'},
{'input': 'CC=O aldehyde carboxyl', 'output': 'CC(=O)O'},
]
def _run_base(self, base_smiles: str, old_group_name: str, new_group_name: str, *args, **kwargs) -> str:
mol = Chem.MolFromSmiles(base_smiles)
if not mol:
return "Error: Invalid base molecule SMILES"
# 查找旧官能团的SMARTS模式
if old_group_name not in GROUP_TO_SMARTS:
available_groups = ", ".join(list(GROUP_TO_SMARTS.keys())[:10]) + "..."
return f"Error: Unknown old group name '{old_group_name}'. Available groups: {available_groups}"
old_group_smarts = GROUP_TO_SMARTS[old_group_name]
pattern = Chem.MolFromSmarts(old_group_smarts)
if not pattern:
return f"Error: Invalid SMARTS pattern for '{old_group_name}': {old_group_smarts}"
# 查找新官能团的SMILES - 优先从ADD_SMILES查找,去掉*
if new_group_name in GROUP_TO_ADD_SMILES:
new_group_smiles = GROUP_TO_ADD_SMILES[new_group_name].replace('*', '')
else:
# 一些常见的简单替换
simple_groups = {
'hydrogen': '[H]',
'methyl': 'C',
'ethyl': 'CC',
'hydroxyl': 'O',
'primary_amine': 'N',
'chloro': 'Cl',
'bromo': 'Br',
'fluoro': 'F',
'iodo': 'I'
}
if new_group_name in simple_groups:
new_group_smiles = simple_groups[new_group_name]
else:
available_groups = ", ".join(list(GROUP_TO_ADD_SMILES.keys())[:10]) + "..."
return f"Error: Unknown new group name '{new_group_name}'. Available groups: {available_groups}"
replacement = Chem.MolFromSmiles(new_group_smiles)
if not replacement:
return f"Error: Invalid SMILES for '{new_group_name}': {new_group_smiles}"
# 执行替换
new_mols = AllChem.ReplaceSubstructs(mol, pattern, replacement, replaceAll=True)
if not new_mols:
return f"Warning: No '{old_group_name}' functional group found in molecule, returning original: {base_smiles}"
# ReplaceSubstructs 返回一个元组,我们取第一个结果
result_mol = new_mols[0]
try:
Chem.SanitizeMol(result_mol)
return Chem.MolToSmiles(result_mol)
except:
return "Error: Failed to sanitize resulting molecule"
class RemoveFunctionalGroup(BaseTool):
name = "RemoveFunctionalGroup"
func_name = 'remove_functional_group'
description = "Remove a specific functional group from a molecule using functional group names from constants.py."
func_doc = ("base_smiles: str, group_name: str", "str")
func_description = description
input_output_description = "input: base molecule SMILES, group name to remove, output: modified molecule SMILES"
examples = [
{'input': 'CCO hydroxyl', 'output': 'CC'},
{'input': 'CCCl halo', 'output': 'CC'},
]
def _run_base(self, base_smiles: str, group_name: str, *args, **kwargs) -> str:
mol = Chem.MolFromSmiles(base_smiles)
if not mol:
return "Error: Invalid base molecule SMILES"
# 查找官能团的SMARTS模式
if group_name not in GROUP_TO_SMARTS:
available_groups = ", ".join(list(GROUP_TO_SMARTS.keys())[:10]) + "..."
return f"Error: Unknown group name '{group_name}'. Available groups: {available_groups}"
group_smarts = GROUP_TO_SMARTS[group_name]
pattern = Chem.MolFromSmarts(group_smarts)
if not pattern:
return f"Error: Invalid SMARTS pattern for '{group_name}': {group_smarts}"
# 执行删除
try:
result_mol = AllChem.DeleteSubstructs(mol, pattern, onlyFrags=False)
Chem.SanitizeMol(result_mol)
result_smiles = Chem.MolToSmiles(result_mol)
if result_smiles == base_smiles:
return f"Warning: No '{group_name}' functional group found in molecule, returning original: {base_smiles}"
return result_smiles
except:
return "Error: Failed to remove functional group or sanitize resulting molecule"
class AddFunctionalGroup(BaseTool):
name = "AddFunctionalGroup"
func_name = 'add_functional_group'
description = "Add a functional group to a molecule at a specified atom index by replacing a hydrogen atom using functional group names from constants.py."
func_doc = ("base_smiles: str, group_name: str, atom_index_to_attach: int", "str")
func_description = description
input_output_description = "input: base molecule SMILES, group name, atom index, output: modified molecule SMILES"
examples = [
{'input': 'CCC carboxyl 0', 'output': 'C(C(=O)O)CC'},
{'input': 'CCC hydroxyl 2', 'output': 'CCO'},
]
def _run_base(self, base_smiles: str, group_name: str, atom_index_to_attach: int, *args, **kwargs) -> str:
mol = Chem.MolFromSmiles(base_smiles)
if not mol:
return "Error: Invalid base molecule SMILES"
if atom_index_to_attach >= mol.GetNumAtoms():
return f"Error: Atom index {atom_index_to_attach} out of range (molecule has {mol.GetNumAtoms()} atoms)"
# 查找官能团的SMILES模式
if group_name not in GROUP_TO_ADD_SMILES:
available_groups = ", ".join(list(GROUP_TO_ADD_SMILES.keys())[:10]) + "..."
return f"Error: Unknown group name '{group_name}'. Available groups: {available_groups}"
group_to_add_smiles = GROUP_TO_ADD_SMILES[group_name]
# 检查目标原子是否有可用的氢原子
target_atom = mol.GetAtomWithIdx(atom_index_to_attach)
if target_atom.GetTotalNumHs() == 0:
return f"Error: Atom at index {atom_index_to_attach} has no hydrogen atoms to replace"
# 使用EditableMol来直接编辑分子结构
editable_mol = Chem.EditableMol(mol)
# 解析要添加的官能团,去掉连接点标记(*)
group_smiles_clean = group_to_add_smiles.replace('*', '')
group_mol = Chem.MolFromSmiles(group_smiles_clean)
if not group_mol:
return f"Error: Invalid SMILES for functional group '{group_name}': {group_smiles_clean}"
# 获取原始分子的原子数,用于后续的原子索引计算
original_atom_count = mol.GetNumAtoms()
# 添加官能团的所有原子到分子中
atom_index_map = {}
for atom in group_mol.GetAtoms():
new_atom_idx = editable_mol.AddAtom(atom)
atom_index_map[atom.GetIdx()] = new_atom_idx
# 添加官能团内部的键
for bond in group_mol.GetBonds():
begin_idx = atom_index_map[bond.GetBeginAtomIdx()]
end_idx = atom_index_map[bond.GetEndAtomIdx()]
editable_mol.AddBond(begin_idx, end_idx, bond.GetBondType())
# 连接官能团的第一个原子到目标原子
if len(atom_index_map) > 0:
first_group_atom_idx = atom_index_map[0] # 官能团的第一个原子
editable_mol.AddBond(atom_index_to_attach, first_group_atom_idx, Chem.BondType.SINGLE)
try:
result_mol = editable_mol.GetMol()
Chem.SanitizeMol(result_mol)
return Chem.MolToSmiles(result_mol)
except Exception as e:
return f"Error: Failed to create or sanitize resulting molecule: {str(e)}"