diff --git a/examples/openai/toolcall-stream-with-usage.php b/examples/openai/toolcall-stream-with-usage.php new file mode 100644 index 000000000..cea8d5ed5 --- /dev/null +++ b/examples/openai/toolcall-stream-with-usage.php @@ -0,0 +1,55 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Symfony\AI\Agent\Agent; +use Symfony\AI\Agent\Bridge\Clock\Clock; +use Symfony\AI\Agent\Bridge\OpenMeteo\OpenMeteo; +use Symfony\AI\Agent\Toolbox\AgentProcessor; +use Symfony\AI\Agent\Toolbox\Toolbox; +use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Result\TextChunk; + +require_once dirname(__DIR__).'/bootstrap.php'; + +$platform = PlatformFactory::create(env('OPENAI_API_KEY'), http_client()); + +$clock = new Clock(); +$openMeteo = new OpenMeteo(http_client()); +$toolbox = new Toolbox([$clock, $openMeteo], logger: logger()); +$processor = new AgentProcessor($toolbox); + +$agent = new Agent($platform, 'gpt-4o-mini', [$processor], [$processor]); +$messages = new MessageBag(Message::ofUser('Tell me the time and the weather in Dublin.')); + +$result = $agent->call($messages, [ + 'stream' => true, // enable streaming of response text + 'stream_options' => [ + 'include_usage' => true, // include usage in the response + ], +]); + +/** @var TextChunk $textChunk */ +foreach ($result->getContent() as $textChunk) { + echo $textChunk->getContent(); +} + +foreach ($result->getMetadata()->get('calls', []) as $call) { + echo \PHP_EOL.sprintf( + '%s: %d tokens - Finish reason: [%s]', + $call['id'], + $call['usage']['total_tokens'], + $call['finish_reason'] + ); +} + +echo \PHP_EOL; diff --git a/src/agent/src/Toolbox/AgentProcessor.php b/src/agent/src/Toolbox/AgentProcessor.php index e3b029b60..b3fa0a961 100644 --- a/src/agent/src/Toolbox/AgentProcessor.php +++ b/src/agent/src/Toolbox/AgentProcessor.php @@ -80,7 +80,7 @@ public function processOutput(Output $output): void if ($result instanceof GenericStreamResponse) { $output->setResult( - new ToolboxStreamResponse($result->getContent(), $this->handleToolCallsCallback($output)) + new ToolboxStreamResponse($result, $this->handleToolCallsCallback($output)) ); return; diff --git a/src/agent/src/Toolbox/StreamResult.php b/src/agent/src/Toolbox/StreamResult.php index afc10378a..44c17ed13 100644 --- a/src/agent/src/Toolbox/StreamResult.php +++ b/src/agent/src/Toolbox/StreamResult.php @@ -13,6 +13,7 @@ use Symfony\AI\Platform\Message\Message; use Symfony\AI\Platform\Result\BaseResult; +use Symfony\AI\Platform\Result\StreamResult as PlatformStreamResult; use Symfony\AI\Platform\Result\ToolCallResult; /** @@ -21,7 +22,7 @@ final class StreamResult extends BaseResult { public function __construct( - private readonly \Generator $generator, + private readonly PlatformStreamResult $sourceStreamResult, private readonly \Closure $handleToolCallsCallback, ) { } @@ -29,7 +30,7 @@ public function __construct( public function getContent(): \Generator { $streamedResult = ''; - foreach ($this->generator as $value) { + foreach ($this->sourceStreamResult->getContent() as $value) { if ($value instanceof ToolCallResult) { $innerResult = ($this->handleToolCallsCallback)($value, Message::ofAssistant($streamedResult)); @@ -48,12 +49,30 @@ public function getContent(): \Generator yield from $content; } - break; + if ($innerResult->getMetadata()->has('calls')) { + $innerCalls = $innerResult->getMetadata()->get('calls'); + $previousCalls = $this->getMetadata()->get('calls', []); + $calls = array_merge($previousCalls, $innerCalls); + } else { + $calls[] = $innerResult->getMetadata()->all(); + } + + if ($calls !== ['calls' => []]) { + $this->getMetadata()->add('calls', $calls); + } + + continue; } $streamedResult .= $value; yield $value; } + + // Attach the metadata from the platform stream to the agent after the stream has been fully processed + // and the post-result metadata, such as usage, has been received. + $calls = $this->getMetadata()->get('calls', []); + $calls[] = $this->sourceStreamResult->getMetadata()->all(); + $this->getMetadata()->add('calls', $calls); } } diff --git a/src/ai-bundle/src/Profiler/TraceablePlatform.php b/src/ai-bundle/src/Profiler/TraceablePlatform.php index 5bfe6b58b..bf21f85fa 100644 --- a/src/ai-bundle/src/Profiler/TraceablePlatform.php +++ b/src/ai-bundle/src/Profiler/TraceablePlatform.php @@ -56,8 +56,7 @@ public function invoke(string $model, array|string|object $input, array $options } if ($options['stream'] ?? false) { - $originalStream = $deferredResult->asStream(); - $deferredResult = new DeferredResult(new PlainConverter($this->createTraceableStreamResult($originalStream)), $deferredResult->getRawResult(), $options); + $deferredResult = new DeferredResult(new PlainConverter($this->createTraceableStreamResult($deferredResult)), $deferredResult->getRawResult(), $options); } $this->calls[] = [ @@ -75,16 +74,20 @@ public function getModelCatalog(): ModelCatalogInterface return $this->platform->getModelCatalog(); } - private function createTraceableStreamResult(\Generator $originalStream): StreamResult + private function createTraceableStreamResult(DeferredResult $sourceResult): StreamResult { - return $result = new StreamResult((function () use (&$result, $originalStream) { + return $result = new StreamResult((function () use (&$result, $sourceResult) { $this->resultCache[$result] = ''; - foreach ($originalStream as $chunk) { + foreach ($sourceResult->asStream() as $chunk) { yield $chunk; if (\is_string($chunk)) { $this->resultCache[$result] .= $chunk; } } + + foreach ($sourceResult->getResult()->getMetadata() as $key => $value) { + $result->getMetadata()->add($key, $value); + } })()); } } diff --git a/src/platform/src/Bridge/OpenAi/Gpt/ResultConverter.php b/src/platform/src/Bridge/OpenAi/Gpt/ResultConverter.php index 29765a4a4..2495fab9c 100644 --- a/src/platform/src/Bridge/OpenAi/Gpt/ResultConverter.php +++ b/src/platform/src/Bridge/OpenAi/Gpt/ResultConverter.php @@ -17,12 +17,14 @@ use Symfony\AI\Platform\Exception\ContentFilterException; use Symfony\AI\Platform\Exception\RateLimitExceededException; use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Metadata\Metadata; use Symfony\AI\Platform\Model; use Symfony\AI\Platform\Result\ChoiceResult; use Symfony\AI\Platform\Result\RawHttpResult; use Symfony\AI\Platform\Result\RawResultInterface; use Symfony\AI\Platform\Result\ResultInterface; use Symfony\AI\Platform\Result\StreamResult; +use Symfony\AI\Platform\Result\TextChunk; use Symfony\AI\Platform\Result\TextResult; use Symfony\AI\Platform\Result\ToolCall; use Symfony\AI\Platform\Result\ToolCallResult; @@ -88,21 +90,43 @@ public function convert(RawResultInterface|RawHttpResult $result, array $options private function convertStream(RawResultInterface|RawHttpResult $result): \Generator { $toolCalls = []; + $metadata = []; foreach ($result->getDataStream() as $data) { + if (!$metadata && isset($data['id'])) { + $metadata['id'] = $data['id']; + } + + if (isset($data['usage'])) { + $metadata['usage'] = $data['usage']; + } + + if (isset($data['choices'][0]['finish_reason'])) { + $metadata['finish_reason'] = $data['choices'][0]['finish_reason']; + } + if ($this->streamIsToolCall($data)) { $toolCalls = $this->convertStreamToToolCalls($toolCalls, $data); } if ([] !== $toolCalls && $this->isToolCallsStreamFinished($data)) { - yield new ToolCallResult(...array_map($this->convertToolCall(...), $toolCalls)); + $toolCallResult = new ToolCallResult(...array_map($this->convertToolCall(...), $toolCalls)); + $metadata['tool_calls'] = $toolCalls; + $toolCallResult->getMetadata()->set($metadata); + yield $toolCallResult; } if (!isset($data['choices'][0]['delta']['content'])) { continue; } - yield $data['choices'][0]['delta']['content']; + $textChunk = new TextChunk($data['choices'][0]['delta']['content']); + $textChunk->getMetadata()->set($metadata); + $textChunk->setRawResult($result); + + yield $textChunk; } + + yield new Metadata($metadata); } /** diff --git a/src/platform/src/Result/DeferredResult.php b/src/platform/src/Result/DeferredResult.php index ea9ce05cd..815c71bbd 100644 --- a/src/platform/src/Result/DeferredResult.php +++ b/src/platform/src/Result/DeferredResult.php @@ -119,7 +119,13 @@ public function asVectors(): array */ public function asStream(): \Generator { - yield from $this->as(StreamResult::class)->getContent(); + $streamResult = $this->as(StreamResult::class); + + yield from $streamResult->getContent(); + + foreach ($streamResult->getMetadata() as $key => $value) { + $this->getResult()->getMetadata()->add($key, $value); + } } /** diff --git a/src/platform/src/Result/StreamResult.php b/src/platform/src/Result/StreamResult.php index ef253ec3c..6b36af7d1 100644 --- a/src/platform/src/Result/StreamResult.php +++ b/src/platform/src/Result/StreamResult.php @@ -11,6 +11,8 @@ namespace Symfony\AI\Platform\Result; +use Symfony\AI\Platform\Metadata\Metadata; + /** * @author Christopher Hertel */ @@ -23,6 +25,15 @@ public function __construct( public function getContent(): \Generator { - yield from $this->generator; + foreach ($this->generator as $content) { + if ($content instanceof Metadata) { + foreach ($content as $key => $value) { + $this->getMetadata()->add($key, $value); + } + continue; + } + + yield $content; + } } } diff --git a/src/platform/src/Result/TextChunk.php b/src/platform/src/Result/TextChunk.php new file mode 100644 index 000000000..818f3462e --- /dev/null +++ b/src/platform/src/Result/TextChunk.php @@ -0,0 +1,33 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Result; + +/** + * @author Oscar Esteve + */ +final class TextChunk extends BaseResult implements \Stringable +{ + public function __construct( + private readonly string $content, + ) { + } + + public function __toString(): string + { + return $this->content; + } + + public function getContent(): string + { + return $this->content; + } +} diff --git a/src/platform/tests/Bridge/OpenAi/Gpt/ResultConverterStreamTest.php b/src/platform/tests/Bridge/OpenAi/Gpt/ResultConverterStreamTest.php new file mode 100644 index 000000000..306f9cfdf --- /dev/null +++ b/src/platform/tests/Bridge/OpenAi/Gpt/ResultConverterStreamTest.php @@ -0,0 +1,167 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\OpenAi\Gpt; + +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\OpenAi\Gpt\ResultConverter; +use Symfony\AI\Platform\Result\RawHttpResult; +use Symfony\AI\Platform\Result\StreamResult; +use Symfony\AI\Platform\Result\TextChunk; +use Symfony\AI\Platform\Result\ToolCallResult; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Component\HttpClient\MockHttpClient; +use Symfony\Component\HttpClient\Response\MockResponse; + +final class ResultConverterStreamTest extends TestCase +{ + public function testStreamTextDeltas() + { + $sseBody = '' + ."data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"},\"index\":0}]}\n\n" + ."data: {\"choices\":[{\"delta\":{\"content\":\"Hello \"},\"index\":0}]}\n\n" + ."data: {\"choices\":[{\"delta\":{\"content\":\"world\"},\"index\":0}]}\n\n" + ."data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"stop\"}]}\n\n" + ."data: [DONE]\n\n"; + + $mockClient = new MockHttpClient([ + new MockResponse($sseBody, [ + 'http_code' => 200, + 'response_headers' => [ + 'content-type' => 'text/event-stream', + ], + ]), + ]); + $esClient = new EventSourceHttpClient($mockClient); + $asyncResponse = $esClient->request('GET', 'http://localhost/stream'); + + $converter = new ResultConverter(); + $result = $converter->convert(new RawHttpResult($asyncResponse), ['stream' => true]); + + $this->assertInstanceOf(StreamResult::class, $result); + + /** @var TextChunk[] $chunks */ + $chunks = []; + $content = ''; + + foreach ($result->getContent() as $chunk) { + $chunks[] = $chunk; + $content .= $chunk; + } + + // Only text deltas are yielded; role and finish chunks are ignored + $this->assertSame('Hello world', $content); + $this->assertCount(2, $chunks); + $this->assertSame('Hello ', $chunks[0]->getContent()); + $this->assertEquals('http://localhost/stream', $chunks[0]->getRawResult()->getObject()->getInfo()['url']); + } + + public function testStreamToolCallsAreAssembledAndYielded() + { + // Simulate a tool call that is streamed in multiple argument parts + $sseBody = '' + ."data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"},\"index\":0}]}\n\n" + ."data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"id\":\"call_123\",\"type\":\"function\",\"function\":{\"name\":\"test_function\",\"arguments\":\"{\\\"arg1\\\": \\\"value1\\\"}\"}}]},\"index\":0}]}\n\n" + ."data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"tool_calls\"}]}\n\n" + ."data: {\"usage\":{\"prompt_tokens\":1039,\"completion_tokens\":10,\"total_tokens\":1049,\"prompt_tokens_details\":{\"cached_tokens\":0,\"audio_tokens\":0},\"completion_tokens_details\":{\"reasoning_tokens\":0,\"audio_tokens\":0,\"accepted_prediction_tokens\":0,\"rejected_prediction_tokens\":0}}}\n\n" + ."data: [DONE]\n\n"; + + $mockClient = new MockHttpClient([ + new MockResponse($sseBody, [ + 'http_code' => 200, + 'response_headers' => [ + 'content-type' => 'text/event-stream', + ], + ]), + ]); + $esClient = new EventSourceHttpClient($mockClient); + $asyncResponse = $esClient->request('GET', 'http://localhost/stream'); + + $converter = new ResultConverter(); + $result = $converter->convert(new RawHttpResult($asyncResponse), ['stream' => true]); + + $this->assertInstanceOf(StreamResult::class, $result); + + $yielded = []; + foreach ($result->getContent() as $delta) { + $yielded[] = $delta; + } + + // Expect only one yielded item and it should be a ToolCallResult + $this->assertCount(1, $yielded); + $this->assertInstanceOf(ToolCallResult::class, $yielded[0]); + /** @var ToolCallResult $toolCallResult */ + $toolCallResult = $yielded[0]; + $toolCalls = $toolCallResult->getContent(); + + $this->assertCount(1, $toolCalls); + $this->assertSame('call_123', $toolCalls[0]->getId()); + $this->assertSame('test_function', $toolCalls[0]->getName()); + $this->assertSame(['arg1' => 'value1'], $toolCalls[0]->getArguments()); + + // Get the token usage metadata from the result + $this->assertSame( + [ + 'prompt_tokens' => 1039, + 'completion_tokens' => 10, + 'total_tokens' => 1049, + 'prompt_tokens_details' => [ + 'cached_tokens' => 0, + 'audio_tokens' => 0, + ], + 'completion_tokens_details' => [ + 'reasoning_tokens' => 0, + 'audio_tokens' => 0, + 'accepted_prediction_tokens' => 0, + 'rejected_prediction_tokens' => 0, + ], + ], + $result->getMetadata()->get('usage') + ); + } + + public function testStreamTokenUsage() + { + $sseBody = '' + ."data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"role\":\"assistant\"},\"index\":0}]}\n\n" + ."data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\"Hello \"},\"index\":0}]}\n\n" + ."data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\"world\"},\"index\":0}]}\n\n" + ."data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"stop\"}]}\n\n" + ."data: {\"id\":\"chatcmpl-123\",\"usage\":{\"prompt_tokens\":1039,\"completion_tokens\":10,\"total_tokens\":1049,\"prompt_tokens_details\":{\"cached_tokens\":0,\"audio_tokens\":0},\"completion_tokens_details\":{\"reasoning_tokens\":0,\"audio_tokens\":0,\"accepted_prediction_tokens\":0,\"rejected_prediction_tokens\":0}}}\n\n" + ."data: [DONE]\n\n"; + + $mockClient = new MockHttpClient([ + new MockResponse($sseBody, [ + 'http_code' => 200, + 'response_headers' => [ + 'content-type' => 'text/event-stream', + ], + ]), + ]); + $esClient = new EventSourceHttpClient($mockClient); + $asyncResponse = $esClient->request('GET', 'http://localhost/stream'); + + $converter = new ResultConverter(); + $result = $converter->convert(new RawHttpResult($asyncResponse), ['stream' => true]); + + $this->assertInstanceOf(StreamResult::class, $result); + + $yielded = []; + foreach ($result->getContent() as $delta) { + $yielded[] = $delta; + } + $this->assertCount(2, $yielded); + /** @var TextChunk $chunk */ + $chunk = $yielded[0]; + $this->assertInstanceOf(TextChunk::class, $chunk); + $this->assertEquals('chatcmpl-123', $chunk->getMetadata()->get('id')); + } +}