Skip to content

Commit 3e1a004

Browse files
committed
feat: agents
1 parent 11af30f commit 3e1a004

File tree

9 files changed

+411
-6
lines changed

9 files changed

+411
-6
lines changed

src/Agents/Agent.php

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Cortex\Agents;
6+
7+
use Closure;
8+
use Cortex\Pipeline;
9+
use Cortex\Facades\LLM;
10+
use Cortex\Support\Utils;
11+
use Cortex\LLM\Data\Usage;
12+
use Cortex\Prompts\Prompt;
13+
use Cortex\Memory\ChatMemory;
14+
use InvalidArgumentException;
15+
use Cortex\Contracts\Pipeable;
16+
use Cortex\LLM\Enums\ToolChoice;
17+
use Cortex\LLM\Contracts\Message;
18+
use Cortex\Tasks\Stages\AppendUsage;
19+
use Cortex\LLM\Data\ChatStreamResult;
20+
use Cortex\Memory\Stores\InMemoryStore;
21+
use Cortex\Exceptions\PipelineException;
22+
use Cortex\Tasks\Stages\HandleToolCalls;
23+
use Cortex\JsonSchema\Types\ObjectSchema;
24+
use Cortex\LLM\Data\Messages\SystemMessage;
25+
use Cortex\Tasks\Stages\AddMessageToMemory;
26+
use Illuminate\Contracts\Support\Arrayable;
27+
use Cortex\LLM\Contracts\LLM as LLMContract;
28+
use Cortex\Prompts\Builders\ChatPromptBuilder;
29+
use Cortex\LLM\Data\Messages\MessagePlaceholder;
30+
use Cortex\Prompts\Templates\ChatPromptTemplate;
31+
32+
class Agent implements Pipeable
33+
{
34+
protected LLMContract $llm;
35+
36+
protected ChatPromptTemplate $prompt;
37+
38+
protected ChatMemory $memory;
39+
40+
protected Usage $usage;
41+
42+
/**
43+
* @param class-string|\Cortex\JsonSchema\Types\ObjectSchema $output
44+
* @param array<int, \Cortex\LLM\Contracts\Tool|\Closure|string> $tools
45+
* @param array<string, mixed> $initialPromptVariables
46+
*/
47+
public function __construct(
48+
protected string $name,
49+
ChatPromptTemplate|ChatPromptBuilder|string|null $prompt = null,
50+
?LLMContract $llm = null,
51+
protected ?string $description = null,
52+
protected array $tools = [],
53+
protected ToolChoice|string $toolChoice = ToolChoice::Auto,
54+
protected ObjectSchema|string|null $output = null,
55+
protected array $initialPromptVariables = [],
56+
protected int $maxSteps = 1,
57+
protected bool $strict = true,
58+
) {
59+
if ($prompt !== null) {
60+
$this->prompt = match (true) {
61+
is_string($prompt) => Prompt::builder('chat')
62+
->messages([
63+
new SystemMessage($prompt),
64+
])
65+
->strict($this->strict)
66+
->initialVariables($this->initialPromptVariables)
67+
->build(),
68+
$prompt instanceof ChatPromptBuilder => $prompt->build(),
69+
$prompt instanceof ChatPromptTemplate => $prompt,
70+
default => throw new InvalidArgumentException('Invalid prompt type.'),
71+
};
72+
73+
$this->prompt->addMessage(new MessagePlaceholder('messages'));
74+
} else {
75+
$this->prompt = new ChatPromptTemplate([
76+
new MessagePlaceholder('messages'),
77+
], $this->initialPromptVariables);
78+
}
79+
80+
$this->memory = new ChatMemory(new InMemoryStore($this->prompt->messages->withoutPlaceholders()));
81+
$this->usage = Usage::empty();
82+
83+
$this->llm = $llm ?? LLM::provider();
84+
85+
if ($this->tools !== []) {
86+
$this->llm->withTools($this->tools, $this->toolChoice);
87+
}
88+
89+
if ($this->output !== null) {
90+
$this->llm->withStructuredOutput(
91+
output: $this->output,
92+
name: $this->name,
93+
strict: $this->strict,
94+
);
95+
}
96+
}
97+
98+
public function pipeline(bool $shouldParseOutput = true): Pipeline
99+
{
100+
$tools = Utils::toToolCollection($this->getTools());
101+
102+
return $this->executionPipeline($shouldParseOutput)
103+
->when(
104+
$tools->isNotEmpty(),
105+
fn(Pipeline $pipeline): Pipeline => $pipeline->pipe(
106+
new HandleToolCalls(
107+
$tools,
108+
$this->memory,
109+
$this->executionPipeline($shouldParseOutput),
110+
$this->maxSteps,
111+
),
112+
),
113+
);
114+
}
115+
116+
/**
117+
* This is the main pipeline that will be used to generate the output.
118+
*/
119+
public function executionPipeline(bool $shouldParseOutput = true): Pipeline
120+
{
121+
return $this->prompt
122+
->pipe($this->llm->shouldParseOutput($shouldParseOutput))
123+
->pipe(new AddMessageToMemory($this->memory))
124+
->pipe(new AppendUsage($this->usage));
125+
}
126+
127+
/**
128+
* @param array<int, \Cortex\LLM\Contracts\Message> $messages
129+
* @param array<string, mixed> $input
130+
*/
131+
public function invoke(array $messages = [], array $input = []): mixed
132+
{
133+
// $this->id ??= $this->generateId();
134+
$this->memory->setVariables([
135+
...$this->initialPromptVariables,
136+
...$input,
137+
]);
138+
139+
$messages = $this->memory->getMessages()->merge($messages);
140+
$this->memory->setMessages($messages);
141+
142+
return $this->pipeline()->invoke([
143+
...$input,
144+
'messages' => $this->memory->getMessages(),
145+
]);
146+
}
147+
148+
/**
149+
* @param array<string, mixed> $input
150+
*/
151+
public function stream(array $messages = [], array $input = []): ChatStreamResult
152+
{
153+
// $this->id ??= $this->generateId();
154+
$this->memory->setVariables([
155+
...$this->initialPromptVariables,
156+
...$input,
157+
]);
158+
159+
$messages = $this->memory->getMessages()->merge($messages);
160+
$this->memory->setMessages($messages);
161+
162+
return $this->pipeline()->stream([
163+
...$input,
164+
'messages' => $this->memory->getMessages(),
165+
]);
166+
}
167+
168+
public function pipe(Pipeable|callable $pipeable): Pipeline
169+
{
170+
return $this->pipeline()->pipe($pipeable);
171+
}
172+
173+
public function handlePipeable(mixed $payload, Closure $next): mixed
174+
{
175+
$payload = match (true) {
176+
$payload === null => [],
177+
is_array($payload) => $payload,
178+
$payload instanceof Arrayable => $payload->toArray(),
179+
is_object($payload) => get_object_vars($payload),
180+
default => throw new PipelineException('Invalid input for agent.'),
181+
};
182+
183+
return $next($this->invoke($payload));
184+
}
185+
186+
public function getName(): string
187+
{
188+
return $this->name;
189+
}
190+
191+
public function getDescription(): ?string
192+
{
193+
return $this->description;
194+
}
195+
196+
public function getPrompt(): ChatPromptTemplate
197+
{
198+
return $this->prompt;
199+
}
200+
201+
/**
202+
* @return array<int, \Cortex\LLM\Contracts\Tool|\Closure|string>
203+
*/
204+
public function getTools(): array
205+
{
206+
return $this->tools;
207+
}
208+
209+
public function getLLM(): LLMContract
210+
{
211+
return $this->llm;
212+
}
213+
214+
public function getMemory(): ChatMemory
215+
{
216+
return $this->memory;
217+
}
218+
219+
public function getUsage(): Usage
220+
{
221+
return $this->usage;
222+
}
223+
}

src/LLM/Drivers/OpenAIChat.php

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,10 @@ protected function mapResponse(CreateResponse $response): ChatResult
105105
->map(function (CreateResponseChoice $choice) use ($toolCalls, $finishReason, $usage, $response): ChatGeneration {
106106
$generation = new ChatGeneration(
107107
message: new AssistantMessage(
108-
content: [
109-
new TextContent($choice->message->content),
110-
],
108+
content: $choice->message->content,
109+
// content: [
110+
// new TextContent($choice->message->content),
111+
// ],
111112
toolCalls: $toolCalls,
112113
metadata: new ResponseMetadata(
113114
id: $response->id,

src/Memory/ChatMemory.php

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ public function getMessages(): MessageCollection
5858
return $messages;
5959
}
6060

61+
public function setMessages(MessageCollection $messages): static
62+
{
63+
$this->store->setMessages($messages);
64+
65+
return $this;
66+
}
67+
6168
/**
6269
* @param array<string, mixed> $variables
6370
*/

src/Memory/Contracts/Store.php

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ public function addMessage(Message $message): void;
2626
*/
2727
public function addMessages(MessageCollection|array $messages): void;
2828

29+
/**
30+
* Set the messages in the store.
31+
*/
32+
public function setMessages(MessageCollection $messages): void;
33+
2934
/**
3035
* Reset the store.
3136
*/

src/Memory/Stores/InMemoryStore.php

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ public function addMessages(MessageCollection|array $messages): void
2929
$this->messages->merge($messages);
3030
}
3131

32+
public function setMessages(MessageCollection $messages): void
33+
{
34+
$this->messages = $messages;
35+
}
36+
3237
public function reset(): void
3338
{
3439
$this->messages = new MessageCollection();

src/Prompts/Templates/ChatPromptTemplate.php

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
use Override;
88
use Cortex\Support\Utils;
9+
use Cortex\LLM\Contracts\Message;
910
use Illuminate\Support\Collection;
1011
use Cortex\JsonSchema\SchemaFactory;
1112
use Cortex\Exceptions\PromptException;
@@ -14,6 +15,7 @@
1415
use Cortex\JsonSchema\Types\UnionSchema;
1516
use Cortex\JsonSchema\Types\ObjectSchema;
1617
use Cortex\LLM\Data\Messages\MessageCollection;
18+
use Cortex\LLM\Data\Messages\MessagePlaceholder;
1719

1820
class ChatPromptTemplate extends AbstractPromptTemplate
1921
{
@@ -57,6 +59,13 @@ public function variables(): Collection
5759
->unique();
5860
}
5961

62+
public function addMessage(Message|MessagePlaceholder $message): self
63+
{
64+
$this->messages->add($message);
65+
66+
return $this;
67+
}
68+
6069
#[Override]
6170
public function defaultInputSchema(): ObjectSchema
6271
{

src/Tools/AbstractTool.php

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,19 @@ abstract class AbstractTool implements Tool, Pipeable
2121
*/
2222
public function format(): array
2323
{
24-
return [
24+
$output = [
2525
'name' => $this->name(),
2626
'description' => $this->description(),
27-
'parameters' => $this->schema()->toArray(includeSchemaRef: false, includeTitle: false),
2827
];
28+
29+
$schema = $this->schema();
30+
31+
// If the schema has no properties, then we don't need to include the parameters.
32+
if (! empty($schema->getPropertyKeys())) {
33+
$output['parameters'] = $schema->toArray(includeSchemaRef: false, includeTitle: false);
34+
}
35+
36+
return $output;
2937
}
3038

3139
public function handlePipeable(mixed $payload, Closure $next): mixed

src/Tools/ClosureTool.php

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ public function invoke(ToolCall|array $toolCall = []): mixed
5151
$arguments = $this->getArguments($toolCall);
5252

5353
// Ensure arguments are valid as per the tool's schema.
54-
$this->schema->validate($arguments);
54+
if ($arguments !== []) {
55+
$this->schema->validate($arguments);
56+
}
5557

5658
// Invoke the closure with the arguments.
5759
return $this->reflection->invokeArgs($arguments);

0 commit comments

Comments
 (0)