|
1 | 1 | import { generateObject } from "ai"; |
2 | | -import { createAmazonBedrock } from "@ai-sdk/amazon-bedrock"; |
3 | 2 | import { z } from "zod"; |
4 | 3 | import { internal } from "../../_generated/api"; |
5 | 4 | import type { RunnerCtx, RunnerResult } from "../shared/types"; |
6 | | -import { MANAGER_MODEL } from "../models"; |
| 5 | +import { mistral, MANAGER_MODEL } from "../models"; |
7 | 6 |
|
8 | 7 | type TaskRecord = { title: string; description?: string }; |
9 | 8 |
|
10 | 9 | const MAX_ITERATIONS = 200; |
11 | 10 | const ACTION_DELAY_MS = 1000; |
12 | 11 |
|
13 | | -// Structured action schema — the model picks one action per step |
14 | | -const ActionSchema = z.discriminatedUnion("action", [ |
15 | | - z.object({ |
16 | | - action: z.literal("click"), |
17 | | - x: z.number().describe("X coordinate to click"), |
18 | | - y: z.number().describe("Y coordinate to click"), |
19 | | - button: z.enum(["left", "right"]).default("left").describe("Mouse button"), |
20 | | - reasoning: z.string().describe("Why you are clicking here"), |
21 | | - }), |
22 | | - z.object({ |
23 | | - action: z.literal("double_click"), |
24 | | - x: z.number().describe("X coordinate to double-click"), |
25 | | - y: z.number().describe("Y coordinate to double-click"), |
26 | | - reasoning: z.string().describe("Why you are double-clicking here"), |
27 | | - }), |
28 | | - z.object({ |
29 | | - action: z.literal("type"), |
30 | | - text: z.string().describe("Text to type"), |
31 | | - reasoning: z.string().describe("Why you are typing this"), |
32 | | - }), |
33 | | - z.object({ |
34 | | - action: z.literal("key"), |
35 | | - key: z.string().describe("Key to press (e.g. Enter, Tab, Escape)"), |
36 | | - modifiers: z.array(z.string()).optional().describe("Modifier keys (e.g. ctrl, alt, shift)"), |
37 | | - reasoning: z.string().describe("Why you are pressing this key"), |
38 | | - }), |
39 | | - z.object({ |
40 | | - action: z.literal("hotkey"), |
41 | | - keys: z.string().describe("Key combo (e.g. ctrl+c, ctrl+l, alt+tab)"), |
42 | | - reasoning: z.string().describe("Why you are pressing this hotkey"), |
43 | | - }), |
44 | | - z.object({ |
45 | | - action: z.literal("scroll"), |
46 | | - x: z.number().describe("X coordinate for scroll position"), |
47 | | - y: z.number().describe("Y coordinate for scroll position"), |
48 | | - direction: z.enum(["up", "down"]).describe("Scroll direction"), |
49 | | - amount: z.number().optional().describe("Scroll amount (default 3)"), |
50 | | - reasoning: z.string().describe("Why you are scrolling"), |
51 | | - }), |
52 | | - z.object({ |
53 | | - action: z.literal("wait"), |
54 | | - seconds: z.number().min(1).max(5).describe("Seconds to wait for page to load"), |
55 | | - reasoning: z.string().describe("Why you are waiting"), |
56 | | - }), |
57 | | - z.object({ |
58 | | - action: z.literal("done"), |
59 | | - result: z.string().describe("Summary of what was accomplished"), |
60 | | - }), |
61 | | -]); |
62 | | - |
63 | | -type Action = z.infer<typeof ActionSchema>; |
| 12 | +// Flat action schema — all fields on one object; "action" acts as discriminator. |
| 13 | +const ActionSchema = z.object({ |
| 14 | + action: z |
| 15 | + .enum(["click", "double_click", "type", "key", "hotkey", "scroll", "wait", "done"]) |
| 16 | + .describe("The action to perform"), |
| 17 | + reasoning: z.string().optional().describe("Why you are taking this action"), |
| 18 | + x: z.number().optional().describe("X coordinate (click, double_click, scroll)"), |
| 19 | + y: z.number().optional().describe("Y coordinate (click, double_click, scroll)"), |
| 20 | + button: z.enum(["left", "right"]).optional().describe("Mouse button for click (default: left)"), |
| 21 | + text: z.string().optional().describe("Text to type (for type action)"), |
| 22 | + key: z.string().optional().describe("Key to press, e.g. Enter, Tab, Escape (for key action)"), |
| 23 | + modifiers: z |
| 24 | + .array(z.string()) |
| 25 | + .optional() |
| 26 | + .describe("Modifier keys e.g. ctrl, alt, shift (for key action)"), |
| 27 | + keys: z.string().optional().describe("Key combo e.g. ctrl+c, alt+tab (for hotkey action)"), |
| 28 | + direction: z.enum(["up", "down"]).optional().describe("Scroll direction (for scroll action)"), |
| 29 | + amount: z.number().optional().describe("Scroll amount, default 3 (for scroll action)"), |
| 30 | + seconds: z.number().optional().describe("Seconds to wait 1-5 (for wait action)"), |
| 31 | + result: z.string().optional().describe("Summary of what was accomplished (for done action)"), |
| 32 | +}); |
| 33 | + |
| 34 | +type FlatAction = z.infer<typeof ActionSchema>; |
| 35 | + |
| 36 | +// Typed action variants for executeAction/formatAction (narrow from flat schema) |
| 37 | +type Action = |
| 38 | + | { action: "click"; x: number; y: number; button: string; reasoning?: string } |
| 39 | + | { action: "double_click"; x: number; y: number; reasoning?: string } |
| 40 | + | { action: "type"; text: string; reasoning?: string } |
| 41 | + | { action: "key"; key: string; modifiers?: string[]; reasoning?: string } |
| 42 | + | { action: "hotkey"; keys: string; reasoning?: string } |
| 43 | + | { |
| 44 | + action: "scroll"; |
| 45 | + x: number; |
| 46 | + y: number; |
| 47 | + direction: "up" | "down"; |
| 48 | + amount?: number; |
| 49 | + reasoning?: string; |
| 50 | + } |
| 51 | + | { action: "wait"; seconds: number; reasoning?: string } |
| 52 | + | { action: "done"; result: string; reasoning?: string }; |
| 53 | + |
| 54 | +function toAction(raw: FlatAction): Action { |
| 55 | + switch (raw.action) { |
| 56 | + case "click": { |
| 57 | + return { |
| 58 | + action: "click", |
| 59 | + x: raw.x ?? 0, |
| 60 | + y: raw.y ?? 0, |
| 61 | + button: raw.button ?? "left", |
| 62 | + reasoning: raw.reasoning, |
| 63 | + }; |
| 64 | + } |
| 65 | + case "double_click": { |
| 66 | + return { action: "double_click", x: raw.x ?? 0, y: raw.y ?? 0, reasoning: raw.reasoning }; |
| 67 | + } |
| 68 | + case "type": { |
| 69 | + return { action: "type", text: raw.text ?? "", reasoning: raw.reasoning }; |
| 70 | + } |
| 71 | + case "key": { |
| 72 | + return { |
| 73 | + action: "key", |
| 74 | + key: raw.key ?? "Enter", |
| 75 | + modifiers: raw.modifiers, |
| 76 | + reasoning: raw.reasoning, |
| 77 | + }; |
| 78 | + } |
| 79 | + case "hotkey": { |
| 80 | + return { action: "hotkey", keys: raw.keys ?? "", reasoning: raw.reasoning }; |
| 81 | + } |
| 82 | + case "scroll": { |
| 83 | + return { |
| 84 | + action: "scroll", |
| 85 | + x: raw.x ?? 0, |
| 86 | + y: raw.y ?? 0, |
| 87 | + direction: raw.direction ?? "down", |
| 88 | + amount: raw.amount, |
| 89 | + reasoning: raw.reasoning, |
| 90 | + }; |
| 91 | + } |
| 92 | + case "wait": { |
| 93 | + return { |
| 94 | + action: "wait", |
| 95 | + seconds: Math.min(5, Math.max(1, raw.seconds ?? 2)), |
| 96 | + reasoning: raw.reasoning, |
| 97 | + }; |
| 98 | + } |
| 99 | + case "done": { |
| 100 | + return { action: "done", result: raw.result ?? "Task completed.", reasoning: raw.reasoning }; |
| 101 | + } |
| 102 | + } |
| 103 | +} |
64 | 104 |
|
65 | 105 | // Run a Computer Use task: start desktop → vision loop → return result |
66 | 106 | export async function runComputerUseTask( |
67 | 107 | ctx: RunnerCtx, |
68 | 108 | agentId: string, |
69 | 109 | task: TaskRecord, |
70 | 110 | ): Promise<RunnerResult> { |
71 | | - const bedrock = createAmazonBedrock({ region: "us-west-2" }); |
72 | | - const model = bedrock(MANAGER_MODEL); |
| 111 | + const model = mistral(MANAGER_MODEL); |
73 | 112 |
|
74 | 113 | // 1. Ensure Computer Use environment is started (Xvfb + xfce4 + VNC) |
75 | 114 | await ctx.runAction(internal.sandbox.lifecycle.ensureComputerUseStarted, { agentId }); |
@@ -183,7 +222,7 @@ export async function runComputerUseTask( |
183 | 222 | ); |
184 | 223 |
|
185 | 224 | // Ask Mistral Large to decide next action |
186 | | - const { object: nextAction, usage: stepUsage } = await generateObject({ |
| 225 | + const { object, usage: stepUsage } = await generateObject({ |
187 | 226 | model, |
188 | 227 | schema: ActionSchema, |
189 | 228 | messages: [ |
@@ -228,6 +267,8 @@ Rules: |
228 | 267 | ], |
229 | 268 | }); |
230 | 269 |
|
| 270 | + const nextAction = toAction(object); |
| 271 | + |
231 | 272 | // Log the action + usage |
232 | 273 | const actionDesc = formatAction(nextAction); |
233 | 274 | actionLog.push(actionDesc); |
|
0 commit comments