Skip to content

Commit 8dd63eb

Browse files
jurgisppsarka
authored andcommitted
Set allowed_action_kinds
1 parent 5345345 commit 8dd63eb

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

src/generalagents/agent.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import httpx
77
from PIL import Image
88

9-
from generalagents.action import Action
9+
from generalagents.action import Action, ActionKind
1010

1111

1212
class Session:
@@ -30,7 +30,12 @@ def __init__(
3030
self.previous_actions = []
3131
self.temperature = temperature
3232

33-
def plan(self, observation: Image.Image) -> Action:
33+
def plan(
34+
self,
35+
observation: Image.Image,
36+
*,
37+
allowed_action_kinds: list[ActionKind] | None = None,
38+
) -> Action:
3439
buffer = BytesIO()
3540
observation.save(buffer, format="WEBP")
3641
image_url = f"data:image/webp;base64,{base64.b64encode(buffer.getvalue()).decode('utf8')}"
@@ -41,6 +46,7 @@ def plan(self, observation: Image.Image) -> Action:
4146
"image_url": image_url,
4247
"previous_actions": self.previous_actions[-self.max_previous_actions :],
4348
"temperature": self.temperature,
49+
"allowed_action_kinds": allowed_action_kinds,
4450
}
4551

4652
res = self.client.post("/v1/control/predict", json=data)

0 commit comments

Comments
 (0)