22from types import ModuleType
33from abc import ABCMeta , abstractmethod
44from importlib import import_module
5+ from dataclasses import dataclass
56from pathlib import Path
67import ast
78from agentstack import conf
89from agentstack .exceptions import ValidationError
910from agentstack .generation import InsertionPoint
1011from agentstack .utils import get_framework
12+ from agentstack import packaging
1113from agentstack .generation import asttools
1214from agentstack .agents import AgentConfig , get_all_agent_names
1315from agentstack .tasks import TaskConfig , get_all_task_names
2830DEFAULT_FRAMEWORK = CREWAI
2931
3032
33+ @dataclass
34+ class Provider :
35+ """
36+ An LLM provider definition.
37+
38+ Used to reference required dependencies, and provide attributes for an
39+ import statement.
40+ """
41+
42+ class_name : str # The class we need to import to use the provider
43+ module_name : str # The module we import from
44+ dependencies : list [str ] # The dependency we need to install to get the module
45+
46+ def install_dependencies (self ):
47+ """Install the dependencies for the provider."""
48+ for dependency in self .dependencies :
49+ packaging .install (dependency )
50+
51+
3152class FrameworkModule (Protocol ):
3253 """
3354 Protocol spec for a framework implementation module.
@@ -76,19 +97,13 @@ def wrap_tool(self, tool_func: Callable) -> Callable:
7697 """
7798 ...
7899
79- def get_agent_tool_names (self , agent_name : str ) -> list [str ]:
80- """
81- Get a list of tool names in an agent in the user's project.
82- """
83- ...
84-
85- def add_agent (self , agent : 'AgentConfig' , position : Optional ['InsertionPoint' ] = None ) -> None :
100+ def add_agent (self , agent : 'AgentConfig' , position : Optional [InsertionPoint ] = None ) -> None :
86101 """
87102 Add an agent to the user's project.
88103 """
89104 ...
90105
91- def add_task (self , task : 'TaskConfig' , position : Optional [' InsertionPoint' ] = None ) -> None :
106+ def add_task (self , task : 'TaskConfig' , position : Optional [InsertionPoint ] = None ) -> None :
92107 """
93108 Add a task to the user's project.
94109 """
@@ -113,6 +128,7 @@ class BaseEntrypointFile(asttools.File, metaclass=ABCMeta):
113128 and the `run` method with a method named `run` which accepts `inputs` as a
114129 keyword argument.
115130
131+ Usually, it looks something like this:
116132 ```
117133 class UserStack:
118134 @agentstack.task
@@ -128,7 +144,7 @@ def run(self, inputs: list):
128144 ```
129145 """
130146
131- base_class_pattern = r'\w+Stack$'
147+ base_class_pattern : str = r'\w+Stack$'
132148 agent_decorator_name : str = 'agent'
133149 task_decorator_name : str = 'task'
134150
@@ -248,6 +264,51 @@ def add_agent_method(self, agent: AgentConfig):
248264 if not self .source [pos :].startswith ('\n ' ):
249265 code += '\n \n '
250266 self .edit_node_range (pos , pos , code )
267+
268+ @abstractmethod
269+ def get_agent_tools (self , agent_name : str ) -> ast .List :
270+ """Get the list of tools used by an agent as an AST List node."""
271+ ...
272+
273+ def get_agent_tool_nodes (self , agent_name : str ) -> list [ast .Starred ]:
274+ """Get a list of all ast nodes that define agentstack tools used by the agent."""
275+ agent_tools_node = self .get_agent_tools (agent_name )
276+ return asttools .find_tool_nodes (agent_tools_node )
277+
278+ def get_agent_tool_names (self , agent_name : str ) -> list [str ]:
279+ """Get a list of all tools used by the agent."""
280+ # Tools are identified by the item name of an `agentstack.tools` attribute node.
281+ tool_names : list [str ] = []
282+ for node in self .get_agent_tool_nodes (agent_name ):
283+ # ignore type checking here since `get_agent_tool_nodes` is exhaustive
284+ tool_names .append (node .value .slice .value ) # type: ignore[attr-defined]
285+ return tool_names
286+
287+ def add_agent_tools (self , agent_name : str , tool : ToolConfig ):
288+ """Modify the existing tools list to add a new tool."""
289+ existing_node : ast .List = self .get_agent_tools (agent_name )
290+ existing_elts : list [ast .expr ] = existing_node .elts
291+
292+ if not tool .name in self .get_agent_tool_names (agent_name ):
293+ existing_elts .append (asttools .create_tool_node (tool .name ))
294+
295+ new_node = ast .List (elts = existing_elts , ctx = ast .Load ())
296+ start , end = self .get_node_range (existing_node )
297+ self .edit_node_range (start , end , new_node )
298+
299+ def remove_agent_tools (self , agent_name : str , tool : ToolConfig ):
300+ """Modify the existing tools list to remove a tool."""
301+ existing_node : ast .List = self .get_agent_tools (agent_name )
302+ start , end = self .get_node_range (existing_node )
303+
304+ # we're referencing the internal node list from two directions here,
305+ # so it's important that the node tree doesn't get re-parsed in between
306+ for node in self .get_agent_tool_nodes (agent_name ):
307+ # ignore type checking here since `get_agent_tool_nodes` is exhaustive
308+ if tool .name == node .value .slice .value : # type: ignore[attr-defined]
309+ existing_node .elts .remove (node )
310+
311+ self .edit_node_range (start , end , existing_node )
251312
252313
253314def get_framework_module (framework : str ) -> FrameworkModule :
@@ -335,6 +396,7 @@ def add_tool(tool: ToolConfig, agent_name: str):
335396 The tool will have already been installed in the user's application and have
336397 all dependencies installed. We're just handling code generation here.
337398 """
399+ # since this is a write operation, delegate to the framework impl.
338400 module = get_framework_module (get_framework ())
339401 return module .add_tool (tool , agent_name )
340402
@@ -343,6 +405,7 @@ def remove_tool(tool: ToolConfig, agent_name: str):
343405 """
344406 Remove a tool from the user's project.
345407 """
408+ # since this is a write operation, delegate to the framework impl.
346409 module = get_framework_module (get_framework ())
347410 return module .remove_tool (tool , agent_name )
348411
@@ -401,10 +464,11 @@ def get_agent_method_names() -> list[str]:
401464
402465def get_agent_tool_names (agent_name : str ) -> list [str ]:
403466 """
404- Get a list of tool names in the user's project.
467+ Get a list of tool names in the user's project for a given agent .
405468 """
406469 module = get_framework_module (get_framework ())
407- return module .get_agent_tool_names (agent_name )
470+ entrypoint = module .get_entrypoint ()
471+ return entrypoint .get_agent_tool_names (agent_name )
408472
409473
410474def add_agent (agent : 'AgentConfig' , position : Optional [InsertionPoint ] = None ):
0 commit comments