Skip to content

Commit 734bbbb

Browse files
committed
feat: Anthropic
1 parent 73e9867 commit 734bbbb

File tree

13 files changed

+199
-115
lines changed

13 files changed

+199
-115
lines changed

src/LLM/AbstractLLM.php

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use Cortex\Contracts\OutputParser;
1919
use Cortex\Support\Traits\CanPipe;
2020
use Cortex\Exceptions\LLMException;
21+
use Cortex\LLM\Data\ChatGeneration;
2122
use Cortex\JsonSchema\SchemaFactory;
2223
use Cortex\ModelInfo\Data\ModelInfo;
2324
use Cortex\LLM\Data\ChatStreamResult;
@@ -50,6 +51,8 @@ abstract class AbstractLLM implements LLM
5051

5152
protected ?StructuredOutputConfig $structuredOutputConfig = null;
5253

54+
protected StructuredOutputMode $structuredOutputMode = StructuredOutputMode::Auto;
55+
5356
protected ?OutputParser $outputParser = null;
5457

5558
protected ?string $outputParserError = null;
@@ -64,6 +67,8 @@ abstract class AbstractLLM implements LLM
6467

6568
protected bool $shouldApplyFormatInstructions = false;
6669

70+
protected bool $shouldParseOutput = true;
71+
6772
/**
6873
* @var array<\Cortex\ModelInfo\Enums\ModelFeature>
6974
*/
@@ -116,15 +121,19 @@ public function output(OutputParser $parser): Pipeline
116121
/**
117122
* @param array<int, \Cortex\LLM\Contracts\Tool|\Cortex\JsonSchema\Contracts\Schema|\Closure|string> $tools
118123
*/
119-
public function withTools(array $tools, ToolChoice|string $toolChoice = ToolChoice::Auto): static
120-
{
124+
public function withTools(
125+
array $tools,
126+
ToolChoice|string $toolChoice = ToolChoice::Auto,
127+
bool $allowParallelToolCalls = true,
128+
): static {
121129
$this->supportsFeatureOrFail(ModelFeature::ToolCalling);
122130

123131
$this->toolConfig = $tools === []
124132
? null
125133
: new ToolConfig(
126134
Utils::toToolCollection($tools)->all(),
127135
$toolChoice,
136+
$allowParallelToolCalls,
128137
);
129138

130139
return $this;
@@ -133,9 +142,15 @@ public function withTools(array $tools, ToolChoice|string $toolChoice = ToolChoi
133142
/**
134143
* Add a tool to the LLM.
135144
*/
136-
public function addTool(Tool|Closure|string $tool, ToolChoice|string $toolChoice = ToolChoice::Auto): static
137-
{
138-
return $this->withTools([...($this->toolConfig->tools ?? []), $tool], $toolChoice);
145+
public function addTool(
146+
Tool|Closure|string $tool,
147+
ToolChoice|string $toolChoice = ToolChoice::Auto,
148+
bool $allowParallelToolCalls = true,
149+
): static {
150+
return $this->withTools([
151+
...($this->toolConfig->tools ?? []),
152+
$tool,
153+
], $toolChoice, $allowParallelToolCalls);
139154
}
140155

141156
/**
@@ -166,6 +181,7 @@ public function withStructuredOutput(
166181
bool $strict = true,
167182
StructuredOutputMode $outputMode = StructuredOutputMode::Auto,
168183
): static {
184+
$this->structuredOutputMode = $outputMode;
169185
[$schema, $outputParser] = $this->resolveSchemaAndOutputParser($output, $strict);
170186

171187
$this->withOutputParser($outputParser);
@@ -386,6 +402,27 @@ protected static function applyFormatInstructions(
386402
return $messages;
387403
}
388404

405+
public function shouldParseOutput(bool $shouldParseOutput = true): static
406+
{
407+
$this->shouldParseOutput = $shouldParseOutput;
408+
409+
return $this;
410+
}
411+
412+
protected function applyOutputParserIfApplicable(ChatGeneration $generation): ChatGeneration
413+
{
414+
if ($this->shouldParseOutput && $this->outputParser !== null) {
415+
try {
416+
$parsedOutput = $this->outputParser->parse($generation);
417+
$generation = $generation->cloneWithParsedOutput($parsedOutput);
418+
} catch (OutputParserException $e) {
419+
$this->outputParserError = $e->getMessage();
420+
}
421+
}
422+
423+
return $generation;
424+
}
425+
389426
/**
390427
* Resolve the schema and output parser from the given output type.
391428
*

src/LLM/CacheDecorator.php

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ public function ignoreFeatures(bool $ignoreModelFeatures = true): static
128128
return $this;
129129
}
130130

131+
public function shouldParseOutput(bool $shouldParseOutput = true): static
132+
{
133+
$this->llm = $this->llm->shouldParseOutput($shouldParseOutput);
134+
135+
return $this;
136+
}
137+
131138
public function withModel(string $model): static
132139
{
133140
$this->llm = $this->llm->withModel($model);

src/LLM/Contracts/LLM.php

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,4 +150,11 @@ public function getFeatures(): array;
150150
* Get the model info for the LLM.
151151
*/
152152
public function getModelInfo(): ?ModelInfo;
153+
154+
/**
155+
* Set whether the output should be parsed.
156+
* This may be set to false when called in a pipeline context and output parsing
157+
* is done as part of the next pipeable.
158+
*/
159+
public function shouldParseOutput(bool $shouldParseOutput = true): static;
153160
}

src/LLM/Contracts/Tool.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
namespace Cortex\LLM\Contracts;
66

77
use Cortex\LLM\Data\ToolCall;
8-
use Cortex\JsonSchema\Contracts\Schema;
8+
use Cortex\JsonSchema\Types\ObjectSchema;
99
use Cortex\LLM\Data\Messages\ToolMessage;
1010

1111
interface Tool
@@ -23,7 +23,7 @@ public function description(): string;
2323
/**
2424
* Get the schema of the tool.
2525
*/
26-
public function schema(): Schema;
26+
public function schema(): ObjectSchema;
2727

2828
/**
2929
* Get the formatted output of the tool.

src/LLM/Data/ToolConfig.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414
public function __construct(
1515
public array $tools,
1616
public ToolChoice|string $toolChoice = ToolChoice::Auto,
17+
public bool $allowParallelToolCalls = true,
1718
) {}
1819
}

src/LLM/Drivers/AnthropicChat.php

Lines changed: 69 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
use Cortex\LLM\Data\Usage;
1212
use Cortex\LLM\AbstractLLM;
1313
use Illuminate\Support\Arr;
14+
use Cortex\LLM\Data\ToolCall;
1415
use Cortex\LLM\Contracts\Tool;
1516
use Cortex\Events\ChatModelEnd;
1617
use Cortex\LLM\Data\ChatResult;
@@ -19,26 +20,26 @@
1920
use Cortex\Events\ChatModelError;
2021
use Cortex\Events\ChatModelStart;
2122
use Cortex\LLM\Contracts\Message;
23+
use Cortex\LLM\Data\FunctionCall;
2224
use Cortex\LLM\Enums\MessageRole;
2325
use Cortex\LLM\Enums\FinishReason;
2426
use Cortex\Exceptions\LLMException;
2527
use Cortex\LLM\Data\ChatGeneration;
2628
use Cortex\LLM\Data\ChatStreamResult;
2729
use Cortex\LLM\Data\ResponseMetadata;
2830
use Anthropic\Contracts\ClientContract;
31+
use Cortex\LLM\Data\ToolCallCollection;
2932
use Cortex\LLM\Data\ChatGenerationChunk;
30-
use Cortex\JsonSchema\Types\ObjectSchema;
3133
use Cortex\LLM\Data\Messages\ToolMessage;
3234
use Cortex\ModelInfo\Enums\ModelProvider;
3335
use Cortex\LLM\Data\Messages\SystemMessage;
36+
use Cortex\Tasks\Enums\StructuredOutputMode;
3437
use Cortex\LLM\Data\Messages\AssistantMessage;
3538
use Cortex\LLM\Data\Messages\MessageCollection;
3639
use Anthropic\Responses\Messages\CreateResponse;
3740
use Anthropic\Responses\Messages\StreamResponse;
3841
use Cortex\LLM\Data\Messages\Content\TextContent;
39-
use Cortex\LLM\Data\Messages\Content\ImageContent;
4042
use Anthropic\Responses\Messages\CreateResponseUsage;
41-
use Cortex\LLM\Data\Messages\Content\DocumentContent;
4243
use Anthropic\Responses\Messages\CreateResponseContent;
4344
use Anthropic\Responses\Messages\CreateStreamedResponseUsage;
4445
use Anthropic\Testing\Responses\Fixtures\Messages\CreateResponseFixture;
@@ -106,40 +107,47 @@ public function invoke(
106107
*/
107108
protected function mapResponse(CreateResponse $response): ChatResult
108109
{
109-
$content = $response->content[0];
110-
// $toolCalls = $content->toolCalls === [] ? null : new ToolCallCollection(
111-
// collect($content->toolCalls)
112-
// ->map(fn(CreateResponseToolCall $toolCall): ToolCall => new ToolCall(
113-
// $toolCall->id,
114-
// new FunctionCall(
115-
// $toolCall->function->name,
116-
// json_decode($toolCall->function->arguments, true, flags: JSON_THROW_ON_ERROR),
117-
// ),
118-
// ))
119-
// ->values()
120-
// ->all(),
121-
// );
110+
$toolCalls = array_filter(
111+
$response->content,
112+
fn(CreateResponseContent $content): bool => $content->type === 'tool_use',
113+
);
114+
115+
$toolCalls = collect($toolCalls)
116+
->map(fn(CreateResponseContent $content): ToolCall => new ToolCall(
117+
$content->id,
118+
new FunctionCall($content->name, $content->input ?? []),
119+
))
120+
->values()
121+
->all();
122+
123+
$toolCalls = $toolCalls !== []
124+
? new ToolCallCollection($toolCalls)
125+
: null;
122126

123127
$usage = $this->mapUsage($response->usage);
124128
$finishReason = static::mapFinishReason($response->stop_reason ?? null);
125129

126130
$generations = collect($response->content)
127-
->map(fn(CreateResponseContent $content): ChatGeneration => new ChatGeneration(
128-
message: new AssistantMessage(
129-
content: $content->text,
130-
// toolCalls: $toolCalls,
131-
metadata: new ResponseMetadata(
132-
id: $response->id,
133-
model: $response->model,
134-
provider: $this->modelProvider,
135-
finishReason: $finishReason,
136-
usage: $usage,
131+
->map(function (CreateResponseContent $content) use ($toolCalls, $finishReason, $usage, $response): ChatGeneration {
132+
$generation = new ChatGeneration(
133+
message: new AssistantMessage(
134+
content: $content->text,
135+
toolCalls: $toolCalls,
136+
metadata: new ResponseMetadata(
137+
id: $response->id,
138+
model: $response->model,
139+
provider: $this->modelProvider,
140+
finishReason: $finishReason,
141+
usage: $usage,
142+
),
137143
),
138-
),
139-
index: 0,
140-
createdAt: new DateTimeImmutable(),
141-
finishReason: $finishReason,
142-
))
144+
index: 0,
145+
createdAt: new DateTimeImmutable(),
146+
finishReason: $finishReason,
147+
);
148+
149+
return $this->applyOutputParserIfApplicable($generation);
150+
})
143151
->all();
144152

145153
$result = new ChatResult(
@@ -319,16 +327,16 @@ protected static function mapMessagesForInput(MessageCollection $messages): arra
319327
'type' => 'text',
320328
'text' => $content->text,
321329
],
322-
$content instanceof ImageContent => [
323-
'type' => 'image_url',
324-
'image_url' => [
325-
'url' => $content->urlOrBase64,
326-
],
327-
],
328-
$content instanceof DocumentContent => [
329-
'type' => 'document',
330-
'document' => $content->data,
331-
],
330+
// $content instanceof ImageContent => [
331+
// 'type' => 'image_url',
332+
// 'image_url' => [
333+
// 'url' => $content->urlOrBase64,
334+
// ],
335+
// ],
336+
// $content instanceof DocumentContent => [
337+
// 'type' => 'document',
338+
// 'document' => $content->data,
339+
// ],
332340
default => $content,
333341
};
334342
}, $formattedMessage['content']);
@@ -367,56 +375,39 @@ protected function buildParams(array $additionalParameters): array
367375
];
368376

369377
if ($this->structuredOutputConfig !== null) {
370-
$schema = $this->structuredOutputConfig->schema;
371-
$params['response_format'] = [
372-
'type' => 'json_schema',
373-
'json_schema' => [
374-
'name' => $this->structuredOutputConfig->name,
375-
'description' => $this->structuredOutputConfig->description ?? $schema->getDescription(),
376-
'schema' => $schema instanceof ObjectSchema
377-
? $schema->additionalProperties(false)->toArray()
378-
: $schema,
379-
'strict' => $this->structuredOutputConfig->strict,
380-
],
381-
];
378+
$this->structuredOutputMode = StructuredOutputMode::Tool;
382379
} elseif ($this->forceJsonOutput) {
383-
$params['response_format'] = [
384-
'type' => 'json_object',
385-
];
380+
$this->structuredOutputMode = StructuredOutputMode::Json;
386381
}
387382

388383
if ($this->toolConfig !== null) {
389-
if (is_string($this->toolConfig->toolChoice)) {
390-
$toolChoice = [
391-
'type' => 'function',
392-
'function' => [
393-
'name' => $this->toolConfig->toolChoice,
394-
],
395-
];
396-
} else {
397-
$toolChoice = $this->toolConfig->toolChoice->value;
398-
}
384+
$choice = $this->toolConfig->toolChoice;
399385

400-
$params['tool_choice'] = match ($toolChoice) {
401-
ToolChoice::Required->value => 'any',
402-
default => $toolChoice,
386+
$params['tool_choice'] = match (true) {
387+
is_string($choice) => [
388+
'type' => 'tool',
389+
'name' => $choice,
390+
],
391+
default => [
392+
'type' => match ($choice) {
393+
ToolChoice::Required => 'any',
394+
default => $choice,
395+
},
396+
'disable_parallel_tool_use' => ! $this->toolConfig->allowParallelToolCalls,
397+
],
403398
};
404399

400+
// TODO: add ProviderTool support for Anthropic e.g. web_search, etc.
405401
$params['tools'] = collect($this->toolConfig->tools)
406402
->map(fn(Tool $tool): array => [
407-
'type' => 'function',
408-
'function' => $tool->format(),
403+
'type' => 'custom',
404+
'name' => $tool->name(),
405+
'description' => $tool->description(),
406+
'input_schema' => $tool->schema()->additionalProperties(false)->toArray(),
409407
])
410408
->toArray();
411409
}
412410

413-
// Ensure the usage information is returned when streaming
414-
if ($this->streaming) {
415-
$params['stream_options'] = [
416-
'include_usage' => true,
417-
];
418-
}
419-
420411
return [
421412
...$params,
422413
...$this->parameters,

src/LLM/Drivers/OpenAIChat.php

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,7 @@ protected function mapResponse(CreateResponse $response): ChatResult
130130
finishReason: $finishReason,
131131
);
132132

133-
if ($this->outputParser !== null) {
134-
try {
135-
$parsedOutput = $this->outputParser->parse($generation);
136-
$generation = $generation->cloneWithParsedOutput($parsedOutput);
137-
} catch (OutputParserException $e) {
138-
$this->outputParserError = $e->getMessage();
139-
}
140-
}
141-
142-
return $generation;
133+
return $this->applyOutputParserIfApplicable($generation);
143134
})
144135
->all();
145136

@@ -422,6 +413,7 @@ protected function buildParams(array $additionalParameters): array
422413
}
423414

424415
$params['tool_choice'] = $toolChoice;
416+
$params['parallel_tool_calls'] = $this->toolConfig->allowParallelToolCalls;
425417
$params['tools'] = collect($this->toolConfig->tools)
426418
->map(fn(Tool $tool): array => [
427419
'type' => 'function',

0 commit comments

Comments
 (0)