diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml new file mode 100644 index 0000000..2002811 --- /dev/null +++ b/.github/workflows/linter.yml @@ -0,0 +1,40 @@ +name: Linter & Formatter + +on: + workflow_dispatch: + pull_request: + types: [opened, synchronize] + paths: + - "**/*.py" + - "pyproject.toml" + + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + ruff_linter: + runs-on: ubuntu-24.04 + + steps: + - name: Sync repository + uses: eProsima/eProsima-CI/external/checkout@v0 + + - name: Set up Python + uses: eProsima/eProsima-CI/external/setup-python@v0 + with: + python-version: "3.12" + + - name: Install Ruff + uses: eProsima/eProsima-CI/ubuntu/install_python_packages@v0 + with: + packages: ruff==0.14.13 + + - name: Ruff format (check) + run: | + ruff format --check . + + - name: Ruff lint (check) + run: | + ruff check . diff --git a/README.md b/README.md index eeabbbf..effce25 100644 --- a/README.md +++ b/README.md @@ -106,4 +106,20 @@ To do so, source Vulcanexus and then run the following command in the terminal w ```bash source /opt/vulcanexus/${VULCANEXUS_DISTRO}/setup.bash && \ export PYTHONPATH='//lib/python3.x/site-packages':$PYTHONPATH -``` \ No newline at end of file +``` + +## Developers + +This repository uses `ruff` as formatter and linter. +Use the following commands to ensure that any contribution follows the projects style guidelines: + +```bash +ruff check --fix . +ruff format . +``` + +Python package `ruff` can be installed directly with: + +```bash +python3 -m pip install ruff==0.14.13 +``` diff --git a/pyproject.toml b/pyproject.toml index 6e0441d..a7ae25c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,3 +44,15 @@ vulcanai = ["py.typed", "*.pyi", "**/*.pyi"] [project.scripts] vulcanai-console = "vulcanai.console.console:main" + +# Linter options +[tool.ruff] +line-length = 119 +extend-exclude = [".venv", "build", "dist"] + +[tool.ruff.lint] +# E: pycodestyle errors, F: pyflakes, I: import sorting +select = ["E", "F", "I"] + +[tool.ruff.lint.isort] +known-first-party = ["vulcanai"] diff --git a/src/vulcanai/__init__.py b/src/vulcanai/__init__.py index b61ed4d..32f08dc 100644 --- a/src/vulcanai/__init__.py +++ b/src/vulcanai/__init__.py @@ -16,9 +16,7 @@ from types import ModuleType _SUBPACKAGES = ("core", "tools", "console", "models") -_submods: dict[str, ModuleType] = { - name: import_module(f"{__name__}.{name}") for name in _SUBPACKAGES -} +_submods: dict[str, ModuleType] = {name: import_module(f"{__name__}.{name}") for name in _SUBPACKAGES} __all__ = sorted({sym for m in _submods.values() for sym in getattr(m, "__all__", ())}) @@ -28,11 +26,13 @@ for sym in getattr(mod, "__all__", ()): _MODULE_INDEX.setdefault(sym, mod) + def __getattr__(name: str): mod = _MODULE_INDEX.get(name) if not mod: raise AttributeError(f"module '{__name__}' has no attribute '{name}'") return getattr(mod, name) + def __dir__(): return sorted(list(globals().keys()) + __all__) diff --git a/src/vulcanai/__init__.pyi b/src/vulcanai/__init__.pyi index f828bad..d932dd8 100644 --- a/src/vulcanai/__init__.pyi +++ b/src/vulcanai/__init__.pyi @@ -12,15 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .console import VulcanAILogger, VulcanConsole from .core import ( - ToolManager, PlanManager, IterativeManager, TimelineEvent, - Agent, PlanExecutor, Blackboard, - ArgValue, Step, PlanNode, GlobalPlan, PlanValidator, + Agent, + ArgValue, + Blackboard, + GlobalPlan, + IterativeManager, + PlanExecutor, + PlanManager, + PlanNode, + PlanValidator, + Step, + TimelineEvent, + ToolManager, ) -from .console import VulcanConsole, VulcanAILogger from .models import GeminiModel, OllamaModel, OpenAIModel from .tools import ( - AtomicTool, CompositeTool, ValidationTool, ToolRegistry, vulcanai_tool, + AtomicTool, + CompositeTool, + ToolRegistry, + ValidationTool, + vulcanai_tool, ) __all__ = [ diff --git a/src/vulcanai/console/__init__.py b/src/vulcanai/console/__init__.py index dd75589..4341810 100644 --- a/src/vulcanai/console/__init__.py +++ b/src/vulcanai/console/__init__.py @@ -23,17 +23,19 @@ __all__ = list(_EXPORTS.keys()) + def __getattr__(name: str): target = _EXPORTS.get(name) if not target: raise AttributeError(f"module '{__name__}' has no attribute '{name}'") - module_name, attr_name = target.split(':') + module_name, attr_name = target.split(":") if module_name.startswith("."): module = import_module(module_name, package=__name__) else: module = import_module(module_name) return getattr(module, attr_name) + def __dir__() -> list[str]: """Make dir() show the public API.""" return sorted(list(globals().keys()) + __all__) diff --git a/src/vulcanai/console/__init__.pyi b/src/vulcanai/console/__init__.pyi index 697035c..367e4da 100644 --- a/src/vulcanai/console/__init__.pyi +++ b/src/vulcanai/console/__init__.pyi @@ -16,5 +16,6 @@ from .console import VulcanConsole from .logger import VulcanAILogger __all__ = [ - "VulcanConsole", "VulcanAILogger", + "VulcanConsole", + "VulcanAILogger", ] diff --git a/src/vulcanai/console/console.py b/src/vulcanai/console/console.py index c101167..074ddbd 100644 --- a/src/vulcanai/console/console.py +++ b/src/vulcanai/console/console.py @@ -16,10 +16,10 @@ import argparse import asyncio -import pyperclip # To paste the clipboard into the terminal import sys import threading +import pyperclip # To paste the clipboard into the terminal from textual import events, work from textual.app import App, ComposeResult from textual.binding import Binding @@ -28,15 +28,16 @@ from textual.markup import escape # To remove potential errors in textual terminal from textual.widgets import Input, Static -from vulcanai.console.widget_custom_log_text_area import CustomLogTextArea from vulcanai.console.logger import VulcanAILogger from vulcanai.console.modal_screens import CheckListModal, RadioListModal, ReverseSearchModal -from vulcanai.console.utils import attach_ros_logger_to_console, common_prefix, SpinnerHook, StreamToTextual +from vulcanai.console.utils import SpinnerHook, StreamToTextual, attach_ros_logger_to_console, common_prefix +from vulcanai.console.widget_custom_log_text_area import CustomLogTextArea from vulcanai.console.widget_spinner import SpinnerStatus class TextualLogSink: """A default console that prints to standard output.""" + def __init__(self, textual_console) -> None: self.console = textual_console @@ -45,7 +46,6 @@ def write(self, msg: str, color: str = "") -> None: class VulcanConsole(App): - # CSS Styles # Two panels: left (log + input) and right (history + variables) # Right panel: 48 characters length @@ -111,11 +111,17 @@ class VulcanConsole(App): Binding("down", "history_next", show=False), ] - def __init__(self, - model: str = "gpt-5-nano", k: int = 7, iterative: bool = False, - register_from_file:str = "", tools_from_entrypoints: str = "", - user_context: str = "", main_node = None): - super().__init__() # Textual lib + def __init__( + self, + model: str = "gpt-5-nano", + k: int = 7, + iterative: bool = False, + register_from_file: str = "", + tools_from_entrypoints: str = "", + user_context: str = "", + main_node=None, + ): + super().__init__() # Textual lib # -- Main variables -- # Manager instance @@ -164,7 +170,6 @@ def __init__(self, self.suggestion_index = -1 self.suggestion_index_changed = threading.Event() - async def on_mouse_down(self, event: MouseEvent) -> None: """ Function used to paste the string for the user clipboard @@ -203,8 +208,7 @@ def compose(self) -> ComposeResult: color_tmp = VulcanAILogger.vulcanai_theme["vulcanai"] - vulcanai_title_slant = \ -f"""[{color_tmp}] + vulcanai_title_slant = f"""[{color_tmp}] _ __ __ ___ ____ | | / /_ __/ /________ ____ / | / _/ | | / / / / / / ___/ __ `/ __ \/ /| | / / @@ -229,7 +233,7 @@ def compose(self) -> ComposeResult: # Title Area yield Static(vulcanai_title_slant, id="history_title") # Variable info Area - yield Static(f" Loading info...", id="variables") + yield Static(" Loading info...", id="variables") # History Area with VerticalScroll(id="history_scroll"): # NOTE: markup=True so [bold reverse] works @@ -293,7 +297,7 @@ def worker() -> None: self.manager.bb["console"] = self # Add the shared node to the console manager blackboard to be used by tools - if self.main_node != None: + if self.main_node is not None: self.manager.bb["main_node"] = self.main_node attach_ros_logger_to_console(self, self.main_node) @@ -307,14 +311,14 @@ def worker() -> None: # Activate the terminal input self.set_input_enabled(True) - async def queriestrap(self, user_input: str="") -> None: + async def queriestrap(self, user_input: str = "") -> None: """ Function used to handle user requests. Print information at runtime execution of a function, without blocking the main thread so Textual Log does not freeze. """ - def worker(user_input: str="") -> None: + def worker(user_input: str = "") -> None: """ Worker function to run in a separate thread. @@ -343,8 +347,8 @@ def worker(user_input: str="") -> None: self.last_bb = result.get("blackboard", None) # Print the backboard state - bb_ret = result.get('blackboard', None) - bb_ret = str(bb_ret).replace('<', '\'').replace('>', '\'') + bb_ret = result.get("blackboard", None) + bb_ret = str(bb_ret).replace("<", "'").replace(">", "'") self.logger.log_console(f"Output of plan: {bb_ret}") except KeyboardInterrupt: @@ -394,14 +398,14 @@ def _update_history_panel(self) -> None: cmd_esc = escape(cmd) prefix = "" tmp_color = "" - if len(cmd_esc) > 0 and cmd_esc[0] != '/': + if len(cmd_esc) > 0 and cmd_esc[0] != "/": tmp_color = VulcanAILogger.vulcanai_theme["vulcanai"] prefix = f" [{tmp_color}][Plan {plan_count}][/{tmp_color}]\n" plan_count += 1 else: tmp_color = VulcanAILogger.vulcanai_theme["console"] - text = f"{prefix} [{tmp_color}]{i+1}:[/{tmp_color}] {escape(cmd)}" + text = f"{prefix} [{tmp_color}]{i + 1}:[/{tmp_color}] {escape(cmd)}" if self.history_index is not None and self.history_index == i: # Highlight current selection text = f"[bold reverse]{text}[/]" @@ -415,7 +419,10 @@ def _update_variables_panel(self) -> None: the current variables info (model, k, history_depth). """ - text = f" AI model: {self.model.replace('ollama-', '')}\n K = {self.manager.k}\n history_depth = {self.manager.history_depth}" + text = ( + f" AI model: {self.model.replace('ollama-', '')}\n K = {self.manager.k}\n" + f" history_depth = {self.manager.history_depth}" + ) kvalue_widget = self.query_one("#variables", Static) kvalue_widget.update(text) @@ -431,7 +438,6 @@ async def open_checklist(self, tools_list: list[str], active_tools_num: int) -> if selected is None: self.logger.log_msg("Selection cancelled.") else: - # Iterate over all tools and activate/deactivate accordingly # to the selection made by the user for tool_tmp in tools_list: @@ -457,14 +463,13 @@ async def open_radiolist(self, option_list: list[str], tool: str = "") -> str: selected = await self.push_screen_wait(RadioListModal(option_list)) if selected is None: - self.logger.log_tool(f"Suggestion cancelled", tool_name=tool) + self.logger.log_tool("Suggestion cancelled", tool_name=tool) self.suggestion_index = -2 return - self.logger.log_tool(f"Selected suggestion: \"{option_list[selected]}\"", tool_name=tool) + self.logger.log_tool(f'Selected suggestion: "{option_list[selected]}"', tool_name=tool) self.suggestion_index = selected - self.suggestion_index_changed.set() # signal change - + self.suggestion_index_changed.set() # signal change # endregion @@ -479,8 +484,10 @@ def cmd_help(self, _) -> None: "/help - Show this help message\n" "/tools - List available tools\n" "/edit_tools - Edit the list of available tools\n" - "/change_k 'int' - Change the 'k' value for the top_k algorithm selection or show the current value if no 'int' is provided\n" - "/history 'int' - Change the history depth or show the current value if no 'int' is provided\n" + "/change_k 'int' - Change the 'k' value for the top_k algorithm selection" + " or show the current value if no 'int' is provided\n" + "/history 'int' - Change the history depth or show the current value if no" + " 'int' is provided\n" "/show_history - Show the current history\n" "/clear_history - Clear the history\n" "/plan - Show the last generated plan\n" @@ -489,7 +496,8 @@ def cmd_help(self, _) -> None: "/clear - Clears the console screen\n" "/exit - Exit the console\n" "Query any other text to process it with the LLM and execute the plan generated.\n\n" - "Add --image='path' to include images in the query. It can be used multiple times to add more images.\n" + "Add --image='path' to include images in the query. It can be used multiple times to add" + " more images.\n" "Example: 'user_prompt' --image=/path/to/image1 --image=/path/to/image2'\n" "___________________\n" "Available keybinds:\n" @@ -502,16 +510,17 @@ def cmd_help(self, _) -> None: "Ctrl+K - Clears from the cursor to then end of the line\n" "Ctrl+W - Delete the word before the cursor\n" "Ctrl+'left/right' - Move cursor backward/forward by one word\n" - "Ctrl+R - Reverse search through command history (try typing part of a previous command).\n" + "Ctrl+R - Reverse search through command history (try typing part of a" + " previous command).\n" ] ) self.logger.log_console(table, "console") def cmd_tools(self, _) -> None: tmp_msg = f"(current index k={self.manager.k})" - tool_msg = ("_" * len(tmp_msg)) + '\n' - tool_msg += f"Available tools:\n" - tool_msg += tmp_msg + '\n' + ("‾" * len(tmp_msg)) + '\n' + tool_msg = ("_" * len(tmp_msg)) + "\n" + tool_msg += "Available tools:\n" + tool_msg += tmp_msg + "\n" + ("‾" * len(tmp_msg)) + "\n" for tool in self.manager.registry.tools.values(): tool_msg += f"- {tool.name}: {tool.description}\n" @@ -561,16 +570,16 @@ def cmd_show_history(self, _) -> None: self.logger.log_console("No history available.") return - history_msg = \ - "________________\n" + \ - "Current history:\n" + \ - "(oldest first)\n" + \ - "‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾\n" + history_msg = ( + "________________\n" + "Current history:\n" + "(oldest first)\n" + "‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾\n" + ) self.logger.log_console(history_msg, "console") for i, (user_text, plan_summary) in enumerate(self.manager.history): - user_req_cmd = user_text.split('\n') - self.logger.log_msg(f"{i+1}. [USER] >>> {user_req_cmd[1]}\n",) + user_req_cmd = user_text.split("\n") + self.logger.log_msg( + f"{i + 1}. [USER] >>> {user_req_cmd[1]}\n", + ) self.logger.log_msg(f"Plan summary: {plan_summary}\n") def cmd_clear_history(self, _) -> None: @@ -649,7 +658,7 @@ def cmd_blackboard_state(self, _) -> None: self.logger.log_console("Lastest blackboard state:") # Parse the blackboard to avoid <...> issues in textual last_bb_parsed = str(self.last_bb) - last_bb_parsed = last_bb_parsed.replace('<', '\'').replace('>', '\'') + last_bb_parsed = last_bb_parsed.replace("<", "'").replace(">", "'") self.logger.log_console(last_bb_parsed) else: self.logger.log_console("No blackboard available.") @@ -664,9 +673,7 @@ def cmd_quit(self, _) -> None: # region Logging - def add_line(self, input: str, - color: str = "", - subprocess_flag: bool = False) -> None: + def add_line(self, input: str, color: str = "", subprocess_flag: bool = False) -> None: """ Function used to write an input in the VulcanAI terminal. """ @@ -679,7 +686,6 @@ def add_line(self, input: str, color_begin = f"<{color}>" color_end = f"" - # Append each line; deque automatically truncates old ones for line in lines: line_processed = line @@ -687,7 +693,7 @@ def add_line(self, input: str, line_processed = escape(line) text = f"{color_begin}{line_processed}{color_end}" if not self.left_pannel.append_line(text): - self.logger.log_console(f"Warning: Trying to add an empty line.") + self.logger.log_console("Warning: Trying to add an empty line.") def delete_last_line(self): """ @@ -818,7 +824,6 @@ async def _paste_clipboard(self) -> None: cmd_input.cursor_position = cursor + len(paste_text) cmd_input.focus() - async def on_key(self, event: events.Key) -> None: """ Function used to handle key events in the terminal input box. @@ -897,7 +902,7 @@ async def on_key(self, event: events.Key) -> None: # Iterate backwards to find the start of the previous word while i > 0: - if value[i] == ' ': + if value[i] == " ": # First space after a word if first_char: i += 1 @@ -916,7 +921,7 @@ async def on_key(self, event: events.Key) -> None: return # - escape/ctrl+delete: Delete after cursor --------------------------- - if key in ("ctrl+delete", "escape") : + if key in ("ctrl+delete", "escape"): # Input value and cursor position value = cmd_input.value cursor = cmd_input.cursor_position @@ -928,13 +933,13 @@ async def on_key(self, event: events.Key) -> None: # Iterate forwards to find the end of the next word while i < n: - if(value[i] == ' '): + if value[i] == " ": if first_char: break else: first_char = True i += 1 - count +=1 + count += 1 # Update input value and cursor position cmd_input.value = value[:cursor] + value[i:] @@ -1021,9 +1026,9 @@ def action_stop_streaming_task(self) -> None: Binding("ctrl+c", "stop_streaming_task", ...), """ - if self.stream_task != None and not self.stream_task.done(): + if self.stream_task is not None and not self.stream_task.done(): # Cancel the streaming task - self.stream_task.cancel() # Triggers CancelledError in the task + self.stream_task.cancel() # Triggers CancelledError in the task self.stream_task = None else: @@ -1079,7 +1084,6 @@ def run_console(self) -> None: """ self.run() - def init_manager(self) -> None: """ Function used to initialize VulcanAI Manager. @@ -1106,39 +1110,46 @@ def get_images(self, user_input: str) -> None: for part in parts: if part.startswith("--image="): - images.append(part[len("--image="):]) + images.append(part[len("--image=") :]) return images def main() -> None: parser = argparse.ArgumentParser(description="VulcanAI Interactive Console") parser.add_argument( - "--model", type=str, default="gpt-5-nano", - help="LLM model to used in the agent (ej: gpt-5-nano, gemini-2.0-flash, etc.)" + "--model", + type=str, + default="gpt-5-nano", + help="LLM model to used in the agent (ej: gpt-5-nano, gemini-2.0-flash, etc.)", ) parser.add_argument( - "--register-from-file", type=str, nargs="*", default=[], - help="Register tools from a python file (or multiple files)" + "--register-from-file", + type=str, + nargs="*", + default=[], + help="Register tools from a python file (or multiple files)", ) parser.add_argument( - "--register-from-entry-point", type=str, nargs="*", default=[], - help="Register tools from a python entry-point (or multiple entry-points)" + "--register-from-entry-point", + type=str, + nargs="*", + default=[], + help="Register tools from a python entry-point (or multiple entry-points)", ) + parser.add_argument("-k", type=int, default=7, help="Maximum number of tools to pass to the LLM") parser.add_argument( - "-k", type=int, default=7, - help="Maximum number of tools to pass to the LLM" - ) - parser.add_argument( - "-i", "--iterative", action="store_true", default=False, - help="Enable Iterative Manager (default: off)" + "-i", "--iterative", action="store_true", default=False, help="Enable Iterative Manager (default: off)" ) - args = parser.parse_args() - console = VulcanConsole(register_from_file=args.register_from_file, - tools_from_entrypoints=args.register_from_entry_point, - model=args.model, k=args.k, iterative=args.iterative) + console = VulcanConsole( + register_from_file=args.register_from_file, + tools_from_entrypoints=args.register_from_entry_point, + model=args.model, + k=args.k, + iterative=args.iterative, + ) console.run_console() diff --git a/src/vulcanai/console/logger.py b/src/vulcanai/console/logger.py index a225cdd..a6b6f88 100644 --- a/src/vulcanai/console/logger.py +++ b/src/vulcanai/console/logger.py @@ -13,19 +13,20 @@ # limitations under the License. import re -from typing import Protocol, Optional +from typing import Optional, Protocol class LogSink(Protocol): """A default console that prints to standard output.""" - def write(self, msg: str, color: str = "") -> None: - ... + + def write(self, msg: str, color: str = "") -> None: ... class RichStdoutSink: def __init__(self, logger_theme) -> None: from rich.console import Console from rich.theme import Theme + self.console = Console(theme=Theme(logger_theme)) def write(self, msg: str, color: str = "") -> None: @@ -33,24 +34,23 @@ def write(self, msg: str, color: str = "") -> None: class VulcanAILogger: - """ Logger class for VulcanAI components. Provides methods to log messages with different tags and colors. """ vulcanai_theme = { - "registry": "#068399", - "manager": "#0d87c0", - "executor": "#15B606", - "vulcanai": "#56AA08", - "user": "#91DD16", - "validator": "#C49C00", - "tool": "#EB921E", - "error": "#FF0000", - "console": "#8F6296", - "warning": "#D8C412", - } + "registry": "#068399", + "manager": "#0d87c0", + "executor": "#15B606", + "vulcanai": "#56AA08", + "user": "#91DD16", + "validator": "#C49C00", + "tool": "#EB921E", + "error": "#FF0000", + "console": "#8F6296", + "warning": "#D8C412", + } _default_instance: Optional["VulcanAILogger"] = None _rich_markup = True @@ -84,7 +84,7 @@ def parse_color(self, msg): """ # Matches [tag] or [/tag] - pattern = re.compile(r'\[(\/?)([^\]]+)\]') + pattern = re.compile(r"\[(\/?)([^\]]+)\]") def replace_tag(match): slash, tag = match.groups() @@ -126,30 +126,30 @@ def process_msg(self, msg: str, prefix: str = "", color: str = "") -> str: def log_manager(self, msg: str, error: bool = False, color: str = ""): if error: - prefix = f"[error][MANAGER] [ERROR][/error] " + prefix = "[error][MANAGER] [ERROR][/error] " else: - prefix = f"[manager][MANAGER][/manager] " + prefix = "[manager][MANAGER][/manager] " processed_msg = self.process_msg(msg, prefix=prefix, color=color) self.sink.write(processed_msg) - def log_executor(self, msg: str, error: bool = False, tool: bool = False, tool_name: str = '', color: str = ""): + def log_executor(self, msg: str, error: bool = False, tool: bool = False, tool_name: str = "", color: str = ""): if error: - prefix = f"[error][EXECUTOR] [ERROR][/error] " + prefix = "[error][EXECUTOR] [ERROR][/error] " elif tool: self.log_tool(msg, tool_name=tool_name) return else: - prefix = f"[executor][EXECUTOR][/executor] " + prefix = "[executor][EXECUTOR][/executor] " processed_msg = self.process_msg(msg, prefix=prefix, color=color) self.sink.write(processed_msg) - def log_tool(self, msg: str, tool_name: str = '', error: bool = False, color: str = ""): + def log_tool(self, msg: str, tool_name: str = "", error: bool = False, color: str = ""): if tool_name: tag = f"[TOOL {tool_name}]" else: - tag = '[TOOL]' + tag = "[TOOL]" if error: prefix = f"[error]{tag} [ERROR][/error] " else: @@ -160,18 +160,18 @@ def log_tool(self, msg: str, tool_name: str = '', error: bool = False, color: st def log_registry(self, msg: str, error: bool = False, color: str = ""): if error: - prefix = f"[error][REGISTRY] [ERROR][/error] " + prefix = "[error][REGISTRY] [ERROR][/error] " else: - prefix = f"[registry][REGISTRY][/registry] " + prefix = "[registry][REGISTRY][/registry] " processed_msg = self.process_msg(msg, prefix=prefix, color=color) self.sink.write(processed_msg) def log_validator(self, msg: str, error: bool = False, color: str = ""): if error: - prefix = f"[error][VALIDATOR] [ERROR][/error] " + prefix = "[error][VALIDATOR] [ERROR][/error] " else: - prefix = f"[validator][VALIDATOR][/validator] " + prefix = "[validator][VALIDATOR][/validator] " processed_msg = self.process_msg(msg, prefix=prefix, color=color) self.sink.write(processed_msg) @@ -195,7 +195,7 @@ def log_msg(self, msg: str, error: bool = False, color: str = ""): self.sink.write(processed_msg) def log_user(self, msg: str): - prefix = f"[user][USER] >>>[/user] " + prefix = "[user][USER] >>>[/user] " processed_msg = self.process_msg(msg, prefix=prefix) self.sink.write(processed_msg) diff --git a/src/vulcanai/console/modal_screens.py b/src/vulcanai/console/modal_screens.py index a00f894..829ad13 100644 --- a/src/vulcanai/console/modal_screens.py +++ b/src/vulcanai/console/modal_screens.py @@ -14,10 +14,9 @@ from textual import events from textual.app import ComposeResult -from textual.containers import VerticalScroll, Horizontal, Vertical, Container +from textual.containers import Container, Horizontal, Vertical, VerticalScroll from textual.screen import ModalScreen -from textual.widgets import Input, Checkbox, Button, Label, RadioSet, RadioButton - +from textual.widgets import Button, Checkbox, Input, Label, RadioButton, RadioSet class ReverseSearchModal(ModalScreen[str | None]): @@ -143,9 +142,8 @@ async def on_key(self, event: events.Key) -> None: event.stop() return -class CheckListModal(ModalScreen[list[str] | None]): - +class CheckListModal(ModalScreen[list[str] | None]): CSS = """ CheckListModal { align: center middle; @@ -210,10 +208,7 @@ def on_mount(self) -> None: self.set_focus(first_cb) - - class RadioListModal(ModalScreen[str | None]): - CSS = """ RadioListModal { align: center middle; @@ -257,11 +252,7 @@ def compose(self) -> ComposeResult: with VerticalScroll(classes="radio-list"): with RadioSet(id="radio-set"): for i, line in enumerate(self.lines): - yield RadioButton( - line, - id=f"rb{i}", - value=(i == self.default_index) - ) + yield RadioButton(line, id=f"rb{i}", value=(i == self.default_index)) # Buttons with Horizontal(classes="btns"): @@ -276,7 +267,7 @@ def on_button_pressed(self, event: Button.Pressed) -> None: if event.button.id == "submit": radioset = self.query_one("#radio-set", RadioSet) selected = radioset.pressed_index - if selected != None: + if selected is not None: self.dismiss(selected) else: - self.dismiss(None) \ No newline at end of file + self.dismiss(None) diff --git a/src/vulcanai/console/utils.py b/src/vulcanai/console/utils.py index 6720f66..65a4929 100644 --- a/src/vulcanai/console/utils.py +++ b/src/vulcanai/console/utils.py @@ -19,7 +19,9 @@ import subprocess import sys import time -from textual.markup import escape # To remove potential errors in textual terminal + +from textual.markup import escape # To remove potential errors in textual terminal + class StreamToTextual: """ @@ -36,7 +38,7 @@ def write(self, data: str): if data.strip(): # Ensure update happens on the app thread - #self.app.call_from_thread(self.app.append_log_text, data) + # self.app.call_from_thread(self.app.append_log_text, data) self.app.call_from_thread(self.app.add_line, data) def flush(self): @@ -107,13 +109,10 @@ def common_prefix(strings: str) -> str: return common_prefix, commands -async def run_streaming_cmd_async(console, args: list[str], - max_duration: float = 60, - max_lines: int = 1000, - echo: bool = True, - tool_name="") -> str: - +async def run_streaming_cmd_async( + console, args: list[str], max_duration: float = 60, max_lines: int = 1000, echo: bool = True, tool_name="" +) -> str: # Unpack the command cmd, *cmd_args = args @@ -150,28 +149,29 @@ async def run_streaming_cmd_async(console, args: list[str], # Check duration if max_duration and (time.monotonic() - start_time) >= max_duration: - console.logger.log_tool(f"[tool]Stopping:[/tool] Exceeded max_duration = {max_duration}s", tool_name=tool_name) + console.logger.log_tool( + f"[tool]Stopping:[/tool] Exceeded max_duration = {max_duration}s", tool_name=tool_name + ) console.set_stream_task(None) process.terminate() break - except asyncio.CancelledError: # Task was cancelled → stop the subprocess - console.logger.log_tool(f"[tool]Cancellation received:[/tool] terminating subprocess...", tool_name=tool_name) + console.logger.log_tool("[tool]Cancellation received:[/tool] terminating subprocess...", tool_name=tool_name) process.terminate() raise # Not necessary, textual terminal get the keyboard input except KeyboardInterrupt: # Ctrl+C pressed → stop subprocess - console.logger.log_tool(f"[tool]Ctrl+C received:[/tool] terminating subprocess...", tool_name=tool_name) + console.logger.log_tool("[tool]Ctrl+C received:[/tool] terminating subprocess...", tool_name=tool_name) process.terminate() finally: try: await asyncio.wait_for(process.wait(), timeout=3.0) except asyncio.TimeoutError: - console.logger.log_tool(f"Subprocess didn't exit in time → killing it.", tool_name=tool_name, error=True) + console.logger.log_tool("Subprocess didn't exit in time → killing it.", tool_name=tool_name, error=True) process.kill() await process.wait() @@ -179,7 +179,6 @@ async def run_streaming_cmd_async(console, args: list[str], def execute_subprocess(console, tool_name, base_args, max_duration, max_lines): - stream_task = None def _launcher() -> None: @@ -192,12 +191,11 @@ def _launcher() -> None: base_args, max_duration=max_duration, max_lines=max_lines, - tool_name=tool_name#tool_header_str + tool_name=tool_name, # tool_header_str ) ) def _on_done(task: asyncio.Task) -> None: - if task.cancelled(): # Normal path → don't log as an error # If you want a message, call UI methods directly here, @@ -208,7 +206,7 @@ def _on_done(task: asyncio.Task) -> None: task.result() except Exception as e: console.logger.log_msg(f"Echo task error: {e!r}\n", error=True) - #result["output"] = False + # result["output"] = False return stream_task.add_done_callback(_on_done) @@ -228,20 +226,16 @@ def _on_done(task: asyncio.Task) -> None: console.set_stream_task(stream_task) console.logger.log_tool("[tool]Subprocess created![tool]", tool_name=tool_name) + def run_oneshot_cmd(args: list[str]) -> str: try: - return subprocess.check_output( - args, - stderr=subprocess.STDOUT, - text=True - ) + return subprocess.check_output(args, stderr=subprocess.STDOUT, text=True) except subprocess.CalledProcessError as e: raise Exception(f"Failed to run '{' '.join(args)}': {e.output}") def suggest_string(console, tool_name, string_name, input_string, real_string_list): - ret = None def _similarity(a: str, b: str) -> float: @@ -289,15 +283,14 @@ def _get_suggestions(real_string_list_comp: list[str], string_comp: str) -> tupl ret_list.append(most_topic_similar) while topic_list_pq: - _ , topic = heapq.heappop(topic_list_pq) + _, topic = heapq.heappop(topic_list_pq) ret_list.append(topic) return most_topic_similar, ret_list if input_string not in real_string_list: - - #console.add_line(f"{tool_header_str} {string_name}: \"{input_string}\" does not exists") - console.logger.log_tool(f"{string_name}: \"{input_string}\" does not exists", tool_name=tool_name) + # console.add_line(f"{tool_header_str} {string_name}: \"{input_string}\" does not exists") + console.logger.log_tool(f'{string_name}: "{input_string}" does not exists', tool_name=tool_name) # Get the suggestions list sorted by similitud value _, topic_sim_list = _get_suggestions(real_string_list, input_string) diff --git a/src/vulcanai/console/widget_custom_log_text_area.py b/src/vulcanai/console/widget_custom_log_text_area.py index 45ef422..dd60f4c 100644 --- a/src/vulcanai/console/widget_custom_log_text_area.py +++ b/src/vulcanai/console/widget_custom_log_text_area.py @@ -13,11 +13,11 @@ # limitations under the License. -from collections import defaultdict, deque -import pyperclip import re import threading +from collections import defaultdict, deque +import pyperclip from rich.style import Style from textual.widgets import TextArea @@ -44,7 +44,6 @@ class CustomLogTextArea(TextArea): # join tags TAG_TOKEN_RE = re.compile(r"]+>") - def __init__(self, **kwargs): super().__init__(read_only=True, **kwargs) @@ -123,7 +122,7 @@ def join_nested_tags(self, input_text: str) -> str: # Iterate over all tags in 'input_text' for m in self.TAG_TOKEN_RE.finditer(input_text): # Emit text between tags - text = input_text[pos:m.start()] + text = input_text[pos : m.start()] if text: if stack: combined = " ".join(stack) @@ -253,7 +252,7 @@ def append_line(self, text: str) -> bool: # TAG_RE.finditer finds all tags in 'text' for m in self.TAG_RE.finditer(text): # Append text before the current tag - plain += text[cursor:m.start()] + plain += text[cursor : m.start()] # Get tag and body tag = m.group("tag") body = m.group("body") @@ -280,17 +279,15 @@ def append_line(self, text: str) -> bool: # [EXECUTOR] Invoking 'move_turtle' with args: ... # [ROS] [INFO] Publishing message 1 to ... with self._lock: - # Append via document API to keep row tracking consistent # Only add a newline before the new line if there is already content insert_text = ("\n" if self.document.text else "") + plain self.insert(insert_text, location=self.document.end) # Track styles for the new line (always at the end) - row = self._line_count self._line_count += 1 - if (self._line_count > self.MAX_LINES): + if self._line_count > self.MAX_LINES: self._highlights.pop(self._line_count - self.MAX_LINES, None) # Store styles diff --git a/src/vulcanai/console/widget_spinner.py b/src/vulcanai/console/widget_spinner.py index fc56aed..a2019ed 100644 --- a/src/vulcanai/console/widget_spinner.py +++ b/src/vulcanai/console/widget_spinner.py @@ -24,6 +24,7 @@ class SpinnerStatus(Static): It implements rich's Spinner and manages its display state, starting and stopping it as needed through SpinnerHook. """ + def __init__(self, logcontent, **kwargs) -> None: super().__init__(**kwargs) self.logcontent = logcontent @@ -32,7 +33,7 @@ def __init__(self, logcontent, **kwargs) -> None: self._forced_compact = False def on_mount(self) -> None: - self._timer = self.set_interval(1/30, self._refresh, pause=True) + self._timer = self.set_interval(1 / 30, self._refresh, pause=True) self.display = False self.styles.height = 0 diff --git a/src/vulcanai/core/__init__.py b/src/vulcanai/core/__init__.py index d4ddf52..9cda0b9 100644 --- a/src/vulcanai/core/__init__.py +++ b/src/vulcanai/core/__init__.py @@ -31,17 +31,19 @@ __all__ = list(_EXPORTS.keys()) + def __getattr__(name: str): target = _EXPORTS.get(name) if not target: raise AttributeError(f"module '{__name__}' has no attribute '{name}'") - module_name, attr_name = target.split(':') + module_name, attr_name = target.split(":") if module_name.startswith("."): module = import_module(module_name, package=__name__) else: module = import_module(module_name) return getattr(module, attr_name) + def __dir__() -> list[str]: """Make dir() show the public API.""" return sorted(list(globals().keys()) + __all__) diff --git a/src/vulcanai/core/__init__.pyi b/src/vulcanai/core/__init__.pyi index 0aa6935..45e2d25 100644 --- a/src/vulcanai/core/__init__.pyi +++ b/src/vulcanai/core/__init__.pyi @@ -13,11 +13,11 @@ # limitations under the License. from .agent import Agent -from .executor import PlanExecutor, Blackboard +from .executor import Blackboard, PlanExecutor from .manager import ToolManager from .manager_iterator import IterativeManager, TimelineEvent from .manager_plan import PlanManager -from .plan_types import ArgValue, Step, PlanNode, GlobalPlan +from .plan_types import ArgValue, GlobalPlan, PlanNode, Step from .validator import PlanValidator __all__ = [ diff --git a/src/vulcanai/core/agent.py b/src/vulcanai/core/agent.py index febc67e..2d9886f 100644 --- a/src/vulcanai/core/agent.py +++ b/src/vulcanai/core/agent.py @@ -25,8 +25,8 @@ class Brand(str, Enum): class Agent: - """Interface to operate the LLM.""" + def __init__(self, model_name: str, logger=None): self.brand, name = self._detect_brand(model_name) self.model = None @@ -34,12 +34,8 @@ def __init__(self, model_name: str, logger=None): self._load_model(name) def inference_plan( - self, - system_context: str, - user_prompt: str, - images: list[str], - history: list[tuple[str, str]] - ) -> GlobalPlan: + self, system_context: str, user_prompt: str, images: list[str], history: list[tuple[str, str]] + ) -> GlobalPlan: """ Perform inference using the selected LLM model to generate a plan. @@ -53,20 +49,12 @@ def inference_plan( raise RuntimeError("LLM model was not loaded correctly.") plan: GlobalPlan = self.model.plan_inference( - system_prompt=system_context, - user_prompt=user_prompt, - images=images, - history=history + system_prompt=system_context, user_prompt=user_prompt, images=images, history=history ) return plan - def inference_goal( - self, - system_context: str, - user_prompt: str, - history: list[tuple[str, str]] - ) -> GoalSpec: + def inference_goal(self, system_context: str, user_prompt: str, history: list[tuple[str, str]]) -> GoalSpec: """ Perform inference using the selected LLM model to generate a goal. @@ -79,20 +67,18 @@ def inference_goal( raise RuntimeError("LLM model was not loaded correctly.") goal: GoalSpec = self.model.goal_inference( - system_prompt=system_context, - user_prompt=user_prompt, - history=history + system_prompt=system_context, user_prompt=user_prompt, history=history ) return goal def inference_validation( - self, - system_context: str, - user_prompt: str, - images: list[str], - history: list[tuple[str, str]], - ) -> AIValidation: + self, + system_context: str, + user_prompt: str, + images: list[str], + history: list[tuple[str, str]], + ) -> AIValidation: """ Perform inference using the selected LLM model to generate a validation. @@ -106,10 +92,7 @@ def inference_validation( raise RuntimeError("LLM model was not loaded correctly.") validation: AIValidation = self.model.validation_inference( - system_prompt=system_context, - user_prompt=user_prompt, - images=images, - history=history + system_prompt=system_context, user_prompt=user_prompt, images=images, history=history ) return validation @@ -123,7 +106,7 @@ def _detect_brand(self, model_name: str) -> tuple[Brand, str]: """ m = model_name.lower() if m.startswith("ollama-"): - return Brand.ollama, model_name[len("ollama-"):] + return Brand.ollama, model_name[len("ollama-") :] if m.startswith(("gpt-", "o")): return Brand.gpt, model_name if m.startswith(("gemini-", "gemma-")): @@ -134,20 +117,20 @@ def _detect_brand(self, model_name: str) -> tuple[Brand, str]: def _load_model(self, model_name: str): if self.brand == Brand.gpt: from vulcanai.models.openai import OpenAIModel - self.logger.log_manager(f"Using OpenAI API with model: " + \ - f"[manager]{model_name}[/manager]") + + self.logger.log_manager("Using OpenAI API with model: " + f"[manager]{model_name}[/manager]") self.model = OpenAIModel(model_name, self.logger) elif self.brand == Brand.gemini: from vulcanai.models.gemini import GeminiModel - self.logger.log_manager(f"Using Gemini API with model: " + \ - f"[manager]{model_name}[/manager]") + + self.logger.log_manager("Using Gemini API with model: " + f"[manager]{model_name}[/manager]") self.model = GeminiModel(model_name, self.logger) elif self.brand == Brand.ollama: from vulcanai.models.ollama_model import OllamaModel - self.logger.log_manager(f"Using Ollama API with model: " + \ - f"[manager]{model_name}[/manager]") + + self.logger.log_manager("Using Ollama API with model: " + f"[manager]{model_name}[/manager]") self.model = OllamaModel(model_name, self.logger) else: diff --git a/src/vulcanai/core/executor.py b/src/vulcanai/core/executor.py index 3cc41c3..2912aec 100644 --- a/src/vulcanai/core/executor.py +++ b/src/vulcanai/core/executor.py @@ -17,17 +17,15 @@ import io import re import time -from typing import Dict, Any, Optional, Set, Tuple, List +from typing import Any, Dict, List, Optional, Set, Tuple from vulcanai.console.logger import VulcanAILogger -from vulcanai.core.plan_types import GlobalPlan, PlanBase, Step, ArgValue - +from vulcanai.core.plan_types import ArgValue, GlobalPlan, PlanBase, Step TYPE_CAST = { "float": float, "int": int, - "bool": lambda v: v if isinstance(v, bool) - else str(v).strip().lower() in ("1","true","yes","on"), + "bool": lambda v: v if isinstance(v, bool) else str(v).strip().lower() in ("1", "true", "yes", "on"), "str": str, } @@ -54,6 +52,7 @@ def text_snapshot(self, keys: Optional[List[str]] = None) -> str: return str(snapshot) + class PlanExecutor: """Executes a validated GlobalPlan with blackboard and execution control parameters.""" @@ -80,22 +79,24 @@ def _run_plan_node(self, node: PlanBase, bb: Blackboard) -> bool: """Run a PlanNode with execution control parameters.""" # Evaluate PlanNode-level condition if node.condition and not self.safe_eval(node.condition, bb): - self.logger.log_executor(f"Skipping PlanNode {node.kind} due to not fulfilled " + \ - f"condition={node.condition}") + self.logger.log_executor( + f"Skipping PlanNode {node.kind} due to not fulfilled " + f"condition={node.condition}" + ) return True attempts = node.retry + 1 if node.retry else 1 for i in range(attempts): ok = self._execute_plan_node_with_timeout(node, bb) if ok and self._check_success(node, bb): - self.logger.log_executor(f"PlanNode [registry]{node.kind}[/registry] " + \ - f"[executor]succeeded[/executor] " + \ - f"on attempt {i+1}/{attempts}") + self.logger.log_executor( + f"PlanNode [registry]{node.kind}[/registry] " + + "[executor]succeeded[/executor] " + + f"on attempt {i + 1}/{attempts}" + ) return True - self.logger.log_executor(f"PlanNode {node.kind} failed on attempt {i+1}/{attempts}", error=True) + self.logger.log_executor(f"PlanNode {node.kind} failed on attempt {i + 1}/{attempts}", error=True) if node.on_fail: - self.logger.log_executor(f"Executing on_fail branch for PlanNode " + \ - f"[registry]{node.kind}[/registry]") + self.logger.log_executor("Executing on_fail branch for PlanNode " + f"[registry]{node.kind}[/registry]") # Execute the on_fail branch but ignore its result and return False self._run_plan_node(node.on_fail, bb) @@ -109,7 +110,9 @@ def _execute_plan_node_with_timeout(self, node: PlanBase, bb: Blackboard) -> boo future = executor.submit(self._execute_plan_node, node, bb) return future.result(timeout=node.timeout_ms / 1000.0) except concurrent.futures.TimeoutError: - self.logger.log_executor(f"PlanNode {node.kind} timed out after {node.timeout_ms} ms", error=True) + self.logger.log_executor( + f"PlanNode {node.kind} timed out after {node.timeout_ms} ms", error=True + ) return False else: return self._execute_plan_node(node, bb) @@ -136,8 +139,10 @@ def _execute_plan_node(self, node: PlanBase, bb: Blackboard) -> bool: def _run_step(self, step: Step, bb: Blackboard, parallel: bool = False) -> bool: # Evaluate Step-level condition if step.condition and not self.safe_eval(step.condition, bb): - self.logger.log_executor(f"Skipping step '{step.tool}' " + \ - f"due to condition=[executor]{step.condition}[/executor]") + self.logger.log_executor( + f"Skipping step '{step.tool}' " + + f"due to condition=[executor]{step.condition}[/executor]" + ) return True # Bind args with blackboard placeholders @@ -155,8 +160,9 @@ def _run_step(self, step: Step, bb: Blackboard, parallel: bool = False) -> bool: if ok and self._check_success(step, bb, is_step=True): return True else: - self.logger.log_executor(f"Step [executor]'{step.tool}'[/executor] " + \ - f"attempt {i+1}/{attempts} failed") + self.logger.log_executor( + f"Step [executor]'{step.tool}'[/executor] " + f"attempt {i + 1}/{attempts} failed" + ) return False @@ -167,12 +173,14 @@ def _check_success(self, entity: Step | PlanBase, bb: Blackboard, is_step: bool return True log_value = entity.tool if is_step else entity.kind if self.safe_eval(entity.success_criteria, bb): - self.logger.log_executor(f"Entity '{log_value}' [executor]succeeded[/executor] " + \ - f"with criteria={entity.success_criteria}") + self.logger.log_executor( + f"Entity '{log_value}' [executor]succeeded[/executor] " + f"with criteria={entity.success_criteria}" + ) return True else: - self.logger.log_executor(f"Entity '{log_value}' [error]failed[/error] " + \ - f"with criteria={entity.success_criteria}") + self.logger.log_executor( + f"Entity '{log_value}' [error]failed[/error] " + f"with criteria={entity.success_criteria}" + ) return False def safe_eval(self, expr: str, bb: Blackboard) -> bool: @@ -242,12 +250,14 @@ def _coerce_to_schema(self, schema: List[Tuple[str, str]], key: str, arg: str) - out = arg return out - def _call_tool(self, - tool_name: str, - args: List[ArgValue], - timeout_ms: int = None, - bb: Blackboard = None, - parallel: bool = False) -> Tuple[bool, Any]: + def _call_tool( + self, + tool_name: str, + args: List[ArgValue], + timeout_ms: int = None, + bb: Blackboard = None, + parallel: bool = False, + ) -> Tuple[bool, Any]: """Invoke a registered tool.""" tool = self.registry.tools.get(tool_name) @@ -265,13 +275,11 @@ def _call_tool(self, msg += "'{" for key, value in arg_dict.items(): if first: - msg += f"[validator]'{key}'[/validator]: " + \ - f"[registry]'{value}'[/registry]" + msg += f"[validator]'{key}'[/validator]: " + f"[registry]'{value}'[/registry]" else: - msg += f", [validator]'{key}'[/validator]: " + \ - f"[registry]'{value}'[/registry]" + msg += f", [validator]'{key}'[/validator]: " + f"[registry]'{value}'[/registry]" first = False - msg+="}'" + msg += "}'" self.logger.log_executor(msg) start = time.time() @@ -299,9 +307,11 @@ def _call_tool(self, if tool_log: self.logger.log_executor(f"{tool_log}: {tool_name}") elapsed = (time.time() - start) * 1000 - self.logger.log_executor(f"Executed [executor]'{tool_name}'[/executor] " + \ - f"in [registry]{elapsed:.1f} ms[/registry] " + \ - f"with result:") + self.logger.log_executor( + f"Executed [executor]'{tool_name}'[/executor] " + + f"in [registry]{elapsed:.1f} ms[/registry] " + + "with result:" + ) if isinstance(result, dict): for key, value in result.items(): @@ -314,11 +324,15 @@ def _call_tool(self, return True, result except concurrent.futures.TimeoutError: - self.logger.log_executor(f"Execution of [executor]'{tool_name}'[/executor] " + \ - f"[error]timed out[/error] " + \ - f"after [registry]{timeout_ms}[/registry] ms") + self.logger.log_executor( + f"Execution of [executor]'{tool_name}'[/executor] " + + "[error]timed out[/error] " + + f"after [registry]{timeout_ms}[/registry] ms" + ) return False, None except Exception as e: - self.logger.log_executor(f"Execution [error]failed[/error] for " + \ - f"[executor]'{tool_name}'[/executor]: {e}") + self.logger.log_executor( + "Execution [error]failed[/error] for " + + f"[executor]'{tool_name}'[/executor]: {e}" + ) return False, None diff --git a/src/vulcanai/core/manager.py b/src/vulcanai/core/manager.py index 93e094a..66936d3 100644 --- a/src/vulcanai/core/manager.py +++ b/src/vulcanai/core/manager.py @@ -15,16 +15,16 @@ from typing import Any, Dict, Optional, Tuple from vulcanai.console.logger import VulcanAILogger -from vulcanai.core.executor import Blackboard, PlanExecutor from vulcanai.core.agent import Agent +from vulcanai.core.executor import Blackboard, PlanExecutor from vulcanai.core.plan_types import GlobalPlan from vulcanai.core.validator import PlanValidator from vulcanai.tools.tool_registry import ToolRegistry - class ToolManager: """Manages the LLM Agent and calls the executor with the LLM output.""" + def __init__( self, model: str, @@ -32,7 +32,7 @@ def __init__( validator: Optional[PlanValidator] = None, k: int = 10, hist_depth: int = 3, - logger: Optional[VulcanAILogger] = None + logger: Optional[VulcanAILogger] = None, ): # Logger default to a stdout logger if none is provided (StdoutLogSink) self.logger = logger or VulcanAILogger.default() @@ -90,7 +90,6 @@ def handle_user_request(self, user_text: str, context: Dict[str, Any]) -> Dict[s :return: A dictionary with the execution result, including the plan used and the final blackboard state. """ try: - # Get plan from LLM plan = self.get_plan_from_user_request(user_text, context) if not plan: @@ -167,7 +166,7 @@ def _build_prompt(self, user_text: str, ctx: Dict[str, Any]) -> Tuple[str, str]: """ tools = self.registry.top_k(user_text, self.k) if not tools: - self.logger.log_manager(f"No tools available in the registry.", error=True) + self.logger.log_manager("No tools available in the registry.", error=True) return "", "" tool_descriptions = [] for tool in tools: @@ -190,7 +189,7 @@ def _add_to_history(self, user_text: str, plan_summary: str): if self.history_depth <= 0: self.history = [] else: - self.history = self.history[-self.history_depth:] + self.history = self.history[-self.history_depth :] def update_history_depth(self, new_depth: int): """ @@ -204,7 +203,7 @@ def update_history_depth(self, new_depth: int): if self.history_depth <= 0: self.history = [] else: - self.history = self.history[-self.history_depth:] + self.history = self.history[-self.history_depth :] def update_k_index(self, new_k: int): """ diff --git a/src/vulcanai/core/manager_iterator.py b/src/vulcanai/core/manager_iterator.py index 37aa960..8414792 100644 --- a/src/vulcanai/core/manager_iterator.py +++ b/src/vulcanai/core/manager_iterator.py @@ -16,8 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple from vulcanai.core.manager import ToolManager -from vulcanai.core.plan_types import ArgValue, GlobalPlan, PlanNode, Step -from vulcanai.core.plan_types import GoalSpec +from vulcanai.core.plan_types import ArgValue, GlobalPlan, GoalSpec, PlanNode, Step from vulcanai.core.validator import PlanValidator from vulcanai.tools.tool_registry import ToolRegistry @@ -33,17 +32,18 @@ class TimelineEvent(Enum): class IterativeManager(ToolManager): """Manager that implements iterative planning to re-adapt plans.""" + def __init__( - self, - model: str, - registry: Optional[ToolRegistry]=None, - validator: Optional[PlanValidator]=None, - k: int=5, - hist_depth: int = 3, - logger=None, - max_iters: int = 5, - step_timeout_ms: Optional[int] = None - ): + self, + model: str, + registry: Optional[ToolRegistry] = None, + validator: Optional[PlanValidator] = None, + k: int = 5, + hist_depth: int = 3, + logger=None, + max_iters: int = 5, + step_timeout_ms: Optional[int] = None, + ): super().__init__(model, registry, validator, k, max(3, hist_depth), logger) self.iter: int = 0 @@ -96,8 +96,10 @@ def handle_user_request(self, user_text: str, context: Dict[str, Any]) -> Dict[s # Get plan from LLM plan = self.get_plan_from_user_request(user_text, context) if not plan: - self.logger.log_manager(f"Error getting plan from model", error=True) - self._timeline.append({"iteration": self.iter, "event": TimelineEvent.PLAN_GENERATION_FAILED.value}) + self.logger.log_manager("Error getting plan from model", error=True) + self._timeline.append( + {"iteration": self.iter, "event": TimelineEvent.PLAN_GENERATION_FAILED.value} + ) skip_verification_step = True continue plan_str = str(plan.plan) @@ -114,7 +116,9 @@ def handle_user_request(self, user_text: str, context: Dict[str, Any]) -> Dict[s self.validator.validate(plan) except Exception as e: self.logger.log_manager(f"Plan validation error. Asking for new plan: {e}") - self._timeline.append({"iteration": self.iter, "event": TimelineEvent.PLAN_NOT_VALID.value, "detail": str(e)}) + self._timeline.append( + {"iteration": self.iter, "event": TimelineEvent.PLAN_NOT_VALID.value, "detail": str(e)} + ) continue # Set step timeouts if not set or too high @@ -125,12 +129,14 @@ def handle_user_request(self, user_text: str, context: Dict[str, Any]) -> Dict[s # Execute plan ret = self.execute_plan(plan) - self._timeline.append({ - "iteration": self.iter, - "event": TimelineEvent.PLAN_EXECUTED.value, - "plan": plan.summary, - "result": ret.get("success", False), - }) + self._timeline.append( + { + "iteration": self.iter, + "event": TimelineEvent.PLAN_EXECUTED.value, + "plan": plan.summary, + "result": ret.get("success", False), + } + ) # If execution was successful, return the result if not ret.get("success", False): @@ -140,7 +146,12 @@ def handle_user_request(self, user_text: str, context: Dict[str, Any]) -> Dict[s self.logger.log_manager(f"Error handling user request: {e}", error=True) return {"error": str(e), "timeline": self._timeline} - return {"timeline": self._timeline, "success": self.bb.get("goal_achieved", False), "blackboard": self.bb.copy(), "plan": plan} + return { + "timeline": self._timeline, + "success": self.bb.get("goal_achieved", False), + "blackboard": self.bb.copy(), + "plan": plan, + } def _is_goal_achieved(self): """ @@ -158,7 +169,7 @@ def _get_iter_context(self) -> str: Get context about the current iteration to include in the prompt. Blackboard state is not included here, as it must be added separately when needed. """ - context_template="""\ + context_template = """\ - Iteration: {iteration} - Goal used: {goal} - Last timeline events: @@ -169,11 +180,16 @@ def _get_iter_context(self) -> str: timeline_events = "" if self._timeline: - for e in self._timeline[-self._timeline_events_printed:]: - if e.get('event', "") == TimelineEvent.PLAN_EXECUTED.value: - timeline_events += f" - Iteration {e['iteration']}: {e['event']} - Plan: {e.get('plan', 'N/A')} - Result: {'Success' if e.get('result', False) else 'Failure'}\n" - elif e.get('event', "") == TimelineEvent.PLAN_NOT_VALID.value: - timeline_events += f" - Iteration {e['iteration']}: {e['event']} - Detail: {e.get('detail', 'N/A')}\n" + for e in self._timeline[-self._timeline_events_printed :]: + if e.get("event", "") == TimelineEvent.PLAN_EXECUTED.value: + timeline_events += ( + f" - Iteration {e['iteration']}: {e['event']} - Plan: {e.get('plan', 'N/A')}" + f" - Result: {'Success' if e.get('result', False) else 'Failure'}\n" + ) + elif e.get("event", "") == TimelineEvent.PLAN_NOT_VALID.value: + timeline_events += ( + f" - Iteration {e['iteration']}: {e['event']} - Detail: {e.get('detail', 'N/A')}\n" + ) else: timeline_events += f" - Iteration {e['iteration']}: {e['event']}\n" else: @@ -184,7 +200,7 @@ def _get_iter_context(self) -> str: iteration=self.iter, goal=self.goal.summary if self.goal else "N/A", timeline_events=timeline_events, - bb_snapshot="{bb_snapshot}" # Placeholder for blackboard snapshot + bb_snapshot="{bb_snapshot}", # Placeholder for blackboard snapshot ) return context @@ -213,7 +229,9 @@ def _build_prompt(self, user_text: str, ctx: Dict[str, Any]) -> Tuple[str, str]: user_context=user_context, ) - user_prompt = "## User Request: " + user_text + "\nContext:\n" + self._get_iter_context().format(bb_snapshot=bb_snapshot) + user_prompt = ( + "## User Request: " + user_text + "\nContext:\n" + self._get_iter_context().format(bb_snapshot=bb_snapshot) + ) return system_prompt, user_prompt @@ -239,10 +257,7 @@ def _build_goal_prompt(self, user_text: str, ctx: Dict[str, Any]) -> Tuple[str, user_context = self._parse_user_context() system_prompt = self._get_goal_prompt_template() - system_prompt = system_prompt.format( - user_context=user_context, - tools_text=tools_text - ) + system_prompt = system_prompt.format(user_context=user_context, tools_text=tools_text) user_prompt = "User request:\n" + user_text return system_prompt, user_prompt @@ -270,18 +285,26 @@ def _get_goal_prompt_template(self) -> str: template = """ You are a goal generator of a robotic/agent system. Your objective is to produce the final goal to be verified during iterative execution from the user request. -You can access the blackboard (bb) to check the current state of the system, which is updated by the execution of verification tools. +You can access the blackboard (bb) to check the current state of the system, which is updated by the execution +of verification tools. Rules: - You have two modes to choose from for goal verification: 'perceptual' and 'objective'. - 1) Objective mode uses deterministic predicates based on validation tools results to verify if the goal has been achieved. - 2) Perceptual mode uses AI to verify if the goal has been achieved based on evidence (e.g., images) or blackboard data. -- Whenever possible, prefer 'objective' mode as it is more reliable and faster to evaluate. Rely on 'perceptual' mode when the task requires a higher level of abstraction and cannot be easily converted to objective predicates. -- Add to 'verify_tools' any tool needed to verify the goal. It will be executed to update the blackboard before checking the success predicates. This applies to both modes. + 1) Objective mode uses deterministic predicates based on validation tools results to verify if the goal has + been achieved. + 2) Perceptual mode uses AI to verify if the goal has been achieved based on evidence (e.g., images) or blackboard + data. +- Whenever possible, prefer 'objective' mode as it is more reliable and faster to evaluate. Rely on 'perceptual' mode + when the task requires a higher level of abstraction and cannot be easily converted to objective predicates. +- Add to 'verify_tools' any tool needed to verify the goal. It will be executed to update the blackboard before + checking the success predicates. This applies to both modes. - When using 'objective' mode, the 'success_predicates' must be simple boolean expressions over the blackboard. - If no tool is useful to verify the goal rely on 'perceptual' mode. -- When using 'perceptual' mode, provide a 'evidence_bb_keys' list containing blackboard keys with relevant evidence to consider. -- Use "{{{{bb.tool_name.key}}}}" to reference tools results. This MUST be used in 'success_predicates'. It can also be used in 'verify_tools' arguments. -For example, if tool 'get_pose' outputs {{"pose": [1.0, 2.0]}}, you can pass reference it as "{{{{bb.get_pose.pose}}}}". +- When using 'perceptual' mode, provide a 'evidence_bb_keys' list containing blackboard keys with relevant evidence + to consider. +- Use "{{{{bb.tool_name.key}}}}" to reference tools results. This MUST be used in 'success_predicates'. It can also + be used in 'verify_tools' arguments. +For example, if tool 'get_pose' outputs {{"pose": [1.0, 2.0]}}, you can pass reference it as +"{{{{bb.get_pose.pose}}}}". {user_context} ## Available tools: @@ -316,12 +339,14 @@ def _get_prompt_template(self) -> str: You are a iterative planner for a robotic/agent system. Your objective is to produce a plan that will make progress toward the user's goal in the current iteration. Rules: -- You will produce ONE actionable plan at a time. Prefer the simplest plan that makes tangible progress toward the goal. +- You will produce ONE actionable plan at a time. Prefer the simplest plan that makes tangible progress toward the + goal. - You must follow the logic execute, observe, and then re-plan if needed. - System context will be provided as a blackboard that contains the current state of the system as outputs of tools. - Keep the plan minimal and focused on the next immediate progress. - CONTEXT will be provided about the current state of the system and previous iterations. -- If the last step failed, propose a different approach or tool/arguments and include a brief rationale of what will be done differently (as a comment-like line in the Summary). +- If the last step failed, propose a different approach or tool/arguments and include a brief rationale of what will + be done differently (as a comment-like line in the Summary). - Add only optional execution control parameters if strictly necessary or requested by the user. - Use "{{{{bb.tool_name.key}}}}" to pass outputs from previous steps if relevant. For example, if tool 'detect_object' outputs {{"pose": [1.0, 2.0]}}, @@ -406,8 +431,8 @@ def _init_single_tool_plan(self): steps=[ Step( tool="name_of_tool", # Placeholder tool name - args=[], # Placeholder args - timeout_ms=self.step_timeout_ms + args=[], # Placeholder args + timeout_ms=self.step_timeout_ms, ), ], ), @@ -443,7 +468,10 @@ def _verify_progress(self) -> bool: elif mode == "perceptual": # Check if goal is already achieved using AI inference - instruction = f"Check if the user goal has been achieved based on the evidence provided. User goal: '{self.goal.summary}'" + instruction = ( + f"Check if the user goal has been achieved based on the evidence provided." + f" User goal: '{self.goal.summary}'" + ) evidence_keys = self.goal.evidence_bb_keys # Gather evidence from blackboard evidence = {} @@ -453,12 +481,16 @@ def _verify_progress(self) -> bool: tool = self.registry.tools.get(verify_tool.tool) if tool: evidence[tool.name] = self.bb.get(tool.name, {}) - # Note that provide_images is only available for ValidationTool, needs to be checked after verifying its a validation tool + # Note that provide_images is only available for ValidationTool, needs to be checked after verifying + # its a validation tool if tool and tool.is_validation_tool and tool.provide_images: images_paths.extend(self.bb.get(tool.name, {}).get("images", [])) # Build prompt to check if goal is achieved system_prompt, user_prompt = self._build_validation_prompt(instruction, evidence) - self.logger.log_manager(f"Running perceptual verification with instruction: {instruction} - Evidence: {evidence_keys} - Images: {images_paths}") + self.logger.log_manager( + f"Running perceptual verification with instruction: {instruction} - Evidence: {evidence_keys}" + f" - Images: {images_paths}" + ) try: validation = self.llm.inference_validation(system_prompt, user_prompt, images_paths, self.history) @@ -482,7 +514,7 @@ def _run_verification_tools(self): it is the first iteration (to set initial state). """ if self._timeline and len(self._timeline) > 0: - last_event = self._timeline[-1].get('event', "N/A") + last_event = self._timeline[-1].get("event", "N/A") if last_event not in (TimelineEvent.PLAN_EXECUTED.value, TimelineEvent.GOAL_SET.value): self.logger.log_manager("Skipping verification tools as no plan has been executed in this iteration.") return @@ -500,7 +532,10 @@ def _run_verification_tools(self): result = self.execute_plan(self._single_tool_plan) if not result.get("success"): self.bb[tool_name] = {"error": "Validation tool execution failed"} - self.logger.log_manager(f"Error running verification tool '{tool_name}'. BB entry of this tool will be empty.", error=True) + self.logger.log_manager( + f"Error running verification tool '{tool_name}'. BB entry of this tool will be empty.", + error=True, + ) except Exception as e: self.logger.log_manager(f"Error running verification tool '{tool_name}': {e}", error=True) continue @@ -510,8 +545,8 @@ def _add_to_history(self, user_text: str, plan_summary: str): last_detail = "" last_event = "N/A" if self._timeline and len(self._timeline) > 0: - last_event = self._timeline[-1].get('event', "N/A") - if self._timeline[-1].get('detail', None): + last_event = self._timeline[-1].get("event", "N/A") + if self._timeline[-1].get("detail", None): last_detail = f" - Details: {self._timeline[-1].get('detail', '')}" history_plan = f"Plan summary: {plan_summary} - Iteration result: {last_event}{last_detail}" @@ -521,4 +556,4 @@ def _add_to_history(self, user_text: str, plan_summary: str): if self.history_depth <= 0: self.history = [] else: - self.history = self.history[-self.history_depth:] + self.history = self.history[-self.history_depth :] diff --git a/src/vulcanai/core/manager_plan.py b/src/vulcanai/core/manager_plan.py index 14ce755..78b3cbb 100644 --- a/src/vulcanai/core/manager_plan.py +++ b/src/vulcanai/core/manager_plan.py @@ -18,17 +18,19 @@ from vulcanai.core.validator import PlanValidator from vulcanai.tools.tool_registry import ToolRegistry + class PlanManager(ToolManager): """Manager to target complex plan generation, involving multiples plans and steps.""" + def __init__( - self, - model: str, - registry: Optional[ToolRegistry]=None, - validator: Optional[PlanValidator]=None, - k: int=5, - hist_depth: int = 3, - logger=None - ): + self, + model: str, + registry: Optional[ToolRegistry] = None, + validator: Optional[PlanValidator] = None, + k: int = 5, + hist_depth: int = 3, + logger=None, + ): super().__init__(model, registry=registry, validator=validator, k=k, hist_depth=hist_depth, logger=logger) def _get_prompt_template(self) -> str: diff --git a/src/vulcanai/core/plan_types.py b/src/vulcanai/core/plan_types.py index 678e019..bdf915f 100644 --- a/src/vulcanai/core/plan_types.py +++ b/src/vulcanai/core/plan_types.py @@ -12,22 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pydantic import BaseModel, Field from typing import List, Literal, Optional, Union +from pydantic import BaseModel, Field + from vulcanai.console.logger import VulcanAILogger -Kind = Literal["SEQUENCE","PARALLEL"] +Kind = Literal["SEQUENCE", "PARALLEL"] class ArgValue(BaseModel): """Key-value pair representing a tool argument.""" + key: str val: Union[str, int, float, bool] class StepBase(BaseModel): """Atomic execution unit bound to a ITool.""" + # Associated tool tool: str = None # Tool arguments @@ -39,6 +42,7 @@ class StepBase(BaseModel): class Step(StepBase): """Final atomic execution unit bound to a ITool with execution control.""" + # Execution control condition: Optional[str] = None success_criteria: Optional[str] = None @@ -51,6 +55,7 @@ class PlanBase(BaseModel): A base class that defines a plan which is composed of one or more steps and has execution control parameters. """ + # SEQUENCE or PARALLEL execution of children kind: Kind # Child nodes @@ -66,11 +71,13 @@ class PlanNode(PlanBase): """ Final plan node with optional failure handling. """ + on_fail: Optional["PlanBase"] = None class GlobalPlan(BaseModel): """GlobalPlan returned by the LLM with each step to be executed.""" + # Top-level plan structure. Always executed sequentially. plan: List[PlanNode] = Field(default_factory=list) # Brief summary of the plan @@ -88,59 +95,84 @@ def __str__(self) -> str: for i, node in enumerate(self.plan, 1): # - PlanNode : kind= - lines.append(f"- PlanNode {i}: <{color_variable}>kind=" + \ - f"<{color_value}>{node.kind}") + lines.append( + f"- PlanNode {i}: <{color_variable}>kind=" + + f"<{color_value}>{node.kind}" + ) if node.condition: - # Condition: + # Condition: lines.append(f"\tCondition: <{color_value}>{node.condition}") if node.retry: - # Retry: - lines.append(f"\t<{color_error}>Retry: " + \ - f"<{color_value}>{node.retry}") + # Retry: + lines.append( + f"\t<{color_error}>Retry: " + f"<{color_value}>{node.retry}" + ) if node.timeout_ms: - # Timeout: ms - lines.append(f"\t<{color_error}>Timeout: " + \ - f"<{color_value}>{node.timeout_ms} ms") + # Timeout: ms + lines.append( + f"\t<{color_error}>Timeout: " + + f"<{color_value}>{node.timeout_ms} ms" + ) if node.success_criteria: - # Succes Criteria: - lines.append(f"\<{color_tool}>tSuccess Criteria: " + \ - f"<{color_value}>{node.success_criteria}") + # Succes Criteria: + lines.append( + f"\<{color_tool}>tSuccess Criteria: " + + f"<{color_value}>{node.success_criteria}" + ) if node.on_fail: - # On Fail: with steps - lines.append(f"\tOn Fail: <{color_value}>{node.on_fail.kind} with " + \ - f"<{color_value}>{len(node.on_fail.steps)} steps") + # On Fail: with steps + lines.append( + f"\tOn Fail: <{color_value}>{node.on_fail.kind} with " + + f"<{color_value}>{len(node.on_fail.steps)} steps" + ) for j, step in enumerate(node.steps, 1): - #arg_str: =, ..., = - arg_str = ", ".join([f"<{color_variable}>{a.key}=" + \ - f"<{color_value}>{a.val}" for a in step.args]) \ - if step.args else f"<{color_value}>no args" - # Step : () + # arg_str: =, ..., = + arg_str = ( + ", ".join( + [ + f"<{color_variable}>{a.key}=" + + f"<{color_value}>{a.val}" + for a in step.args + ] + ) + if step.args + else f"<{color_value}>no args" + ) + # Step : () lines.append(f"\tStep {j}: <{color_tool}>{step.tool}({arg_str})") if step.condition: - # Condition: + # Condition: lines.append(f"\t Condition: <{color_value}>{step.condition}") if step.retry: - # Condition: - lines.append(f"\t <{color_error}>Retry: " + \ - f"<{color_value}>{step.retry}") + # Condition: + lines.append( + f"\t <{color_error}>Retry: " + f"<{color_value}>{step.retry}" + ) if step.timeout_ms: - # Timeout: ms - lines.append(f"\t <{color_error}>Timeout: " + \ - f"<{color_value}>{step.timeout_ms} ms") + # Timeout: ms + lines.append( + f"\t <{color_error}>Timeout: " + + f"<{color_value}>{step.timeout_ms} ms" + ) if step.success_criteria: - # Success Criteria: - lines.append(f"\t <{color_tool}>Success Criteria: " + \ - f"<{color_value}>{step.success_criteria}") + # Success Criteria: + lines.append( + f"\t <{color_tool}>Success Criteria: " + + f"<{color_value}>{step.success_criteria}" + ) return "\n".join(lines) class GoalSpec(BaseModel): """Specification defining the user goal to be achieved.""" + summary: str # Mode used for goal verification: - # - [Perceptual] mode uses AI to verify if the goal has been achieved based on evidence (e.g., images) or blackboard data. - # - [Objective] mode uses deterministic predicates based on validation tools results to verify if the goal has been achieved. + # - [Perceptual] mode uses AI to verify if the goal has been achieved based on evidence (e.g., images) or + # blackboard data. + # - [Objective] mode uses deterministic predicates based on validation tools results to verify if the goal + # has been achieved. mode: Literal["perceptual", "objective"] = "objective" # List of simple boolean predicates over the blackboard (e.g., "{{bb.navigation.at_target}} == true") success_predicates: List[str] = Field(default_factory=list) @@ -167,6 +199,7 @@ def __str__(self) -> str: class AIValidation(BaseModel): """AI-based validation result.""" + success: bool confidence: float explanation: Optional[str] = None diff --git a/src/vulcanai/core/validator.py b/src/vulcanai/core/validator.py index 8ed4954..9cf5122 100644 --- a/src/vulcanai/core/validator.py +++ b/src/vulcanai/core/validator.py @@ -13,20 +13,15 @@ # limitations under the License. import re + from vulcanai.core.plan_types import GlobalPlan, PlanNode, Step -TYPE_ALIAS = { - "int": int, - "integer": int, - "float": float, - "bool": bool, - "boolean": bool, - "str": str, - "string": str -} +TYPE_ALIAS = {"int": int, "integer": int, "float": float, "bool": bool, "boolean": bool, "str": str, "string": str} + class PlanValidator: """Validates and optionally augments a plan before execution.""" + def __init__(self, registry): self.registry = registry @@ -60,7 +55,10 @@ def _validate_step(self, step: Step): tool = self.registry.tools.get(step.tool) if tool.input_schema: if len(step.args) != len(tool.input_schema): - raise ValueError(f"Tool '{tool.name}' expects {len(tool.input_schema)} arguments, but {len(step.args)} were provided.") + raise ValueError( + f"Tool '{tool.name}' expects {len(tool.input_schema)} arguments," + f" but {len(step.args)} were provided." + ) for arg in step.args: if arg.key not in {k for d in tool.input_schema for k in d}: raise ValueError(f"Argument '{arg.key}' not defined in tool '{tool.name}' input schema.") @@ -71,14 +69,21 @@ def _validate_step(self, step: Step): if bb_items: bb_correct_format = re.findall(r"\{\{(bb\..*?)\}\}", arg.val) if len(bb_items) != len(bb_correct_format): - raise ValueError(f"Blackboard reference in argument '{arg.key}' of tool '{tool.name}' is incorrectly formatted: '{arg.val}'") + raise ValueError( + f"Blackboard reference in argument '{arg.key}' of tool '{tool.name}'" + f" is incorrectly formatted: '{arg.val}'" + ) else: - # If there are no blackboard references, ensure the argument is of string type, as it must adhere to the schema + # If there are no blackboard references, ensure the argument is of string type, as it must + # adhere to the schema for schema in tool.input_schema: if arg.key in schema: is_string_type = schema[1] in ["str", "string"] if not is_string_type: - raise ValueError(f"Argument '{arg.key}' of tool '{tool.name}' expects type '{TYPE_ALIAS.get(schema[1])}', but got '{type(arg.val).__name__}'.") + raise ValueError( + f"Argument '{arg.key}' of tool '{tool.name}' expects type" + f" '{TYPE_ALIAS.get(schema[1])}', but got '{type(arg.val).__name__}'." + ) # Check type if static value for non-string types else: expected_type = None @@ -87,13 +92,18 @@ def _validate_step(self, step: Step): print(f"Schema for arg '{arg.key}' in tool '{tool.name}': {schema}") # Debug print # Use TYPE_ALIAS to map string type names to actual types expected_type = TYPE_ALIAS.get(schema[1]) - print(f"Expected type for arg '{arg.key}' in tool '{tool.name}': {expected_type}") # Debug print + print( + f"Expected type for arg '{arg.key}' in tool '{tool.name}': {expected_type}" + ) # Debug print break if expected_type and not isinstance(arg.val, expected_type): - if expected_type == float and isinstance(arg.val, int): + if expected_type is float and isinstance(arg.val, int): # Allow int to float conversion continue - raise ValueError(f"Argument '{arg.key}' of tool '{tool.name}' expects type '{expected_type.__name__}', but got '{type(arg.val).__name__}'.") + raise ValueError( + f"Argument '{arg.key}' of tool '{tool.name}' expects type '{expected_type.__name__}'," + f" but got '{type(arg.val).__name__}'." + ) else: if step.args: raise ValueError(f"Tool '{tool.name}' does not accept arguments, but arguments were provided.") diff --git a/src/vulcanai/models/__init__.py b/src/vulcanai/models/__init__.py index 62724bb..7da8ce0 100644 --- a/src/vulcanai/models/__init__.py +++ b/src/vulcanai/models/__init__.py @@ -22,17 +22,19 @@ __all__ = list(_EXPORTS.keys()) + def __getattr__(name: str): target = _EXPORTS.get(name) if not target: raise AttributeError(f"module '{__name__}' has no attribute '{name}'") - module_name, attr_name = target.split(':') + module_name, attr_name = target.split(":") if module_name.startswith("."): module = import_module(module_name, package=__name__) else: module = import_module(module_name) return getattr(module, attr_name) + def __dir__() -> list[str]: """Make dir() show the public API.""" return sorted(list(globals().keys()) + __all__) diff --git a/src/vulcanai/models/gemini.py b/src/vulcanai/models/gemini.py index c7baef6..abef3d5 100644 --- a/src/vulcanai/models/gemini.py +++ b/src/vulcanai/models/gemini.py @@ -12,23 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google import genai -from google.genai import types as gtypes -from typing import Any, Dict, Iterable, Optional, Type, TypeVar import mimetypes import os import time +from typing import Iterable, Optional, Type, TypeVar + +from google import genai +from google.genai import types as gtypes from vulcanai.core.plan_types import AIValidation, GlobalPlan, GoalSpec from vulcanai.models.model import IModel, IModelHooks -T = TypeVar('T', GlobalPlan, GoalSpec, AIValidation) +T = TypeVar("T", GlobalPlan, GoalSpec, AIValidation) class GeminiModel(IModel): + """Wrapper for most of Google models, Gemini mainly.""" - """ Wrapper for most of Google models, Gemini mainly. """ - def __init__(self, model_name:str, logger=None, hooks: Optional[IModelHooks] = None): + def __init__(self, model_name: str, logger=None, hooks: Optional[IModelHooks] = None): super().__init__() self.logger = logger self.model_name = model_name @@ -76,7 +77,7 @@ def _inference( # Notify hooks of request start try: self.hooks.on_request_start() - except Exception as e: + except Exception: pass response = self.model.models.generate_content( @@ -90,13 +91,14 @@ def _inference( try: parsed_response = response.parsed except Exception as e: - self.logger.log_manager(f"ERROR. Failed to get parsed goal from Gemini response, " + \ - f"falling back to text: {e}", error=True) + self.logger.log_manager( + "ERROR. Failed to get parsed goal from Gemini response, " + f"falling back to text: {e}", error=True + ) finally: # Notify hooks of request end try: self.hooks.on_request_end() - except Exception as e: + except Exception: pass # Fallback to get GoalSpec from text if the parsed field is not available @@ -109,25 +111,28 @@ def _inference( try: parsed_response = GoalSpec.model_validate_json(raw) - except Exception as e: + except Exception: try: import json + parsed_response = GoalSpec(**json.loads(raw)) except Exception as e: - self.logger.log_manager(f"ERROR. Failed to parse raw {response_cls.__name__} JSON: {e}", - error=True) + self.logger.log_manager( + f"ERROR. Failed to parse raw {response_cls.__name__} JSON: {e}", error=True + ) end = time.time() self.logger.log_manager(f"Gemini response time: {end - start:.3f} seconds") usage = getattr(response, "usage_metadata", None) if usage: input_tokens = usage.prompt_token_count output_tokens = usage.candidates_token_count - self.logger.log_manager(f"Prompt tokens: [manager]{input_tokens}[/manager], " + \ - f"Completion tokens: [manager]{output_tokens}[/manager]") + self.logger.log_manager( + f"Prompt tokens: [manager]{input_tokens}[/manager], " + + f"Completion tokens: [manager]{output_tokens}[/manager]" + ) return parsed_response - def _build_user_content(self, user_text: str, images: Optional[Iterable[str]]) -> list[gtypes.Part]: """Compose user content list with text first and optional images as image_url parts.""" content: list[gtypes.Part] = [gtypes.Part.from_text(text=user_text)] @@ -135,28 +140,34 @@ def _build_user_content(self, user_text: str, images: Optional[Iterable[str]]) - for image_path in images: if isinstance(image_path, str) and image_path.startswith("http"): import requests + img = requests.get(image_path) if img.status_code != 200: - self.logger.log_manager(f"ERROR. Failed to fetch image from URL '{image_path}' ", - error=True) + self.logger.log_manager(f"ERROR. Failed to fetch image from URL '{image_path}' ", error=True) continue - content.append(gtypes.Part.from_bytes(data=img.content, mime_type=img.headers.get("Content-Type", "image/png"))) + content.append( + gtypes.Part.from_bytes( + data=img.content, mime_type=img.headers.get("Content-Type", "image/png") + ) + ) else: try: img_bytes = self._read_image(image_path) mime = mimetypes.guess_type(image_path)[0] or "image/png" - content.append(gtypes.Part.from_bytes( - data=img_bytes, - mime_type=mime, - )) + content.append( + gtypes.Part.from_bytes( + data=img_bytes, + mime_type=mime, + ) + ) except Exception as e: # Fail soft on a single bad image but continue with others - self.logger.log_manager(f"Fail soft. Image '{image_path}' could not be encoded: {e}", - error=True) + self.logger.log_manager( + f"Fail soft. Image '{image_path}' could not be encoded: {e}", error=True + ) return content - def _build_messages( self, user_content: list[gtypes.Part], @@ -169,7 +180,11 @@ def _build_messages( if history: for user_text, plan_summary in history: messages.append(gtypes.Content(role="user", parts=[gtypes.Part.from_text(text=user_text)])) - messages.append(gtypes.Content(role="assistant", parts=[gtypes.Part.from_text(text=f"Action plan: {plan_summary}")])) + messages.append( + gtypes.Content( + role="assistant", parts=[gtypes.Part.from_text(text=f"Action plan: {plan_summary}")] + ) + ) # Append current user turn (text + images) messages.append(gtypes.Content(role="user", parts=user_content)) diff --git a/src/vulcanai/models/model.py b/src/vulcanai/models/model.py index 6db20f2..410b190 100644 --- a/src/vulcanai/models/model.py +++ b/src/vulcanai/models/model.py @@ -13,17 +13,17 @@ # limitations under the License. import base64 -from typing import Any, Optional, TypeVar from abc import ABC, abstractmethod +from typing import Any, Optional, TypeVar from vulcanai.core.plan_types import AIValidation, GlobalPlan, GoalSpec - -T = TypeVar('T', GlobalPlan, GoalSpec, AIValidation) +T = TypeVar("T", GlobalPlan, GoalSpec, AIValidation) class IModel(ABC): """Abstract class for models.""" + # Model instance model: Any = None # Model name @@ -34,11 +34,7 @@ class IModel(ABC): hooks: Any = None def plan_inference( - self, - system_prompt: str, - user_prompt: str, - images: list[str], - history: list[tuple[str, str]] + self, system_prompt: str, user_prompt: str, images: list[str], history: list[tuple[str, str]] ) -> Optional[GlobalPlan]: """ Call the generic inference with GlobalPlan as response type. @@ -54,14 +50,11 @@ def plan_inference( user_prompt=user_prompt, response_cls=GlobalPlan, images=images, - history=history + history=history, ) def goal_inference( - self, - system_prompt: str, - user_prompt: str, - history: list[tuple[str, str]] + self, system_prompt: str, user_prompt: str, history: list[tuple[str, str]] ) -> Optional[GoalSpec]: """ Call the generic inference with GoalSpec as response type (no images). @@ -72,19 +65,11 @@ def goal_inference( :return: Parsed response object of type GoalSpec, or None on error. """ return self._inference( - system_prompt=system_prompt, - user_prompt=user_prompt, - response_cls=GoalSpec, - images=None, - history=history + system_prompt=system_prompt, user_prompt=user_prompt, response_cls=GoalSpec, images=None, history=history ) def validation_inference( - self, - system_prompt: str, - user_prompt: str, - images: list[str], - history: list[tuple[str, str]] + self, system_prompt: str, user_prompt: str, images: list[str], history: list[tuple[str, str]] ) -> Optional[AIValidation]: """ Call the generic inference with AIValidation as response type (no history). @@ -99,7 +84,7 @@ def validation_inference( user_prompt=user_prompt, response_cls=AIValidation, images=images, - history=history + history=history, ) @abstractmethod @@ -117,7 +102,9 @@ def _encode_image(self, image_path: str) -> str: class IModelHooks(ABC): """No-op base hooks for LLM activity.""" + def on_request_start(self) -> None: pass + def on_request_end(self) -> None: pass diff --git a/src/vulcanai/models/ollama_model.py b/src/vulcanai/models/ollama_model.py index 7e9877f..0f1a525 100644 --- a/src/vulcanai/models/ollama_model.py +++ b/src/vulcanai/models/ollama_model.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ollama -from typing import Any, Dict, Iterable, Optional,Type, TypeVar import time +from typing import Any, Dict, Iterable, Optional, Type, TypeVar + +import ollama from vulcanai.core.plan_types import AIValidation, GlobalPlan, GoalSpec from vulcanai.models.model import IModel, IModelHooks # Generic type variable for response classes -T = TypeVar('T', GlobalPlan, GoalSpec, AIValidation) +T = TypeVar("T", GlobalPlan, GoalSpec, AIValidation) class OllamaModel(IModel): + """Wrapper for Ollama models.""" - """ Wrapper for Ollama models. """ def __init__(self, model_name: str, logger=None, hooks: Optional[IModelHooks] = None): super().__init__() self.logger = logger @@ -66,7 +67,7 @@ def _inference( # Notify hooks of request start try: self.hooks.on_request_start() - except Exception as e: + except Exception: pass # Call Ollama with response_format bound to the desired schema/class @@ -75,7 +76,7 @@ def _inference( model=self.model_name, messages=messages, format=response_cls.model_json_schema(), - options={"temperature": 0.1} + options={"temperature": 0.1}, ) except Exception as e: self.logger.log_manager(f"ERROR. Ollama API: {e}", error=True) @@ -84,7 +85,7 @@ def _inference( # Notify hooks of request end try: self.hooks.on_request_end() - except Exception as e: + except Exception: pass # Extract parsed object safely @@ -92,16 +93,17 @@ def _inference( try: parsed = response_cls.model_validate_json(completion.message.content) except Exception as e: - self.logger.log_manager(f"ERROR. Failed to parse response into {response_cls.__name__}: {e}", - error=True) + self.logger.log_manager(f"ERROR. Failed to parse response into {response_cls.__name__}: {e}", error=True) end = time.time() self.logger.log_manager(f"Ollama response time: {end - start:.3f} seconds") try: input_tokens = completion.prompt_eval_count output_tokens = completion.eval_count - self.logger.log_manager(f"Prompt tokens: [manager]{input_tokens}[/manager], " + \ - f"Completion tokens: [manager]{output_tokens}[/manager]") + self.logger.log_manager( + f"Prompt tokens: [manager]{input_tokens}[/manager], " + + f"Completion tokens: [manager]{output_tokens}[/manager]" + ) except Exception: pass @@ -119,8 +121,7 @@ def _build_user_content(self, user_text: str, images: Optional[Iterable[str]]) - except Exception as e: # Fail soft on a single bad image but continue with others - self.logger.log_manager(f"Fail soft. Image '{image_path}' could not be encoded: {e}", - error=True) + self.logger.log_manager(f"Fail soft. Image '{image_path}' could not be encoded: {e}", error=True) if encoded_images: content["images"] = encoded_images return content diff --git a/src/vulcanai/models/openai.py b/src/vulcanai/models/openai.py index 172e434..34f3f32 100644 --- a/src/vulcanai/models/openai.py +++ b/src/vulcanai/models/openai.py @@ -12,20 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from openai import OpenAI -from typing import Any, Dict, Iterable, Optional, Type, TypeVar import mimetypes import time +from typing import Any, Dict, Iterable, Optional, Type, TypeVar + +from openai import OpenAI from vulcanai.core.plan_types import AIValidation, GlobalPlan, GoalSpec from vulcanai.models.model import IModel, IModelHooks # Generic type variable for response classes -T = TypeVar('T', GlobalPlan, GoalSpec, AIValidation) +T = TypeVar("T", GlobalPlan, GoalSpec, AIValidation) + class OpenAIModel(IModel): + """Wrapper for OpenAI models.""" - """ Wrapper for OpenAI models. """ def __init__(self, model_name: str, logger=None, hooks: Optional[IModelHooks] = None): super().__init__() self.logger = logger @@ -66,7 +68,7 @@ def _inference( # Notify hooks of request start try: self.hooks.on_request_start() - except Exception as e: + except Exception: pass # Call OpenAI with response_format bound to the desired schema/class @@ -83,7 +85,7 @@ def _inference( # Notify hooks of request end try: self.hooks.on_request_end() - except Exception as e: + except Exception: pass # Extract parsed object safely @@ -98,8 +100,10 @@ def _inference( try: input_tokens = completion.usage.prompt_tokens output_tokens = completion.usage.completion_tokens - self.logger.log_manager(f"Prompt tokens: [manager]{input_tokens}[/manager], " + \ - f"Completion tokens: [manager]{output_tokens}[/manager]") + self.logger.log_manager( + f"Prompt tokens: [manager]{input_tokens}[/manager], " + + f"Completion tokens: [manager]{output_tokens}[/manager]" + ) except Exception: pass @@ -116,14 +120,18 @@ def _build_user_content(self, user_text: str, images: Optional[Iterable[str]]) - try: base64_image = self._encode_image(image_path) mime = mimetypes.guess_type(image_path)[0] or "image/png" - content.append({ - "type": "image_url", - "image_url": {"url": f"data:{mime};base64,{base64_image}"}, - }) + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:{mime};base64,{base64_image}"}, + } + ) except Exception as e: # Fail soft on a single bad image but continue with others - self.logger.log_manager(f"Fail soft. Image '{image_path}' could not be encoded: {e}", error=True) + self.logger.log_manager( + f"Fail soft. Image '{image_path}' could not be encoded: {e}", error=True + ) return content def _build_messages( diff --git a/src/vulcanai/tools/__init__.py b/src/vulcanai/tools/__init__.py index 6e23f43..49b6754 100644 --- a/src/vulcanai/tools/__init__.py +++ b/src/vulcanai/tools/__init__.py @@ -24,17 +24,19 @@ __all__ = list(_EXPORTS.keys()) + def __getattr__(name: str): target = _EXPORTS.get(name) if not target: raise AttributeError(f"module '{__name__}' has no attribute '{name}'") - module_name, attr_name = target.split(':') + module_name, attr_name = target.split(":") if module_name.startswith("."): module = import_module(module_name, package=__name__) else: module = import_module(module_name) return getattr(module, attr_name) + def __dir__() -> list[str]: """Make dir() show the public API.""" return sorted(list(globals().keys()) + __all__) diff --git a/src/vulcanai/tools/embedder.py b/src/vulcanai/tools/embedder.py index 8328a6d..d8c4e2c 100644 --- a/src/vulcanai/tools/embedder.py +++ b/src/vulcanai/tools/embedder.py @@ -15,6 +15,7 @@ import numpy as np from sentence_transformers import SentenceTransformer + class SBERTEmbedder: def __init__(self, model_name="all-MiniLM-L6-v2"): self.model = SentenceTransformer(model_name) @@ -27,5 +28,5 @@ def raw_embed(self, text: str): return self.model.encode(text, convert_to_numpy=False) def similarity(self, vecs1, vecs2): - """ Compute cosine similarity between two sets of vectors. """ + """Compute cosine similarity between two sets of vectors.""" return self.model.similarity(vecs1, vecs2) diff --git a/src/vulcanai/tools/tool_registry.py b/src/vulcanai/tools/tool_registry.py index 0807deb..9a0cf4e 100644 --- a/src/vulcanai/tools/tool_registry.py +++ b/src/vulcanai/tools/tool_registry.py @@ -13,16 +13,17 @@ # limitations under the License. import importlib -import numpy as np import sys from importlib.metadata import entry_points from pathlib import Path from types import ModuleType from typing import Dict, List, Tuple, Type -from vulcanai.tools.embedder import SBERTEmbedder +import numpy as np + from vulcanai.console.logger import VulcanAILogger -from vulcanai.tools.tools import ITool, CompositeTool +from vulcanai.tools.embedder import SBERTEmbedder +from vulcanai.tools.tools import CompositeTool, ITool def vulcanai_tool(cls: Type[ITool]): @@ -35,9 +36,12 @@ def vulcanai_tool(cls: Type[ITool]): class HelpTool(ITool): """A tool that provides help information.""" + name = "help" - description = "Provides help information for using the library. It can list all available tools or" \ - " give info about the usage of a specific tool if 'tool_name' is provided as an argument." + description = ( + "Provides help information for using the library. It can list all available tools or" + " give info about the usage of a specific tool if 'tool_name' is provided as an argument." + ) tags = ["help", "info", "documentation", "usage", "developer", "manual", "available tools"] input_schema = [("tool", "string")] output_schema = {"info": "str"} @@ -72,8 +76,8 @@ def run(self, **kwargs): class ToolRegistry: - """Holds all known tools and performs vector search over metadata.""" + def __init__(self, embedder=None, logger=None): # Logging function from the class VulcanConsole self.logger = logger or VulcanAILogger.default() @@ -118,8 +122,9 @@ def activate_tool(self, tool_name) -> bool: return False # Check if the tool is deactivated if tool_name not in self.deactivated_tools: - self.logger.log_registry(f"Tool [registry]'{tool_name}'[/registry] " + \ - f"not found in the deactivated tools list.", error=True) + self.logger.log_registry( + f"Tool [registry]'{tool_name}'[/registry] " + "not found in the deactivated tools list.", error=True + ) return False # Add the tool to the active tools @@ -137,8 +142,9 @@ def deactivate_tool(self, tool_name) -> bool: return False # Check if the tool is active if tool_name not in self.tools: - self.logger.log_registry(f"Tool [registry]'{tool_name}'[/registry] "+ \ - f"not found in the active tools list.", error=True) + self.logger.log_registry( + f"Tool [registry]'{tool_name}'[/registry] " + "not found in the active tools list.", error=True + ) return False # Add the tool to the deactivated tools @@ -171,11 +177,12 @@ def _resolve_dependencies(self, tool: CompositeTool): for dep_name in tool.dependencies: dep_tool = self.tools.get(dep_name) if dep_tool is None: - self.logger.log_registry(f"ERROR. Dependency '{dep_name}' for tool '{tool.name}' not found.", error=True) + self.logger.log_registry( + f"ERROR. Dependency '{dep_name}' for tool '{tool.name}' not found.", error=True + ) else: tool.resolved_deps[dep_name] = dep_tool - def _load_tools_from_file(self, path: str): """Dynamically load a Python file with @vulcanai_tool classes.""" try: @@ -191,6 +198,7 @@ def _load_tools_from_file(self, path: str): self._loaded_modules.append(module) except Exception as e: self.logger.log_registry(f"Could not load tools from {path}: {e}", error=True) + def discover_tools_from_file(self, path: str): """Load tools from a Python file and register them.""" self._load_tools_from_file(path) @@ -227,8 +235,7 @@ def top_k(self, query: str, k: int = 5, validation: bool = False) -> list[ITool] if not active_names: # If there is no tool for the requested category, be explicit and return [] self.logger.log_registry( - f"No matching tools for the requested mode ({'validation' if validation else 'action'}).", - error=True + f"No matching tools for the requested mode ({'validation' if validation else 'action'}).", error=True ) return [] # If k > number of ALL tools, return required tools @@ -239,8 +246,7 @@ def top_k(self, query: str, k: int = 5, validation: bool = False) -> list[ITool] if not filtered_index: # Index might be stale; log and return [] - self.logger.log_registry("Index has no entries for the selected tool subset.", - error=True) + self.logger.log_registry("Index has no entries for the selected tool subset.", error=True) return [] # If k > number of required tools, return required tools if k > len(filtered_index): diff --git a/src/vulcanai/tools/tools.py b/src/vulcanai/tools/tools.py index aaa1c10..dc1b325 100644 --- a/src/vulcanai/tools/tools.py +++ b/src/vulcanai/tools/tools.py @@ -21,6 +21,7 @@ class ITool(ABC): Abstract class containing base metadata every tool must provide. All tools must inherit from this interface to ensure consistency during LLMs calls. """ + # Name given to the tool name: str # Brief description of the tool's purpose @@ -49,13 +50,16 @@ def run(self, **kwargs) -> Dict[str, Any]: class AtomicTool(ITool): """Atomic tool with a single capability.""" + pass + class CompositeTool(ITool): """ Composite tool used to define more complex actions. It reuses existing tools and their capabilities, which must be listed as dependencies. """ + # Names of tools this composite tool depends on dependencies: List[str] = [] # Resolved tool instances (injected at execution time) @@ -65,15 +69,18 @@ def __init__(self): super().__init__() self.resolved_deps = {} + class ValidationTool(ITool): """ Atomic Validation tool with a single capability. - As a general rule, Validation tools are responsible of the feedback system used to check if the final goal has been achieved. + As a general rule, Validation tools are responsible of the feedback system used to check if the final goal + has been achieved. They must return retrieved data following the output schema, to ensure the LLM can parse it correctly. If data is not returned, the framework will not be able to update the blackboard with the new information and the generated prompt might contain outdated information. """ + is_validation_tool: bool = True # If True, the tool must provide images as its output under the key 'images' in format List[str] provide_images: bool = False diff --git a/tests/integration_tests/test_iterative_manager.py b/tests/integration_tests/test_iterative_manager.py index f1d548d..b3470b1 100644 --- a/tests/integration_tests/test_iterative_manager.py +++ b/tests/integration_tests/test_iterative_manager.py @@ -59,10 +59,12 @@ def change_output(self, new_output): class FakeStepBase: """Replica of StepBase class with tool name and args.""" + def __init__(self, tool: str, args: list[ArgValue]): self.tool = tool self.args = args + class DummyImageValidationTool(ValidationTool): def __init__(self, name, description="desc", input_schema=None, output_schema=None, output="result"): self.name = name @@ -81,20 +83,14 @@ def change_output(self, new_output): class MockGoalSpec: def __init__( - self, - summary="goal", - mode="objective", - success_predicates=None, - verify_tools=None, - evidence_bb_keys=None - ): + self, summary="goal", mode="objective", success_predicates=None, verify_tools=None, evidence_bb_keys=None + ): self.summary = summary self.mode = mode self.success_predicates = success_predicates or [] self.verify_tools: list[FakeStepBase] = verify_tools or [] self.evidence_bb_keys = evidence_bb_keys or [] - def __str__(self) -> str: lines = [] if self.summary: @@ -120,6 +116,7 @@ def __init__(self, success=True, confidence=1.0, explanation=""): class FakeRegistry: """Mimics ToolRegistry enough for top_k() and tools{} lookups.""" + def __init__(self, tools=None): self.tools = {t.name: t for t in (tools or [])} @@ -130,6 +127,7 @@ def top_k(self, user_text, k, validation=False): class MockAgent: """Mock agent that records prompts passed to inference_plan() and inference_goal().""" + def __init__(self, plans=[], goal=None, validation=None, success_validation=0, logger=None): """ :param plans: list[GlobalPlan] to return on successive inference_plan() calls @@ -138,10 +136,14 @@ def __init__(self, plans=[], goal=None, validation=None, success_validation=0, l self.plans = list(plans) self.goal = goal self.validation = validation - self.inference_calls = [] # List of dicts: {"system": str, "user": str, "images": list, "history": list} - self.goal_calls = [] # Same for inference_goal - self.validation_calls = [] # Same for inference_validation. Inference validation is only called if goal is 'perceptual' - # If >0, number of successive validation calls needed to return success=True. Any call before that returns success=False. First call is 0. + # List of dicts: {"system": str, "user": str, "images": list, "history": list} + self.inference_calls = [] + # Same for inference_goal + self.goal_calls = [] + # Same for inference_validation. Inference validation is only called if goal is 'perceptual' + self.validation_calls = [] + # If >0, number of successive validation calls needed to return success=True. + # Any call before that returns success = False. First call is 0. self.success_validation = success_validation def inference_plan(self, system_prompt, user_prompt, images, history): @@ -158,7 +160,9 @@ def inference_goal(self, system_prompt, user_prompt, history): return self.goal def inference_validation(self, system_prompt, user_prompt, images, history): - self.validation_calls.append({"system": system_prompt, "user": user_prompt, "images": list(images), "history": list(history)}) + self.validation_calls.append( + {"system": system_prompt, "user": user_prompt, "images": list(images), "history": list(history)} + ) if self.success_validation > 0: self.success_validation -= 1 self.validation.success = False @@ -180,10 +184,12 @@ def make_single_step_plan(summary="plan", tool="dummy_tool", key="arg", val="x", ], ) + ################# ### Fixtures ################# + class ListSink: def __init__(self): self.lines = [] # list[str] @@ -191,6 +197,7 @@ def __init__(self): def write(self, msg: str, color: str = "") -> None: self.lines.append(msg) + @pytest.fixture def logger(): sink = ListSink() @@ -198,6 +205,7 @@ def logger(): log.set_sink(sink) return log + @pytest.fixture(autouse=True) def patch_core_symbols(monkeypatch): """Patch the exact Agent to avoid doing real inferences.""" @@ -206,6 +214,7 @@ def patch_core_symbols(monkeypatch): # Patch Agent so ToolManager(self.llm = Agent(...)) doesn't spin a real model monkeypatch.setattr(f"{target_mod}.Agent", MockAgent, raising=True) + @pytest.fixture def base_manager(logger, monkeypatch): # Build a minimal IterativeManager with mocked dependencies. The agent will be pathed @@ -228,6 +237,7 @@ def base_manager(logger, monkeypatch): # Initialize blackboard iteration counter mgr.bb["iteration"] = 0 original_execute_plan = mgr.execute_plan + # Each plan execution advances iteration by 1 and "fails" to force another iteration def exec_wrapper(_plan): ret = original_execute_plan(_plan) @@ -236,13 +246,16 @@ def exec_wrapper(_plan): for h in mgr._after_execute_hooks: h(_plan, ret) return ret + monkeypatch.setattr(mgr, "execute_plan", exec_wrapper, raising=True) return mgr + ################# ### Tests ################# + def test_prompts_reflect_bb_updates_across_iterations(base_manager): """ On iteration N+1, the user prompt must include the updated blackboard snapshot @@ -285,7 +298,7 @@ def test_prompts_reflect_bb_updates_across_iterations(base_manager): # After iter 1 execution, BB should contain dummy_tool result and must show up in iter 2 prompt assert "dummy_tool_1" not in first_user_prompt assert "dummy_tool_1" in second_user_prompt - assert "out1" in second_user_prompt # Output from dummy_tool_1 run in iter 1 + assert "out1" in second_user_prompt # Output from dummy_tool_1 run in iter 1 assert "out2" not in second_user_prompt # Output from dummy_tool_2 @@ -299,7 +312,9 @@ def test_validation_tools_are_called_before_each_iteration(base_manager): goal = MockGoalSpec( summary="test validation tools are called", mode="objective", - success_predicates=["{{bb.iteration}} >= 4"], # As we have validation tool, two bb.iterations are called per plan iteration (validation and plan) + success_predicates=[ + "{{bb.iteration}} >= 4" + ], # As we have validation tool, two bb.iterations are called per plan iteration (validation and plan) verify_tools=[FakeStepBase("dummy_validation", args=[])], ) @@ -313,9 +328,12 @@ def test_validation_tools_are_called_before_each_iteration(base_manager): # Capture plans passed to execute_plan to check if the timeout was applied captured_plans = [] + def capture_plan_hook(plan, exec_result): import copy + captured_plans.append(copy.deepcopy(plan)) + mgr._after_execute_hooks.append(capture_plan_hook) # Run @@ -346,7 +364,8 @@ def test_validation_tools_providing_images_are_handled_in_perceptual(base_manage Test that validation tools that provide images are correctly handled and their images are passed to the next inference_plan() call if the goal mode is 'perceptual'. 1) The validation tool must be called before each iteration. - 2) If the goal mode is 'perceptual', the images provided by the validation tool must be passed to the next inference_plan() call. + 2) If the goal mode is 'perceptual', the images provided by the validation tool must be passed to the next + inference_plan() call. """ mgr = base_manager @@ -354,7 +373,9 @@ def test_validation_tools_providing_images_are_handled_in_perceptual(base_manage goal = MockGoalSpec( summary="test validation tools providing images are correctly handled", mode="perceptual", - success_predicates=["{{bb.iteration}} >= 4"], # As we have validation tool, two bb.iterations are called per plan iteration (validation and plan) + success_predicates=[ + "{{bb.iteration}} >= 4" + ], # As we have validation tool, two bb.iterations are called per plan iteration (validation and plan) verify_tools=[FakeStepBase("dummy_image_validation", args=[])], ) @@ -414,9 +435,12 @@ def test_iterative_manager_applies_timeout_to_all_steps(base_manager): # Capture plans passed to execute_plan to check if the timeout was applied captured_plans = [] + def capture_plan_hook(plan, exec_result): import copy + captured_plans.append(copy.deepcopy(plan)) + mgr._after_execute_hooks.append(capture_plan_hook) # Run @@ -466,4 +490,6 @@ def test_repeated_plan_is_detected(base_manager): assert "timeline" in result # Ensure a PLAN_REPEATED event exists - assert any(e["event"] == TimelineEvent.PLAN_REPEATED.value for e in result["timeline"]), f"Timeline: {result['timeline']}" + assert any(e["event"] == TimelineEvent.PLAN_REPEATED.value for e in result["timeline"]), ( + f"Timeline: {result['timeline']}" + ) diff --git a/tests/resources/test_composite_tool.py b/tests/resources/test_composite_tool.py index 06baa73..4ee84f8 100644 --- a/tests/resources/test_composite_tool.py +++ b/tests/resources/test_composite_tool.py @@ -13,6 +13,8 @@ # limitations under the License. from vulcanai import CompositeTool, vulcanai_tool + + # Register dummy composite tool @vulcanai_tool class ComplexFileTool(CompositeTool): @@ -23,5 +25,6 @@ class ComplexFileTool(CompositeTool): output_schema = {"result": "bool"} version = "0.1" dependencies = ["file_tool", "new_file_tool", "other_file_tool"] + def run(self, **kwargs): return {"result": True} diff --git a/tests/resources/test_tools.py b/tests/resources/test_tools.py index e1b8392..53f10e4 100644 --- a/tests/resources/test_tools.py +++ b/tests/resources/test_tools.py @@ -14,6 +14,7 @@ from vulcanai import AtomicTool, ValidationTool, vulcanai_tool + # Register dummy tools @vulcanai_tool class FileTool(AtomicTool): @@ -23,7 +24,10 @@ class FileTool(AtomicTool): input_schema = [("x", "float"), ("y", "float"), ("z", "float")] output_schema = {"arrived": "bool"} version = "0.1" - def run(self, **kwargs): return {"arrived": True} + + def run(self, **kwargs): + return {"arrived": True} + @vulcanai_tool class NewFileTool(AtomicTool): @@ -33,7 +37,10 @@ class NewFileTool(AtomicTool): input_schema = [("label", "string")] output_schema = {"found": "bool"} version = "0.1" - def run(self, **kwargs): return {"found": True} + + def run(self, **kwargs): + return {"found": True} + @vulcanai_tool class OtherFileTool(AtomicTool): @@ -43,7 +50,10 @@ class OtherFileTool(AtomicTool): input_schema = [("text", "string")] output_schema = {"spoken": "bool"} version = "0.1" - def run(self, **kwargs): return {"spoken": True} + + def run(self, **kwargs): + return {"spoken": True} + @vulcanai_tool class AnotherValidationTool(ValidationTool): @@ -53,7 +63,10 @@ class AnotherValidationTool(ValidationTool): input_schema = [] output_schema = {"valid": "bool"} version = "0.1" - def run(self, **kwargs): return {"valid": True} + + def run(self, **kwargs): + return {"valid": True} + class NoDecoratorTool(AtomicTool): name = "no_decorator_tool" @@ -62,4 +75,6 @@ class NoDecoratorTool(AtomicTool): input_schema = [("text", "string")] output_schema = {"spoken": "bool"} version = "0.1" - def run(self, **kwargs): return {"spoken": True} + + def run(self, **kwargs): + return {"spoken": True} diff --git a/tests/unittest/test_executor.py b/tests/unittest/test_executor.py index c4d22c1..aeaa08d 100644 --- a/tests/unittest/test_executor.py +++ b/tests/unittest/test_executor.py @@ -14,23 +14,28 @@ import hashlib import importlib -import numpy as np import os import sys import time import types import unittest +import numpy as np + # Stub sentence_transformers to avoid heavy dependency during tests class _DummySentenceTransformer: def __init__(self, *args, **kwargs): pass + def encode(self, text, convert_to_numpy=True): return None + def similarity(self, a, b): return None -sys.modules.setdefault('sentence_transformers', types.SimpleNamespace(SentenceTransformer=_DummySentenceTransformer)) + + +sys.modules.setdefault("sentence_transformers", types.SimpleNamespace(SentenceTransformer=_DummySentenceTransformer)) # Make src-layout importable @@ -67,8 +72,10 @@ def embed(self, text: str) -> np.ndarray: vec = np.frombuffer(h, dtype=np.uint8).astype(np.float32)[:64] norm = np.linalg.norm(vec) or 1.0 return vec / norm + def similarity(self, a: np.ndarray, b: np.ndarray) -> float: return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8)) + self.Embedder = LocalDummyEmbedder() # Build registry and executor @@ -82,9 +89,11 @@ class NavTool(self.AtomicTool): input_schema = [("x", "float"), ("y", "float"), ("z", "float")] output_schema = {"arrived": "bool"} version = "0.1" + def run(self, **kwargs): print(f"Run method of NavTool called with args: {kwargs}") return {"arrived": True} + class DetectTool(self.AtomicTool): name = "detect_object" description = "Detect an object in the environment" @@ -92,8 +101,10 @@ class DetectTool(self.AtomicTool): input_schema = [("label", "string")] output_schema = {"found": "bool", "pose": "dict(x: float, y: float, z: float)"} version = "0.1" + def run(self, **kwargs): return {"found": True, "pose": {"x": 4.0, "y": 2.0, "z": 0.0}} + class SpeakTool(self.AtomicTool): name = "speak" description = "Speak a text string" @@ -101,8 +112,10 @@ class SpeakTool(self.AtomicTool): input_schema = [("text", "string")] output_schema = {"spoken": "bool"} version = "0.1" + def run(self, **kwargs): return {"spoken": True, "spoken_text": kwargs.get("text", "")} + class ListTool(self.AtomicTool): name = "output_list" description = "Output a list of items" @@ -110,8 +123,10 @@ class ListTool(self.AtomicTool): input_schema = [] output_schema = {"output": "list"} version = "0.1" + def run(self, **kwargs): return {"output": ["apple", "banana", "cherry"]} + class SleepTool(self.AtomicTool): name = "sleep" description = "Sleep for a specified duration" @@ -119,9 +134,11 @@ class SleepTool(self.AtomicTool): input_schema = [("duration", "int")] output_schema = {"slept": "bool"} version = "0.1" + def run(self, **kwargs): time.sleep(kwargs.get("duration", 1)) return {"slept": True} + class FlakyTool(self.AtomicTool): name = "flaky" description = "A tool that fails a few times before succeeding" @@ -129,14 +146,17 @@ class FlakyTool(self.AtomicTool): input_schema = [] output_schema = {"succeeded": "bool"} version = "0.1" + def __init__(self): super().__init__() self.attempts = 0 + def run(self, **kwargs): self.attempts += 1 if self.attempts < 3: raise RuntimeError("Simulated failure") return {"succeeded": True} + # This tool can be instantiated directly to allow resetting attempts self.FlakyTool = FlakyTool @@ -147,13 +167,17 @@ class CriteriaTool(self.AtomicTool): input_schema = [] output_schema = {"value": "int"} version = "0.1" + def __init__(self, return_value): super().__init__() self.return_value = return_value + def run(self, **kwargs): return {"value": self.return_value} + # This tool can be instantiated directly to allow modifying returned value self.CriteriaTool = CriteriaTool + class AddTool(self.AtomicTool): name = "add" description = "Adds two numbers together." @@ -185,9 +209,19 @@ def test_simple_plan_executes(self): self.PlanNode( kind="SEQUENCE", steps=[ - self.Step(tool="go_to_pose", args=[self.Arg(key="x", val=1.0), self.Arg(key="y", val=2.0), self.Arg(key="z", val=0.0)]), + self.Step( + tool="go_to_pose", + args=[self.Arg(key="x", val=1.0), self.Arg(key="y", val=2.0), self.Arg(key="z", val=0.0)], + ), self.Step(tool="detect_object", args=[self.Arg(key="label", val="mug")]), - self.Step(tool="go_to_pose", args=[self.Arg(key="x", val="{{bb.detect_object.pose.x}}"), self.Arg(key="y", val="{{bb.detect_object.pose.y}}"), self.Arg(key="z", val="{{bb.detect_object.pose.z}}")]), + self.Step( + tool="go_to_pose", + args=[ + self.Arg(key="x", val="{{bb.detect_object.pose.x}}"), + self.Arg(key="y", val="{{bb.detect_object.pose.y}}"), + self.Arg(key="z", val="{{bb.detect_object.pose.z}}"), + ], + ), self.Step(tool="speak", args=[self.Arg(key="text", val="I have arrived and detected a mug.")]), ], ) @@ -205,13 +239,20 @@ def test_simple_plan_executes(self): self.assertEqual(bb["detect_object"]["pose"], {"x": 4.0, "y": 2.0, "z": 0.0}) def test_false_condition_skips_step(self): - """Test that a step with a False condition is skipped and a non-existing blackboard entry does not return error.""" + """ + Test that a step with a False condition is skipped and a non-existing blackboard entry + does not return error. + """ plan = self.GlobalPlan( plan=[ self.PlanNode( kind="SEQUENCE", steps=[ - self.Step(tool="speak", args=[self.Arg(key="text", val="Hello")], condition="{{bb.missing_flag}} == True"), + self.Step( + tool="speak", + args=[self.Arg(key="text", val="Hello")], + condition="{{bb.missing_flag}} == True", + ), ], ) ] @@ -266,9 +307,21 @@ def test_bb_condition_evaluation(self): kind="SEQUENCE", steps=[ self.Step(tool="detect_object", args=[self.Arg(key="label", val="book")]), - self.Step(tool="go_to_pose", args=[self.Arg(key="x", val="{{bb.detect_object.pose.x}}"), self.Arg(key="y", val="{{bb.detect_object.pose.y}}"), self.Arg(key="z", val="{{bb.detect_object.pose.z}}")], condition="{{bb.detect_object.found}} == True"), + self.Step( + tool="go_to_pose", + args=[ + self.Arg(key="x", val="{{bb.detect_object.pose.x}}"), + self.Arg(key="y", val="{{bb.detect_object.pose.y}}"), + self.Arg(key="z", val="{{bb.detect_object.pose.z}}"), + ], + condition="{{bb.detect_object.found}} == True", + ), # Skip next step by making condition False - self.Step(tool="speak", args=[self.Arg(key="text", val="I have arrived.")], condition="{{bb.go_to_pose.arrived}} != True"), + self.Step( + tool="speak", + args=[self.Arg(key="text", val="I have arrived.")], + condition="{{bb.go_to_pose.arrived}} != True", + ), ], ) ] @@ -290,7 +343,15 @@ def test_bb_condition_evaluation(self): kind="SEQUENCE", steps=[ self.Step(tool="detect_object", args=[self.Arg(key="label", val="book")]), - self.Step(tool="go_to_pose", args=[self.Arg(key="x", val="{{bb.detect_object.pose.x}}"), self.Arg(key="y", val="{{bb.detect_object.pose.y}}"), self.Arg(key="z", val="{{bb.detect_object.pose.z}}")], condition="True == {{bb.detect_object.found}}"), + self.Step( + tool="go_to_pose", + args=[ + self.Arg(key="x", val="{{bb.detect_object.pose.x}}"), + self.Arg(key="y", val="{{bb.detect_object.pose.y}}"), + self.Arg(key="z", val="{{bb.detect_object.pose.z}}"), + ], + condition="True == {{bb.detect_object.found}}", + ), ], ) ] @@ -311,7 +372,15 @@ def test_bb_condition_evaluation(self): steps=[ self.Step(tool="detect_object", args=[self.Arg(key="label", val="book")]), self.Step(tool="speak", args=[self.Arg(key="text", val="I have arrived.")]), - self.Step(tool="go_to_pose", args=[self.Arg(key="x", val="{{bb.detect_object.pose.x}}"), self.Arg(key="y", val="{{bb.detect_object.pose.y}}"), self.Arg(key="z", val="{{bb.detect_object.pose.z}}")], condition="{{bb.speak.spoken}} == {{bb.detect_object.found}}"), + self.Step( + tool="go_to_pose", + args=[ + self.Arg(key="x", val="{{bb.detect_object.pose.x}}"), + self.Arg(key="y", val="{{bb.detect_object.pose.y}}"), + self.Arg(key="z", val="{{bb.detect_object.pose.z}}"), + ], + condition="{{bb.speak.spoken}} == {{bb.detect_object.found}}", + ), ], ) ] @@ -333,7 +402,8 @@ def test_bb_condition_evaluation(self): kind="SEQUENCE", steps=[ self.Step(tool="output_list", args=[]), - self.Step(tool="speak", + self.Step( + tool="speak", args=[self.Arg(key="text", val="First item is {{bb.output_list.output[0]}}")], ), ], @@ -621,7 +691,8 @@ def test_plan_node_with_unknown_kind(self): ], ) self.fail("PlanNode creation should have failed with ValueError") - except Exception as e: + self.assertFalse(plan) + except Exception: pass def test_plan_node_with_on_fail(self): @@ -712,7 +783,10 @@ def test_types_substitution_is_respected(self): kind="SEQUENCE", steps=[ self.Step(tool="criteria_tool", args=[]), # Outputs 5 - self.Step(tool="add", args=[self.Arg(key="a", val='{{bb.criteria_tool.value}}'), self.Arg(key="b", val=3.5)]), + self.Step( + tool="add", + args=[self.Arg(key="a", val="{{bb.criteria_tool.value}}"), self.Arg(key="b", val=3.5)], + ), ], ) ], diff --git a/tests/unittest/test_tool_registry.py b/tests/unittest/test_tool_registry.py index 0e45814..9dc09ca 100644 --- a/tests/unittest/test_tool_registry.py +++ b/tests/unittest/test_tool_registry.py @@ -15,23 +15,28 @@ import hashlib import importlib import io -import numpy as np import os import re import sys import types import unittest +import numpy as np + # Stub sentence_transformers to avoid heavy dependency during tests class _DummySentenceTransformer: def __init__(self, *args, **kwargs): pass + def encode(self, text, convert_to_numpy=True): return None + def similarity(self, a, b): return None -sys.modules.setdefault('sentence_transformers', types.SimpleNamespace(SentenceTransformer=_DummySentenceTransformer)) + + +sys.modules.setdefault("sentence_transformers", types.SimpleNamespace(SentenceTransformer=_DummySentenceTransformer)) # Add src/ to sys.path for src-layout imports @@ -63,8 +68,10 @@ def embed(self, text: str) -> np.ndarray: vec = np.frombuffer(h, dtype=np.uint8).astype(np.float32)[:64] norm = np.linalg.norm(vec) or 1.0 return vec / norm + def similarity(self, a: np.ndarray, b: np.ndarray) -> float: return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8)) + self.Embedder = LocalDummyEmbedder # Define dummy tools inside setUp to use imported base class @@ -75,8 +82,10 @@ class NavTool(self.AtomicTool): input_schema = [("x", "float"), ("y", "float"), ("z", "float")] output_schema = {"arrived": "bool"} version = "0.1" + def run(self, **kwargs): return {"arrived": True} + class DetectTool(self.AtomicTool): name = "detect_object" description = "Detect an object in the environment" @@ -84,8 +93,10 @@ class DetectTool(self.AtomicTool): input_schema = [("label", "string")] output_schema = {"found": "bool"} version = "0.1" + def run(self, **kwargs): return {"found": True} + class SpeakTool(self.AtomicTool): name = "speak" description = "Speak a text string" @@ -93,8 +104,10 @@ class SpeakTool(self.AtomicTool): input_schema = [("text", "string")] output_schema = {"spoken": "bool"} version = "0.1" + def run(self, **kwargs): return {"spoken": True} + class ComplexTool(self.CompositeTool): name = "complex_action" description = "A complex action using multiple tools" @@ -103,6 +116,7 @@ class ComplexTool(self.CompositeTool): output_schema = {"result": "bool"} version = "0.1" dependencies = ["go_to_pose", "detect_object", "speak"] + def run(self, **kwargs): return {"result": True} @@ -113,6 +127,7 @@ class TestValidationTool(self.ValidationTool): input_schema = [("param", "string")] output_schema = {"result": "bool"} version = "0.1" + def run(self, **kwargs): return {"result": True} @@ -132,7 +147,7 @@ def run(self, **kwargs): def test_top_k_returns_expected_count(self): """Test that top_k returns up to k tools.""" res = self.registry.top_k("go to the room", k=2) - self.assertEqual(len(res), 2+1) # +1 for help tool + self.assertEqual(len(res), 2 + 1) # +1 for help tool def test_top_k_returns_expected_count_validation(self): """Test that top_k returns up to k tools even if there are more non-validation tools registered.""" @@ -166,66 +181,69 @@ def test_empty_registry(self): def test_k_greater_than_tools(self): """Test that k greater than number of tools returns all tools skipping embedding.""" + # Mock embedder to track calls class MockEmbedder(self.Embedder): def __init__(self): super().__init__() self.embed_call_count = 0 + def embed(self, text: str) -> np.ndarray: self.embed_call_count += 1 return super().embed(text) + r = self.ToolRegistry(embedder=MockEmbedder()) r.register_tool(self.NavTool()) r.register_tool(self.DetectTool()) r.register_tool(self.SpeakTool()) self.assertEqual(r.embedder.embed_call_count, 3) res = r.top_k("go to the room", k=10) - self.assertEqual(len(res), 3+1) # +1 for help tool + self.assertEqual(len(res), 3 + 1) # +1 for help tool # Check that self.embedder.embed has not been called during top_k self.assertEqual(r.embedder.embed_call_count, 3) res = r.top_k("detect", k=2) - self.assertEqual(len(res), 2+1) # +1 for help tool + self.assertEqual(len(res), 2 + 1) # +1 for help tool self.assertEqual(r.embedder.embed_call_count, 4) # Test registers def test_register_tool(self): """Test that tools can be registered and are present in the registry.""" r = self.ToolRegistry(embedder=self.Embedder()) - self.assertEqual(len(r.tools), 0+1) # +1 for help tool + self.assertEqual(len(r.tools), 0 + 1) # +1 for help tool r.register_tool(self.NavTool()) - self.assertEqual(len(r.tools), 1+1) # +1 for help tool + self.assertEqual(len(r.tools), 1 + 1) # +1 for help tool r.register_tool(self.DetectTool()) - self.assertEqual(len(r.tools), 2+1) # +1 for help tool + self.assertEqual(len(r.tools), 2 + 1) # +1 for help tool r.register_tool(self.SpeakTool()) - self.assertEqual(len(r.tools), 3+1) # +1 for help tool + self.assertEqual(len(r.tools), 3 + 1) # +1 for help tool self.assertIn("go_to_pose", r.tools) self.assertIn("detect_object", r.tools) self.assertIn("speak", r.tools) self.assertNotIn("nonexistent_tool", r.tools) r.register_tool(self.TestValidationTool()) - self.assertEqual(len(r.tools), 4+1) + self.assertEqual(len(r.tools), 4 + 1) self.assertIn("test_validation", r.tools) def test_register_tool_from_file(self): # Register tools from test_tools.py self.registry.discover_tools_from_file(os.path.join(RESOURCES_DIR, "test_tools.py")) - self.assertEqual(len(self.registry.tools), 7+1) # +1 for help tool (3 existing + 4 new) - self.assertEqual(len(self.registry._index), 7) # (3 existing + 4 new) + self.assertEqual(len(self.registry.tools), 7 + 1) # +1 for help tool (3 existing + 4 new) + self.assertEqual(len(self.registry._index), 7) # (3 existing + 4 new) def test_register_tool_from_nonexistent_file(self): r = self.ToolRegistry(embedder=self.Embedder()) r.discover_tools_from_file("/path/does/not/exist.py") - self.assertEqual(len(r.tools), 0+1) # +1 for help tool + self.assertEqual(len(r.tools), 0 + 1) # +1 for help tool def test_register_composite_tool_solves_deps(self): """Test that registering a composite tool correctly resolves its dependencies.""" # Check that the composite tool has no resolved deps before registration - self.assertEqual(len(self.registry.tools), 3+1) # +1 for help tool + self.assertEqual(len(self.registry.tools), 3 + 1) # +1 for help tool self.assertEqual(len(self.ComplexTool().dependencies), 3) self.assertEqual(len(self.ComplexTool().resolved_deps), 0) # Register the composite tool self.registry.register_tool(self.ComplexTool()) - self.assertEqual(len(self.registry.tools), 4+1) # +1 for help tool + self.assertEqual(len(self.registry.tools), 4 + 1) # +1 for help tool self.assertIn("complex_action", self.registry.tools) self.assertEqual(len(self.registry.tools.get("complex_action").resolved_deps), 3) @@ -235,12 +253,12 @@ def test_register_composite_tool_fails_reports_error(self): sys.stdout = buf # Check that the composite tool has no resolved deps before registration r = self.ToolRegistry(embedder=self.Embedder()) - self.assertEqual(len(r.tools), 0+1) # +1 for help tool + self.assertEqual(len(r.tools), 0 + 1) # +1 for help tool self.assertEqual(len(self.ComplexTool().dependencies), 3) self.assertEqual(len(self.ComplexTool().resolved_deps), 0) # Register the composite tool r.register_tool(self.ComplexTool()) - self.assertEqual(len(r.tools), 1+1) # +1 for help tool + self.assertEqual(len(r.tools), 1 + 1) # +1 for help tool self.assertIn("complex_action", r.tools) self.assertEqual(len(self.ComplexTool().resolved_deps), 0) # Capture output to check for error messages @@ -264,7 +282,7 @@ def test_register_composite_tool_solves_deps_from_file(self): self.assertNotIn("Dependency 'file_tool' for tool 'complex_file_action' not found", output) self.assertNotIn("Dependency 'new_file_tool' for tool 'complex_file_action' not found", output) self.assertNotIn("Dependency 'other_file_tool' for tool 'complex_file_action' not found", output) - self.assertEqual(len(r.tools), 5+1) # +1 for help tool + self.assertEqual(len(r.tools), 5 + 1) # +1 for help tool def test_register_composite_tool_fails_from_file(self): """Test that registering a composite tool from file reports an error if there are no dependencies.""" @@ -279,15 +297,15 @@ def test_register_composite_tool_fails_from_file(self): self.assertIn("Dependency 'file_tool' for tool 'complex_file_action' not found", output) self.assertIn("Dependency 'new_file_tool' for tool 'complex_file_action' not found", output) self.assertIn("Dependency 'other_file_tool' for tool 'complex_file_action' not found", output) - self.assertEqual(len(r.tools), 1+1) # +1 for help tool + self.assertEqual(len(r.tools), 1 + 1) # +1 for help tool def test_register_validation_tool(self): """Test that validation tools can be registered and are present in the registry.""" r = self.ToolRegistry(embedder=self.Embedder()) - self.assertEqual(len(r.tools), 0+1) # +1 for help tool + self.assertEqual(len(r.tools), 0 + 1) # +1 for help tool self.assertEqual(len(r._index), 0) r.register_tool(self.TestValidationTool()) - self.assertEqual(len(r.tools), 1+1) # +1 for help tool + self.assertEqual(len(r.tools), 1 + 1) # +1 for help tool self.assertEqual(len(r._index), 1) self.assertIn("test_validation", r.validation_tools) diff --git a/tests/unittest/test_validator.py b/tests/unittest/test_validator.py index 0c90d07..611ac4f 100644 --- a/tests/unittest/test_validator.py +++ b/tests/unittest/test_validator.py @@ -14,22 +14,27 @@ import hashlib import importlib -import numpy as np import os import sys import types import unittest +import numpy as np + # Stub sentence_transformers to avoid heavy dependency during tests class _DummySentenceTransformer: def __init__(self, *args, **kwargs): pass + def encode(self, text, convert_to_numpy=True): return None + def similarity(self, a, b): return None -sys.modules.setdefault('sentence_transformers', types.SimpleNamespace(SentenceTransformer=_DummySentenceTransformer)) + + +sys.modules.setdefault("sentence_transformers", types.SimpleNamespace(SentenceTransformer=_DummySentenceTransformer)) # Make src-layout importable @@ -66,8 +71,10 @@ def embed(self, text: str) -> np.ndarray: vec = np.frombuffer(h, dtype=np.uint8).astype(np.float32)[:64] norm = np.linalg.norm(vec) or 1.0 return vec / norm + def similarity(self, a: np.ndarray, b: np.ndarray) -> float: return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8)) + self.Embedder = LocalDummyEmbedder() # Build registry and executor @@ -82,6 +89,7 @@ class EmptyTool(self.AtomicTool): def run(self): return {"result": True} + class DetectTool(self.AtomicTool): name = "detect_object" description = "Detect an object in the environment" @@ -89,8 +97,10 @@ class DetectTool(self.AtomicTool): input_schema = [("label", "string")] output_schema = {"found": "bool", "pose": "dict(x: float, y: float, z: float)"} version = "0.1" + def run(self, **kwargs): return {"found": True, "pose": {"x": 4.0, "y": 2.0, "z": 0.0}} + class NavTool(self.AtomicTool): name = "go_to_pose" description = "Navigate robot to a target location" @@ -98,9 +108,11 @@ class NavTool(self.AtomicTool): input_schema = [("x", "float"), ("y", "float"), ("z", "float")] output_schema = {"arrived": "bool"} version = "0.1" + def run(self, **kwargs): print(f"Run method of NavTool called with args: {kwargs}") return {"arrived": True} + class SpeakTool(self.AtomicTool): name = "speak" description = "Speak a text string" @@ -108,6 +120,7 @@ class SpeakTool(self.AtomicTool): input_schema = [("text", "string")] output_schema = {"spoken": "bool"} version = "0.1" + def run(self, **kwargs): return {"spoken": True, "spoken_text": kwargs.get("text", "")} @@ -124,6 +137,7 @@ class TypesTool(self.AtomicTool): ] output_schema = {"types": "bool"} version = "0.1" + def run(self, **kwargs): return {"types": True} @@ -142,9 +156,15 @@ def test_validator_correct_plan(self): self.PlanNode( kind="SEQUENCE", steps=[ - self.Step(tool="go_to_pose", args=[self.Arg(key="x", val=1.0), self.Arg(key="y", val=2.0), self.Arg(key="z", val=0.0)]), + self.Step( + tool="go_to_pose", + args=[self.Arg(key="x", val=1.0), self.Arg(key="y", val=2.0), self.Arg(key="z", val=0.0)], + ), self.Step(tool="detect_object", args=[self.Arg(key="label", val="mug")]), - self.Step(tool="go_to_pose", args=[self.Arg(key="x", val=3.0), self.Arg(key="y", val=4.0), self.Arg(key="z", val=0.0)]), + self.Step( + tool="go_to_pose", + args=[self.Arg(key="x", val=3.0), self.Arg(key="y", val=4.0), self.Arg(key="z", val=0.0)], + ), self.Step(tool="speak", args=[self.Arg(key="text", val="I have arrived and detected a mug.")]), ], ) @@ -182,7 +202,14 @@ def test_validator_non_existing_key(self): self.PlanNode( kind="SEQUENCE", steps=[ - self.Step(tool="go_to_pose", args=[self.Arg(key="non_existing_key", val=1.0), self.Arg(key="y", val=2.0), self.Arg(key="z", val=0.0)]), + self.Step( + tool="go_to_pose", + args=[ + self.Arg(key="non_existing_key", val=1.0), + self.Arg(key="y", val=2.0), + self.Arg(key="z", val=0.0), + ], + ), ], ) ], @@ -222,7 +249,15 @@ def test_validator_extra_key(self): self.PlanNode( kind="SEQUENCE", steps=[ - self.Step(tool="go_to_pose", args=[self.Arg(key="w", val=1.0), self.Arg(key="x", val=1.0), self.Arg(key="y", val=2.0), self.Arg(key="z", val=0.0)]), + self.Step( + tool="go_to_pose", + args=[ + self.Arg(key="w", val=1.0), + self.Arg(key="x", val=1.0), + self.Arg(key="y", val=2.0), + self.Arg(key="z", val=0.0), + ], + ), ], ) ], @@ -262,7 +297,10 @@ def test_validator_wrong_last_bracket_missing(self): self.PlanNode( kind="SEQUENCE", steps=[ - self.Step(tool="go_to_pose", args=[self.Arg(key="x", val=1.0), self.Arg(key="y", val=2.0), self.Arg(key="z", val=0.0)]), + self.Step( + tool="go_to_pose", + args=[self.Arg(key="x", val=1.0), self.Arg(key="y", val=2.0), self.Arg(key="z", val=0.0)], + ), self.Step(tool="speak", args=[self.Arg(key="text", val="Arrived at {{bb.missing_brace")]), ], ) @@ -283,7 +321,10 @@ def test_validator_wrong_first_bracket_missing(self): self.PlanNode( kind="SEQUENCE", steps=[ - self.Step(tool="go_to_pose", args=[self.Arg(key="x", val=1.0), self.Arg(key="y", val=2.0), self.Arg(key="z", val=0.0)]), + self.Step( + tool="go_to_pose", + args=[self.Arg(key="x", val=1.0), self.Arg(key="y", val=2.0), self.Arg(key="z", val=0.0)], + ), self.Step(tool="speak", args=[self.Arg(key="text", val="Arrived at bb.missing_brace}}")]), ], ) @@ -304,7 +345,10 @@ def test_validator_wrong_all_brackets_missing(self): self.PlanNode( kind="SEQUENCE", steps=[ - self.Step(tool="go_to_pose", args=[self.Arg(key="x", val=1.0), self.Arg(key="y", val=2.0), self.Arg(key="z", val=0.0)]), + self.Step( + tool="go_to_pose", + args=[self.Arg(key="x", val=1.0), self.Arg(key="y", val=2.0), self.Arg(key="z", val=0.0)], + ), self.Step(tool="speak", args=[self.Arg(key="text", val="Arrived at bb.missing_brace")]), ], ) @@ -325,13 +369,15 @@ def test_validator_correct_types(self): self.PlanNode( kind="SEQUENCE", steps=[ - self.Step(tool="types", args=[ - self.Arg(key="int", val=4), - self.Arg(key="integer", val=2), - self.Arg(key="float", val=4.2), - self.Arg(key="bool", val=True), - self.Arg(key="boolean", val=False), - ] + self.Step( + tool="types", + args=[ + self.Arg(key="int", val=4), + self.Arg(key="integer", val=2), + self.Arg(key="float", val=4.2), + self.Arg(key="bool", val=True), + self.Arg(key="boolean", val=False), + ], ), ], ) @@ -339,21 +385,21 @@ def test_validator_correct_types(self): ) try: self.validator.validate(plan) - except Exception as e: + except Exception: self.fail("Test 'test_validator_correct_types' should not fail") # Cast of int to float is accepted, not the other way around plan.plan[0].steps[0].args[2] = self.Arg(key="float", val=4) # Pass a int instead of float try: self.validator.validate(plan) - except Exception as e: + except Exception: self.fail("Test 'test_validator_correct_types' should not fail for int->float cast") # Non-string types can accept string bb references plan.plan[0].steps[0].args[2] = self.Arg(key="float", val="{{bb.tool.float_value}}") try: self.validator.validate(plan) - except Exception as e: + except Exception: self.fail("Test 'test_validator_correct_types' should not fail for bb reference in string format") def test_validator_wrong_types(self): @@ -363,14 +409,16 @@ def test_validator_wrong_types(self): self.PlanNode( kind="SEQUENCE", steps=[ - self.Step(tool="types", args=[ - self.Arg(key="int", val="4"), - # Add only one wrong type on each test, rest should be correct - self.Arg(key="integer", val=2), - self.Arg(key="float", val=4.2), - self.Arg(key="bool", val=True), - self.Arg(key="boolean", val=False), - ] + self.Step( + tool="types", + args=[ + self.Arg(key="int", val="4"), + # Add only one wrong type on each test, rest should be correct + self.Arg(key="integer", val=2), + self.Arg(key="float", val=4.2), + self.Arg(key="bool", val=True), + self.Arg(key="boolean", val=False), + ], ), ], ) @@ -381,7 +429,7 @@ def test_validator_wrong_types(self): self.validator.validate(plan) except Exception as e: fail = True - self.assertIn(f"Argument 'int' of tool 'types' expects type", str(e)) + self.assertIn("Argument 'int' of tool 'types' expects type", str(e)) self.assertTrue(fail, "Validator did not catch non-existing key error") plan.plan[0].steps[0].args[0] = self.Arg(key="int", val=4.2) @@ -390,63 +438,79 @@ def test_validator_wrong_types(self): self.validator.validate(plan) except Exception as e: fail = True - self.assertIn(f"Argument 'int' of tool 'types' expects type", str(e)) + self.assertIn("Argument 'int' of tool 'types' expects type", str(e)) self.assertTrue(fail, "Validator did not catch non-existing key error") - plan.plan[0].steps[0] = self.Step(tool="types", args=[ - self.Arg(key="int", val=4), - self.Arg(key="integer", val="2"), - self.Arg(key="float", val=4.2), - self.Arg(key="bool", val=True), - self.Arg(key="boolean", val=False)]) + plan.plan[0].steps[0] = self.Step( + tool="types", + args=[ + self.Arg(key="int", val=4), + self.Arg(key="integer", val="2"), + self.Arg(key="float", val=4.2), + self.Arg(key="bool", val=True), + self.Arg(key="boolean", val=False), + ], + ) try: fail = False self.validator.validate(plan) except Exception as e: fail = True - self.assertIn(f"Argument 'integer' of tool 'types' expects type", str(e)) + self.assertIn("Argument 'integer' of tool 'types' expects type", str(e)) self.assertTrue(fail, "Validator did not catch non-existing key error") - plan.plan[0].steps[0] = self.Step(tool="types", args=[ - self.Arg(key="int", val=4), - self.Arg(key="integer", val=2), - self.Arg(key="float", val="42.0"), - self.Arg(key="bool", val=True), - self.Arg(key="boolean", val=False)]) + plan.plan[0].steps[0] = self.Step( + tool="types", + args=[ + self.Arg(key="int", val=4), + self.Arg(key="integer", val=2), + self.Arg(key="float", val="42.0"), + self.Arg(key="bool", val=True), + self.Arg(key="boolean", val=False), + ], + ) try: fail = False self.validator.validate(plan) except Exception as e: fail = True - self.assertIn(f"Argument 'float' of tool 'types' expects type", str(e)) + self.assertIn("Argument 'float' of tool 'types' expects type", str(e)) self.assertTrue(fail, "Validator did not catch non-existing key error") - plan.plan[0].steps[0] = self.Step(tool="types", args=[ - self.Arg(key="int", val=4), - self.Arg(key="integer", val=2), - self.Arg(key="float", val=42.0), - self.Arg(key="bool", val="true"), - self.Arg(key="boolean", val=False)]) + plan.plan[0].steps[0] = self.Step( + tool="types", + args=[ + self.Arg(key="int", val=4), + self.Arg(key="integer", val=2), + self.Arg(key="float", val=42.0), + self.Arg(key="bool", val="true"), + self.Arg(key="boolean", val=False), + ], + ) try: fail = False self.validator.validate(plan) except Exception as e: fail = True - self.assertIn(f"Argument 'bool' of tool 'types' expects type", str(e)) + self.assertIn("Argument 'bool' of tool 'types' expects type", str(e)) self.assertTrue(fail, "Validator did not catch non-existing key error") - plan.plan[0].steps[0] = self.Step(tool="types", args=[ - self.Arg(key="int", val=4), - self.Arg(key="integer", val=2), - self.Arg(key="float", val=42.0), - self.Arg(key="bool", val=True), - self.Arg(key="boolean", val="false")]) + plan.plan[0].steps[0] = self.Step( + tool="types", + args=[ + self.Arg(key="int", val=4), + self.Arg(key="integer", val=2), + self.Arg(key="float", val=42.0), + self.Arg(key="bool", val=True), + self.Arg(key="boolean", val="false"), + ], + ) try: fail = False self.validator.validate(plan) except Exception as e: fail = True - self.assertIn(f"Argument 'boolean' of tool 'types' expects type", str(e)) + self.assertIn("Argument 'boolean' of tool 'types' expects type", str(e)) self.assertTrue(fail, "Validator did not catch non-existing key error")