Skip to content

Commit 5294e3e

Browse files
committed
Clenaup tool_generation, add ToolConfig.module_name, add tests for tool_generation frontend
1 parent 14a0cf0 commit 5294e3e

5 files changed

Lines changed: 122 additions & 64 deletions

File tree

Lines changed: 44 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,17 @@
11
import os, sys
2-
from typing import Optional, Any, List
3-
import importlib.resources
2+
from typing import Optional, Union, Any
43
from pathlib import Path
5-
import json
6-
import sys
7-
from typing import Optional, List, Dict, Union
8-
9-
from . import get_agent_names
10-
from .gen_utils import insert_code_after_tag, string_in_file
11-
from ..utils import open_json_file, get_framework, term_color
12-
import os
134
import shutil
145
import fileinput
156
import ast
167

8+
from agentstack import frameworks
179
from agentstack import packaging
1810
from agentstack import ValidationError
19-
from agentstack.utils import get_package_path
11+
from agentstack.utils import term_color
2012
from agentstack.tools import ToolConfig
2113
from agentstack.generation import astools
2214
from agentstack.generation.files import ConfigFile, EnvFile
23-
from agentstack import frameworks
24-
from .gen_utils import insert_code_after_tag, string_in_file
25-
from ..utils import open_json_file, get_framework, term_color
2615

2716

2817
# This is the filename of the location of tool imports in the user's project.
@@ -44,7 +33,7 @@ def get_import_for_tool(self, tool: ToolConfig) -> ast.Import:
4433
raises a ValidationError if the tool is imported multiple times.
4534
"""
4635
all_imports = astools.get_all_imports(self.tree)
47-
tool_imports = [i for i in all_imports if tool.name in i.names[0].name]
36+
tool_imports = [i for i in all_imports if tool.module_name == i.module]
4837

4938
if len(tool_imports) > 1:
5039
raise ValidationError(f"Multiple imports for tool {tool.name} found in {self.filename}")
@@ -77,106 +66,100 @@ def remove_import_for_tool(self, framework: str, tool: ToolConfig):
7766
Remove an import for a tool.
7867
raises a ValidationError if the tool is not imported.
7968
"""
80-
tool_import = self.get_imports_for_tool(tool)
69+
tool_import = self.get_import_for_tool(tool)
8170
if not tool_import:
8271
raise ValidationError(f"Tool {tool.name} not imported in {self.filename}")
8372

8473
start, end = self.get_node_range(tool_import)
8574
self.edit_node_range(start, end, "")
8675

8776

88-
def add_tool(tool_name: str, agents: Optional[List[str]] = [], path: Optional[str] = None):
89-
if path:
90-
path = path.endswith('/') and path or path + '/'
91-
else:
92-
path = './'
93-
94-
framework = get_framework(path)
77+
def add_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional[Path] = None):
78+
if path is None: path = Path()
9579
agentstack_config = ConfigFile(path)
80+
framework = agentstack_config.framework
9681

9782
if tool_name in agentstack_config.tools:
9883
print(term_color(f'Tool {tool_name} is already installed', 'red'))
9984
sys.exit(1)
10085

101-
tool_data = ToolConfig.from_tool_name(tool_name)
102-
tool_file_path = tool_data.get_impl_file_path(framework)
86+
tool = ToolConfig.from_tool_name(tool_name)
87+
tool_file_path = tool.get_impl_file_path(framework)
10388

104-
if tool_data.packages:
105-
packaging.install(' '.join(tool_data.packages))
106-
shutil.copy(tool_file_path, f'{path}src/tools/{tool_name}_tool.py') # Move tool from package to project
89+
if tool.packages:
90+
packaging.install(' '.join(tool.packages))
91+
92+
# Move tool from package to project
93+
shutil.copy(tool_file_path, path/f'src/tools/{tool.module_name}.py')
10794

10895
try: # Edit the user's project tool init file to include the tool
10996
with ToolsInitFile(path/TOOLS_INIT_FILENAME) as tools_init:
110-
tools_init.add_import_for_tool(tool_data)
97+
tools_init.add_import_for_tool(framework, tool)
11198
except ValidationError as e:
11299
print(term_color(f"Error adding tool:\n{e}", 'red'))
113-
sys.exit(1)
114100

115101
# Edit the framework entrypoint file to include the tool in the agent definition
116-
if not len(agents): # If no agents are specified, add the tool to all agents
102+
if not agents: # If no agents are specified, add the tool to all agents
117103
agents = frameworks.get_agent_names(framework, path)
118104
for agent_name in agents:
119-
frameworks.add_tool(framework, tool_data, agent_name, path)
105+
frameworks.add_tool(framework, tool, agent_name, path)
120106

121-
if tool_data.env: # add environment variables which don't exist
107+
if tool.env: # add environment variables which don't exist
122108
with EnvFile(path) as env:
123-
for var, value in tool_data.env.items():
109+
for var, value in tool.env.items():
124110
env.append_if_new(var, value)
125111
with EnvFile(path, filename=".env.example") as env:
126-
for var, value in tool_data.env.items():
112+
for var, value in tool.env.items():
127113
env.append_if_new(var, value)
128114

129-
if tool_data.post_install:
130-
os.system(tool_data.post_install)
115+
if tool.post_install:
116+
os.system(tool.post_install)
131117

132118
with agentstack_config as config:
133-
config.tools.append(tool_name)
134-
135-
print(term_color(f'🔨 Tool {tool_name} added to agentstack project successfully', 'green'))
136-
if tool_data.cta:
137-
print(term_color(f'🪩 {tool_data.cta}', 'blue'))
119+
config.tools.append(tool.name)
138120

121+
print(term_color(f'🔨 Tool {tool.name} added to agentstack project successfully', 'green'))
122+
if tool.cta:
123+
print(term_color(f'🪩 {tool.cta}', 'blue'))
139124

140-
def remove_tool(tool_name: str, agents: Optional[List[str]] = [], path: Optional[str] = None):
141-
if path:
142-
path = path.endswith('/') and path or path + '/'
143-
else:
144-
path = './'
145125

146-
framework = get_framework()
126+
def remove_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional[Path] = None):
127+
if path is None: path = Path()
147128
agentstack_config = ConfigFile(path)
129+
framework = agentstack_config.framework
148130

149131
if not tool_name in agentstack_config.tools:
150132
print(term_color(f'Tool {tool_name} is not installed', 'red'))
151133
sys.exit(1)
152134

153-
tool_data = ToolConfig.from_tool_name(tool_name)
154-
if tool_data.packages:
155-
packaging.remove(' '.join(tool_data.packages))
135+
tool = ToolConfig.from_tool_name(tool_name)
136+
if tool.packages:
137+
packaging.remove(' '.join(tool.packages))
138+
156139
try:
157-
os.remove(f'{path}src/tools/{tool_name}_tool.py')
140+
os.remove(path/f'src/tools/{tool.module_name}.py')
158141
except FileNotFoundError:
159-
print(f'"src/tools/{tool_name}_tool.py" not found')
142+
print(f'"src/tools/{tool.module_name}.py" not found')
160143

161144
try: # Edit the user's project tool init file to exclude the tool
162145
with ToolsInitFile(path/TOOLS_INIT_FILENAME) as tools_init:
163-
tools_init.remove_import_for_tool(tool_data)
146+
tools_init.remove_import_for_tool(framework, tool)
164147
except ValidationError as e:
165148
print(term_color(f"Error removing tool:\n{e}", 'red'))
166-
sys.exit(1)
167149

168150
# Edit the framework entrypoint file to exclude the tool in the agent definition
169-
if not len(agents): # If no agents are specified, remove the tool from all agents
151+
if not agents: # If no agents are specified, remove the tool from all agents
170152
agents = frameworks.get_agent_names(framework, path)
171153
for agent_name in agents:
172-
frameworks.remove_tool(framework, tool_data, agent_name, path)
154+
frameworks.remove_tool(framework, tool, agent_name, path)
173155

174-
if tool_data.post_remove:
175-
os.system(tool_data.post_remove)
156+
if tool.post_remove:
157+
os.system(tool.post_remove)
176158
# We don't remove the .env variables to preserve user data.
177159

178160
with agentstack_config as config:
179-
config.tools.remove(tool_name)
161+
config.tools.remove(tool.name)
180162

181-
print(term_color(f'🔨 Tool {tool_name}', 'green'), term_color('removed', 'red'), term_color('from agentstack project successfully', 'green'))
163+
print(term_color(f'🔨 Tool {tool_name}', 'green'), term_color('removed', 'red'),
164+
term_color('from agentstack project successfully', 'green'))
182165

agentstack/tools.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,15 @@ def from_json(cls, path: Path) -> 'ToolConfig':
4141
print(f"{' '.join(error['loc'])}: {error['msg']}")
4242
sys.exit(1)
4343

44+
@property
45+
def module_name(self) -> str:
46+
return f"{self.name}_tool"
47+
4448
def get_import_statement(self, framework: str) -> str:
45-
return f"from .{self.name}_tool import {', '.join(self.tools)}"
49+
return f"from .{self.module_name} import {', '.join(self.tools)}"
4650

4751
def get_impl_file_path(self, framework: str) -> Path:
48-
return get_package_path()/f'templates/{framework}/tools/{self.name}_tool.py'
52+
return get_package_path()/f'templates/{framework}/tools/{self.module_name}.py'
4953

5054
def get_all_tool_paths() -> list[Path]:
5155
paths = []

tests/test_frameworks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import shutil
44
import unittest
55
from parameterized import parameterized_class
6+
67
from agentstack import ValidationError
78
from agentstack import frameworks
89
from agentstack.tools import ToolConfig

tests/test_tool_generation.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import os, sys
2+
from pathlib import Path
3+
import shutil
4+
import unittest
5+
from parameterized import parameterized_class
6+
7+
from agentstack import frameworks
8+
from agentstack.tools import get_all_tools, ToolConfig
9+
from agentstack.generation.files import ConfigFile
10+
from agentstack.generation.tool_generation import add_tool, remove_tool, TOOLS_INIT_FILENAME
11+
12+
13+
BASE_PATH = Path(__file__).parent
14+
15+
# TODO parameterize all tools
16+
@parameterized_class([
17+
{"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS
18+
])
19+
class TestToolGeneration(unittest.TestCase):
20+
def setUp(self):
21+
self.project_dir = BASE_PATH/'tmp'/'tool_generation'
22+
23+
os.makedirs(self.project_dir)
24+
os.makedirs(self.project_dir/'src')
25+
os.makedirs(self.project_dir/'src'/'tools')
26+
(self.project_dir/'src'/'__init__.py').touch()
27+
(self.project_dir/TOOLS_INIT_FILENAME).touch()
28+
29+
# populate the entrypoint
30+
entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir)
31+
shutil.copy(BASE_PATH/f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path)
32+
33+
# set the framework in agentstack.json
34+
shutil.copy(BASE_PATH/'fixtures'/'agentstack.json', self.project_dir/'agentstack.json')
35+
with ConfigFile(self.project_dir) as config:
36+
config.framework = self.framework
37+
38+
def tearDown(self):
39+
shutil.rmtree(self.project_dir)
40+
41+
def test_add_tool(self):
42+
tool_conf = ToolConfig.from_tool_name('agent_connect')
43+
add_tool('agent_connect', path=self.project_dir)
44+
45+
entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir)
46+
entrypoint_src = open(entrypoint_path).read()
47+
tools_init_src = open(self.project_dir/TOOLS_INIT_FILENAME).read()
48+
49+
# TODO verify tool is added to all agents (this is covered in test_frameworks.py)
50+
#assert 'agent_connect' in entrypoint_src
51+
assert f'from .{tool_conf.module_name} import' in tools_init_src
52+
assert (self.project_dir/'src'/'tools'/f'{tool_conf.module_name}.py').exists()
53+
assert 'agent_connect' in open(self.project_dir/'agentstack.json').read()
54+
55+
def test_remove_tool(self):
56+
tool_conf = ToolConfig.from_tool_name('agent_connect')
57+
add_tool('agent_connect', path=self.project_dir)
58+
remove_tool('agent_connect', path=self.project_dir)
59+
60+
entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir)
61+
entrypoint_src = open(entrypoint_path).read()
62+
tools_init_src = open(self.project_dir/TOOLS_INIT_FILENAME).read()
63+
64+
# TODO verify tool is removed from all agents (this is covered in test_frameworks.py)
65+
#assert 'agent_connect' not in entrypoint_src
66+
assert f'from .{tool_conf.module_name} import' not in tools_init_src
67+
assert not (self.project_dir/'src'/'tools'/f'{tool_conf.module_name}.py').exists()
68+
assert 'agent_connect' not in open(self.project_dir/'agentstack.json').read()
69+

tests/test_tool_generation_init.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import shutil
44
import unittest
55
from parameterized import parameterized_class
6+
67
from agentstack import ValidationError
78
from agentstack import frameworks
89
from agentstack.tools import ToolConfig
@@ -17,7 +18,7 @@
1718
])
1819
class TestToolGenerationInit(unittest.TestCase):
1920
def setUp(self):
20-
self.project_dir = BASE_PATH/'tmp'/'tool_generation'
21+
self.project_dir = BASE_PATH/'tmp'/'tool_generation_init'
2122
os.makedirs(self.project_dir)
2223
os.makedirs(self.project_dir/'src')
2324
os.makedirs(self.project_dir/'src'/'tools')

0 commit comments

Comments
 (0)