Skip to content

Commit 6cd26af

Browse files
committed
Previous actions are now a sequence
1 parent 97636c6 commit 6cd26af

1 file changed

Lines changed: 23 additions & 5 deletions

File tree

src/generalagents/agent.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,23 @@
99

1010

1111
class Session:
12-
def __init__(self, model: str, api_key: str, instruction: str, temperature: float):
12+
def __init__(
13+
self,
14+
model: str,
15+
api_key: str,
16+
instruction: str,
17+
temperature: float,
18+
max_previous_actions: int,
19+
):
20+
""""""
1321
self.model = model
1422
self.instruction = instruction
23+
self.max_previous_actions = max_previous_actions
1524
self.client = httpx.Client(
1625
base_url="https://api.generalagents.com",
1726
headers={"Authorization": f"Bearer {api_key}"},
1827
)
19-
self.previous_action = None
28+
self.previous_actions = []
2029
self.temperature = temperature
2130

2231
def plan(self, observation: Image.Image) -> Action:
@@ -28,29 +37,38 @@ def plan(self, observation: Image.Image) -> Action:
2837
"model": self.model,
2938
"instruction": self.instruction,
3039
"image_url": image_url,
31-
"previous_action": self.previous_action,
40+
"previous_actions": self.previous_actions[-self.max_previous_actions :],
3241
"temperature": self.temperature,
3342
}
3443

3544
res = self.client.post("/v1/control/predict", json=data)
3645
res.raise_for_status()
3746

3847
action = res.json()["action"]
39-
self.previous_action = action
48+
self.previous_actions.append(action)
4049
print(f"Received action {action}")
4150
return cattrs.structure(action, Action) # pyright: ignore [reportArgumentType] https://peps.python.org/pep-0747
4251

4352

4453
class Agent:
45-
def __init__(self, model: str, api_key: str, temperature: float = 0.3):
54+
def __init__(
55+
self,
56+
model: str,
57+
api_key: str,
58+
temperature: float = 0.3,
59+
max_previous_actions: int = 20,
60+
):
61+
""""""
4662
self.model = model
4763
self.api_key = api_key
4864
self.temperature = temperature
65+
self.max_previous_actions = max_previous_actions
4966

5067
def start(self, instruction: str) -> Session:
5168
return Session(
5269
self.model,
5370
api_key=self.api_key,
5471
instruction=instruction,
5572
temperature=self.temperature,
73+
max_previous_actions=self.max_previous_actions,
5674
)

0 commit comments

Comments
 (0)