Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 107 additions & 75 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,84 +115,114 @@ def __init__(
],
)

self._action_handlers = {
"open_web_browser": self._handle_open_web_browser,
"click_at": self._handle_click_at,
"hover_at": self._handle_hover_at,
"type_text_at": self._handle_type_text_at,
"scroll_document": self._handle_scroll_document,
"scroll_at": self._handle_scroll_at,
"wait_5_seconds": self._handle_wait_5_seconds,
"go_back": self._handle_go_back,
"go_forward": self._handle_go_forward,
"search": self._handle_search,
"navigate": self._handle_navigate,
"key_combination": self._handle_key_combination,
"drag_and_drop": self._handle_drag_and_drop,
multiply_numbers.__name__: self._handle_multiply_numbers,
}

def handle_action(self, action: types.FunctionCall) -> FunctionResponseT:
"""Handles the action and returns the environment state."""
if action.name == "open_web_browser":
return self._browser_computer.open_web_browser()
elif action.name == "click_at":
x = self.denormalize_x(action.args["x"])
y = self.denormalize_y(action.args["y"])
return self._browser_computer.click_at(
x=x,
y=y,
)
elif action.name == "hover_at":
x = self.denormalize_x(action.args["x"])
y = self.denormalize_y(action.args["y"])
return self._browser_computer.hover_at(
x=x,
y=y,
)
elif action.name == "type_text_at":
x = self.denormalize_x(action.args["x"])
y = self.denormalize_y(action.args["y"])
press_enter = action.args.get("press_enter", False)
clear_before_typing = action.args.get("clear_before_typing", True)
return self._browser_computer.type_text_at(
x=x,
y=y,
text=action.args["text"],
press_enter=press_enter,
clear_before_typing=clear_before_typing,
)
elif action.name == "scroll_document":
return self._browser_computer.scroll_document(action.args["direction"])
elif action.name == "scroll_at":
x = self.denormalize_x(action.args["x"])
y = self.denormalize_y(action.args["y"])
magnitude = action.args.get("magnitude", 800)
direction = action.args["direction"]

if direction in ("up", "down"):
magnitude = self.denormalize_y(magnitude)
elif direction in ("left", "right"):
magnitude = self.denormalize_x(magnitude)
else:
raise ValueError("Unknown direction: ", direction)
return self._browser_computer.scroll_at(
x=x, y=y, direction=direction, magnitude=magnitude
)
elif action.name == "wait_5_seconds":
return self._browser_computer.wait_5_seconds()
elif action.name == "go_back":
return self._browser_computer.go_back()
elif action.name == "go_forward":
return self._browser_computer.go_forward()
elif action.name == "search":
return self._browser_computer.search()
elif action.name == "navigate":
return self._browser_computer.navigate(action.args["url"])
elif action.name == "key_combination":
return self._browser_computer.key_combination(
action.args["keys"].split("+")
)
elif action.name == "drag_and_drop":
x = self.denormalize_x(action.args["x"])
y = self.denormalize_y(action.args["y"])
destination_x = self.denormalize_x(action.args["destination_x"])
destination_y = self.denormalize_y(action.args["destination_y"])
return self._browser_computer.drag_and_drop(
x=x,
y=y,
destination_x=destination_x,
destination_y=destination_y,
)
# Handle the custom function declarations here.
elif action.name == multiply_numbers.__name__:
return multiply_numbers(x=action.args["x"], y=action.args["y"])
if handler := self._action_handlers.get(action.name):
return handler(action)
else:
raise ValueError(f"Unsupported function: {action}")

def _handle_open_web_browser(self, action: types.FunctionCall) -> FunctionResponseT:
return self._browser_computer.open_web_browser()

def _handle_click_at(self, action: types.FunctionCall) -> FunctionResponseT:
x = self.denormalize_x(action.args["x"])
y = self.denormalize_y(action.args["y"])
return self._browser_computer.click_at(x=x, y=y)

def _handle_hover_at(self, action: types.FunctionCall) -> FunctionResponseT:
x = self.denormalize_x(action.args["x"])
y = self.denormalize_y(action.args["y"])
return self._browser_computer.hover_at(x=x, y=y)

def _handle_type_text_at(self, action: types.FunctionCall) -> FunctionResponseT:
x = self.denormalize_x(action.args["x"])
y = self.denormalize_y(action.args["y"])
press_enter = action.args.get("press_enter", False)
clear_before_typing = action.args.get("clear_before_typing", True)
return self._browser_computer.type_text_at(
x=x,
y=y,
text=action.args["text"],
press_enter=press_enter,
clear_before_typing=clear_before_typing,
)

def _handle_scroll_document(self, action: types.FunctionCall) -> FunctionResponseT:
return self._browser_computer.scroll_document(action.args["direction"])

def _handle_scroll_at(self, action: types.FunctionCall) -> FunctionResponseT:
x = self.denormalize_x(action.args["x"])
y = self.denormalize_y(action.args["y"])
magnitude = action.args.get("magnitude", 800)
direction = action.args["direction"]

if direction in ("up", "down"):
magnitude = self.denormalize_y(magnitude)
elif direction in ("left", "right"):
magnitude = self.denormalize_x(magnitude)
else:
raise ValueError(f"Unknown direction: {direction}")
return self._browser_computer.scroll_at(
x=x, y=y, direction=direction, magnitude=magnitude
)

def _handle_wait_5_seconds(
self, action: types.FunctionCall
) -> FunctionResponseT:
return self._browser_computer.wait_5_seconds()

def _handle_go_back(self, action: types.FunctionCall) -> FunctionResponseT:
return self._browser_computer.go_back()

def _handle_go_forward(self, action: types.FunctionCall) -> FunctionResponseT:
return self._browser_computer.go_forward()

def _handle_search(self, action: types.FunctionCall) -> FunctionResponseT:
return self._browser_computer.search()

def _handle_navigate(self, action: types.FunctionCall) -> FunctionResponseT:
return self._browser_computer.navigate(action.args["url"])

def _handle_key_combination(
self, action: types.FunctionCall
) -> FunctionResponseT:
return self._browser_computer.key_combination(action.args["keys"].split("+"))

def _handle_drag_and_drop(self, action: types.FunctionCall) -> FunctionResponseT:
x = self.denormalize_x(action.args["x"])
y = self.denormalize_y(action.args["y"])
destination_x = self.denormalize_x(action.args["destination_x"])
destination_y = self.denormalize_y(action.args["destination_y"])
return self._browser_computer.drag_and_drop(
x=x,
y=y,
destination_x=destination_x,
destination_y=destination_y,
)

def _handle_multiply_numbers(
self, action: types.FunctionCall
) -> FunctionResponseT:
return multiply_numbers(x=action.args["x"], y=action.args["y"])

def get_model_response(
self, max_retries=5, base_delay_s=1
) -> types.GenerateContentResponse:
Expand Down Expand Up @@ -253,11 +283,13 @@ def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]:
try:
response = self.get_model_response()
except Exception as e:
print(e)
return "COMPLETE"
else:
try:
response = self.get_model_response()
except Exception as e:
print(e)
return "COMPLETE"

if not response.candidates:
Expand Down Expand Up @@ -292,7 +324,7 @@ def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]:
# Print the function call and any reasoning.
function_call_str = f"Name: {function_call.name}"
if function_call.args:
function_call_str += f"\nArgs:"
function_call_str += "\nArgs:"
for key, value in function_call.args.items():
function_call_str += f"\n {key}: {value}"
function_call_strs.append(function_call_str)
Expand Down Expand Up @@ -390,7 +422,7 @@ def _get_safety_confirmation(
self, safety: dict[str, Any]
) -> Literal["CONTINUE", "TERMINATE"]:
if safety["decision"] != "require_confirmation":
raise ValueError(f"Unknown safety decision: safety['decision']")
raise ValueError(f"Unknown safety decision: {safety['decision']}")
termcolor.cprint(
"Safety service requires explicit confirmation!",
color="yellow",
Expand Down