Skip to content

Commit e52bb01

Browse files
committed
Bring agent tool modification into shared impl.
1 parent b129b41 commit e52bb01

6 files changed

Lines changed: 124 additions & 269 deletions

File tree

agentstack/frameworks/__init__.py

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
from types import ModuleType
33
from abc import ABCMeta, abstractmethod
44
from importlib import import_module
5+
from dataclasses import dataclass
56
from pathlib import Path
67
import ast
78
from agentstack import conf
89
from agentstack.exceptions import ValidationError
910
from agentstack.generation import InsertionPoint
1011
from agentstack.utils import get_framework
12+
from agentstack import packaging
1113
from agentstack.generation import asttools
1214
from agentstack.agents import AgentConfig, get_all_agent_names
1315
from agentstack.tasks import TaskConfig, get_all_task_names
@@ -28,6 +30,25 @@
2830
DEFAULT_FRAMEWORK = CREWAI
2931

3032

33+
@dataclass
34+
class Provider:
35+
"""
36+
An LLM provider definition.
37+
38+
Used to reference required dependencies, and provide attributes for an
39+
import statement.
40+
"""
41+
42+
class_name: str # The class we need to import to use the provider
43+
module_name: str # The module we import from
44+
dependencies: list[str] # The dependency we need to install to get the module
45+
46+
def install_dependencies(self):
47+
"""Install the dependencies for the provider."""
48+
for dependency in self.dependencies:
49+
packaging.install(dependency)
50+
51+
3152
class FrameworkModule(Protocol):
3253
"""
3354
Protocol spec for a framework implementation module.
@@ -76,19 +97,13 @@ def wrap_tool(self, tool_func: Callable) -> Callable:
7697
"""
7798
...
7899

79-
def get_agent_tool_names(self, agent_name: str) -> list[str]:
80-
"""
81-
Get a list of tool names in an agent in the user's project.
82-
"""
83-
...
84-
85-
def add_agent(self, agent: 'AgentConfig', position: Optional['InsertionPoint'] = None) -> None:
100+
def add_agent(self, agent: 'AgentConfig', position: Optional[InsertionPoint] = None) -> None:
86101
"""
87102
Add an agent to the user's project.
88103
"""
89104
...
90105

91-
def add_task(self, task: 'TaskConfig', position: Optional['InsertionPoint'] = None) -> None:
106+
def add_task(self, task: 'TaskConfig', position: Optional[InsertionPoint] = None) -> None:
92107
"""
93108
Add a task to the user's project.
94109
"""
@@ -113,6 +128,7 @@ class BaseEntrypointFile(asttools.File, metaclass=ABCMeta):
113128
and the `run` method with a method named `run` which accepts `inputs` as a
114129
keyword argument.
115130
131+
Usually, it looks something like this:
116132
```
117133
class UserStack:
118134
@agentstack.task
@@ -128,7 +144,7 @@ def run(self, inputs: list):
128144
```
129145
"""
130146

131-
base_class_pattern = r'\w+Stack$'
147+
base_class_pattern: str = r'\w+Stack$'
132148
agent_decorator_name: str = 'agent'
133149
task_decorator_name: str = 'task'
134150

@@ -248,6 +264,51 @@ def add_agent_method(self, agent: AgentConfig):
248264
if not self.source[pos:].startswith('\n'):
249265
code += '\n\n'
250266
self.edit_node_range(pos, pos, code)
267+
268+
@abstractmethod
269+
def get_agent_tools(self, agent_name: str) -> ast.List:
270+
"""Get the list of tools used by an agent as an AST List node."""
271+
...
272+
273+
def get_agent_tool_nodes(self, agent_name: str) -> list[ast.Starred]:
274+
"""Get a list of all ast nodes that define agentstack tools used by the agent."""
275+
agent_tools_node = self.get_agent_tools(agent_name)
276+
return asttools.find_tool_nodes(agent_tools_node)
277+
278+
def get_agent_tool_names(self, agent_name: str) -> list[str]:
279+
"""Get a list of all tools used by the agent."""
280+
# Tools are identified by the item name of an `agentstack.tools` attribute node.
281+
tool_names: list[str] = []
282+
for node in self.get_agent_tool_nodes(agent_name):
283+
# ignore type checking here since `get_agent_tool_nodes` is exhaustive
284+
tool_names.append(node.value.slice.value) # type: ignore[attr-defined]
285+
return tool_names
286+
287+
def add_agent_tools(self, agent_name: str, tool: ToolConfig):
288+
"""Modify the existing tools list to add a new tool."""
289+
existing_node: ast.List = self.get_agent_tools(agent_name)
290+
existing_elts: list[ast.expr] = existing_node.elts
291+
292+
if not tool.name in self.get_agent_tool_names(agent_name):
293+
existing_elts.append(asttools.create_tool_node(tool.name))
294+
295+
new_node = ast.List(elts=existing_elts, ctx=ast.Load())
296+
start, end = self.get_node_range(existing_node)
297+
self.edit_node_range(start, end, new_node)
298+
299+
def remove_agent_tools(self, agent_name: str, tool: ToolConfig):
300+
"""Modify the existing tools list to remove a tool."""
301+
existing_node: ast.List = self.get_agent_tools(agent_name)
302+
start, end = self.get_node_range(existing_node)
303+
304+
# we're referencing the internal node list from two directions here,
305+
# so it's important that the node tree doesn't get re-parsed in between
306+
for node in self.get_agent_tool_nodes(agent_name):
307+
# ignore type checking here since `get_agent_tool_nodes` is exhaustive
308+
if tool.name == node.value.slice.value: # type: ignore[attr-defined]
309+
existing_node.elts.remove(node)
310+
311+
self.edit_node_range(start, end, existing_node)
251312

252313

253314
def get_framework_module(framework: str) -> FrameworkModule:
@@ -335,6 +396,7 @@ def add_tool(tool: ToolConfig, agent_name: str):
335396
The tool will have already been installed in the user's application and have
336397
all dependencies installed. We're just handling code generation here.
337398
"""
399+
# since this is a write operation, delegate to the framework impl.
338400
module = get_framework_module(get_framework())
339401
return module.add_tool(tool, agent_name)
340402

@@ -343,6 +405,7 @@ def remove_tool(tool: ToolConfig, agent_name: str):
343405
"""
344406
Remove a tool from the user's project.
345407
"""
408+
# since this is a write operation, delegate to the framework impl.
346409
module = get_framework_module(get_framework())
347410
return module.remove_tool(tool, agent_name)
348411

@@ -401,10 +464,11 @@ def get_agent_method_names() -> list[str]:
401464

402465
def get_agent_tool_names(agent_name: str) -> list[str]:
403466
"""
404-
Get a list of tool names in the user's project.
467+
Get a list of tool names in the user's project for a given agent.
405468
"""
406469
module = get_framework_module(get_framework())
407-
return module.get_agent_tool_names(agent_name)
470+
entrypoint = module.get_entrypoint()
471+
return entrypoint.get_agent_tool_names(agent_name)
408472

409473

410474
def add_agent(agent: 'AgentConfig', position: Optional[InsertionPoint] = None):

agentstack/frameworks/crewai.py

Lines changed: 3 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def {agent.name}(self) -> Agent:
7171

7272
def get_agent_tools(self, agent_name: str) -> ast.List:
7373
"""
74-
Get the tools used by an agent as AST nodes.
74+
Get the list of tools used by an agent as an AST List node.
7575
7676
Tool definitions are inside of the methods marked with an `@agent` decorator.
7777
The method returns a new class instance with the tools as a list of callables
@@ -100,65 +100,6 @@ def get_agent_tools(self, agent_name: str) -> ast.List:
100100

101101
return tools_kwarg.value
102102

103-
def get_agent_tool_nodes(self, agent_name: str) -> list[ast.Starred]:
104-
"""
105-
Get a list of all ast nodes that define agentstack tools used by the agent.
106-
"""
107-
agent_tools_node = self.get_agent_tools(agent_name)
108-
return asttools.find_tool_nodes(agent_tools_node)
109-
110-
def get_agent_tool_names(self, agent_name: str) -> list[str]:
111-
"""
112-
Get a list of all tools used by the agent.
113-
114-
Tools are identified by the item name of an `agentstack.tools` attribute node.
115-
"""
116-
tool_names: list[str] = []
117-
for node in self.get_agent_tool_nodes(agent_name):
118-
# ignore type checking here since `get_agent_tool_nodes` is exhaustive
119-
tool_names.append(node.value.slice.value) # type: ignore[attr-defined]
120-
return tool_names
121-
122-
def add_agent_tools(self, agent_name: str, tool: ToolConfig):
123-
"""
124-
Add new tools to be used by an agent.
125-
126-
Tool definitions are inside the methods marked with an `@agent` decorator.
127-
The method returns a new class instance with the tools as a list of callables
128-
under the kwarg `tools`.
129-
"""
130-
method = asttools.find_method(self.get_agent_methods(), agent_name)
131-
if method is None:
132-
raise ValidationError(f"`@agent` method `{agent_name}` does not exist in {ENTRYPOINT}")
133-
134-
existing_node: ast.List = self.get_agent_tools(agent_name)
135-
existing_elts: list[ast.expr] = existing_node.elts
136-
137-
new_tool_nodes: list[ast.expr] = []
138-
if not tool.name in self.get_agent_tool_names(agent_name):
139-
existing_elts.append(asttools.create_tool_node(tool.name))
140-
141-
new_node = ast.List(elts=existing_elts, ctx=ast.Load())
142-
start, end = self.get_node_range(existing_node)
143-
self.edit_node_range(start, end, new_node)
144-
145-
def remove_agent_tools(self, agent_name: str, tool: ToolConfig):
146-
"""
147-
Remove tools from an agent belonging to `tool`.
148-
"""
149-
existing_node: ast.List = self.get_agent_tools(agent_name)
150-
start, end = self.get_node_range(existing_node)
151-
152-
# modify the existing node to remove any matching tools
153-
# we're referencing the internal node list from two directions here,
154-
# so it's important that the node tree doesn't get re-parsed in between
155-
for node in self.get_agent_tool_nodes(agent_name):
156-
# ignore type checking here since `get_agent_tool_nodes` is exhaustive
157-
if tool.name == node.value.slice.value: # type: ignore[attr-defined]
158-
existing_node.elts.remove(node)
159-
160-
self.edit_node_range(start, end, existing_node)
161-
162103

163104
def get_entrypoint() -> CrewFile:
164105
"""
@@ -184,7 +125,7 @@ def parse_llm(llm: str) -> tuple[str, str]:
184125
return provider, model
185126

186127

187-
def add_task(task: TaskConfig, position: Optional['InsertionPoint'] = None) -> None:
128+
def add_task(task: TaskConfig, position: Optional[InsertionPoint] = None) -> None:
188129
"""
189130
Add a task method to the CrewAI entrypoint.
190131
"""
@@ -195,15 +136,7 @@ def add_task(task: TaskConfig, position: Optional['InsertionPoint'] = None) -> N
195136
entrypoint.add_task_method(task)
196137

197138

198-
def get_agent_tool_names(agent_name: str) -> list[Any]:
199-
"""
200-
Get a list of tools used by an agent.
201-
"""
202-
with get_entrypoint() as entrypoint:
203-
return entrypoint.get_agent_tool_names(agent_name)
204-
205-
206-
def add_agent(agent: AgentConfig, position: Optional['InsertionPoint'] = None) -> None:
139+
def add_agent(agent: AgentConfig, position: Optional[InsertionPoint] = None) -> None:
207140
"""
208141
Add an agent method to the CrewAI entrypoint.
209142
"""

0 commit comments

Comments
 (0)