@@ -23,6 +23,7 @@ import org.springframework.ai.model.tool.ToolCallingManager
2323import org.springframework.ai.model.tool.ToolExecutionResult
2424import org.springframework.util.MimeType
2525import org.springframework.util.MimeTypeUtils
26+ import reactor.core.publisher.Flux
2627import java.util.*
2728
2829class OpenAIChatModel (
@@ -35,6 +36,81 @@ class OpenAIChatModel(
3536 private val toolExecutionEligibilityPredicate = DefaultToolExecutionEligibilityPredicate ()
3637
3738 override fun call (prompt : Prompt ): ChatResponse {
39+ val requestPrompt = buildRequestPrompt(prompt)
40+ return internalCall(requestPrompt, null )
41+ }
42+
43+ private fun internalCall (prompt : Prompt , previousChatResponse : ChatResponse ? ): ChatResponse {
44+ val completion = openAIClient.chat().completions().create(buildChatCompletionCreateParams(prompt))
45+ val generations = completion.choices().map { choice ->
46+ buildGeneration(
47+ choice, mapOf (
48+ " id" to completion.id(),
49+ " index" to choice.index(),
50+ " finishReason" to choice.finishReason().value().name
51+ )
52+ )
53+ }
54+ val response = ChatResponse .builder().generations(generations).build()
55+ if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.options, response)) {
56+ val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response)
57+ if (toolExecutionResult.returnDirect()) {
58+ return ChatResponse .builder()
59+ .from(response)
60+ .generations(ToolExecutionResult .buildGenerations(toolExecutionResult))
61+ .build()
62+ } else {
63+ return this .internalCall(
64+ Prompt (toolExecutionResult.conversationHistory(), prompt.options),
65+ response
66+ )
67+ }
68+ }
69+ return response
70+ }
71+
72+ override fun stream (prompt : Prompt ): Flux <ChatResponse > {
73+ val requestPrompt = buildRequestPrompt(prompt)
74+ return internalStream(requestPrompt, null )
75+ }
76+
77+ private fun internalStream (prompt : Prompt , previousChatResponse : ChatResponse ? ): Flux <ChatResponse > {
78+ return Flux .fromStream(openAIClient.chat().completions().createStreaming(buildChatCompletionCreateParams(prompt)).stream().map { chunk ->
79+ val generations = chunk.choices().map { choice ->
80+ buildGeneration(
81+ choice, mapOf (
82+ " id" to chunk.id(),
83+ " index" to choice.index(),
84+ " finishReason" to choice.finishReason().map { reason -> reason.value().name }.orElse(" " )
85+ )
86+ )
87+ }.toList()
88+ ChatResponse .builder().generations(generations).build()
89+ }).flatMap { response ->
90+ if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.options, response)) {
91+ Flux .defer {
92+ val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response)
93+ if (toolExecutionResult.returnDirect()) {
94+ Flux .just(
95+ ChatResponse .builder()
96+ .from(response)
97+ .generations(ToolExecutionResult .buildGenerations(toolExecutionResult))
98+ .build()
99+ )
100+ } else {
101+ this .internalStream(
102+ Prompt (toolExecutionResult.conversationHistory(), prompt.options),
103+ response
104+ )
105+ }
106+ }
107+ } else {
108+ Flux .just(response)
109+ }
110+ }
111+ }
112+
113+ private fun buildRequestPrompt (prompt : Prompt ): Prompt {
38114 var runtimeOptions: OpenAiChatOptions ? = null
39115 if (prompt.options != null ) {
40116 runtimeOptions = if (prompt.options is ToolCallingChatOptions ) {
@@ -81,10 +157,10 @@ class OpenAIChatModel(
81157 requestOptions.toolCallbacks = this .defaultOptions.toolCallbacks
82158 requestOptions.toolContext = this .defaultOptions.toolContext
83159 }
84- return internalCall( prompt, null )
160+ return prompt.mutate().chatOptions(requestOptions).build( )
85161 }
86162
87- private fun internalCall (prompt : Prompt , previousChatResponse : ChatResponse ? ): ChatResponse {
163+ private fun buildChatCompletionCreateParams (prompt : Prompt ): ChatCompletionCreateParams {
88164 val paramsBuilder = ChatCompletionCreateParams .builder()
89165
90166 prompt.instructions.forEach { message ->
@@ -211,35 +287,10 @@ class OpenAIChatModel(
211287 paramsBuilder.tools(tools)
212288 }
213289 }
214-
215- val completion = openAIClient.chat().completions().create(paramsBuilder.build())
216- val generations = completion.choices().map { choice ->
217- buildGeneration(
218- choice, mapOf (
219- " id" to completion.id(),
220- " index" to choice.index(),
221- " finishReason" to choice.finishReason().value().name
222- )
223- )
224- }
225- val response = ChatResponse .builder().generations(generations).build()
226- if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.options, response)) {
227- val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response)
228- if (toolExecutionResult.returnDirect()) {
229- return ChatResponse .builder()
230- .from(response)
231- .generations(ToolExecutionResult .buildGenerations(toolExecutionResult))
232- .build()
233- } else {
234- return this .internalCall(
235- Prompt (toolExecutionResult.conversationHistory(), prompt.options),
236- response
237- )
238- }
239- }
240- return response
290+ return paramsBuilder.build()
241291 }
242292
293+
243294 private fun buildGeneration (
244295 choice : ChatCompletion .Choice ,
245296 metadata : Map <String , Any >
@@ -261,6 +312,28 @@ class OpenAIChatModel(
261312 return Generation (assistantMessage, metadataBuilder.build())
262313 }
263314
315+ private fun buildGeneration (
316+ choice : ChatCompletionChunk .Choice ,
317+ metadata : Map <String , Any >
318+ ): Generation {
319+ val toolCalls = choice.delta().toolCalls().map { calls ->
320+ calls.filter { it.id().isPresent }
321+ .map { toolCall ->
322+ AssistantMessage .ToolCall (
323+ toolCall.id().orElse(" " ),
324+ " function" ,
325+ toolCall.function().flatMap { it.name() }.orElse(" " ),
326+ toolCall.function().flatMap { it.arguments() }.orElse(" " )
327+ )
328+ }
329+ }.orElse(listOf ())
330+ val finishReason = choice.finishReason().map { it.value().name }.orElse(" " )
331+ val metadataBuilder = ChatGenerationMetadata .builder().finishReason(finishReason)
332+ val assistantMessage =
333+ AssistantMessage (choice.delta().content().orElse(" " ), metadata, toolCalls, listOf ())
334+ return Generation (assistantMessage, metadataBuilder.build())
335+ }
336+
264337 private fun fromAudioData (audioData : Any ): String {
265338 return if (audioData is ByteArray ) {
266339 Base64 .getEncoder().encodeToString(audioData)
0 commit comments