|
13 | 13 | import io.modelcontextprotocol.client.McpAsyncClient; |
14 | 14 | import io.modelcontextprotocol.client.McpClient; |
15 | 15 | import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; |
16 | | -import io.modelcontextprotocol.client.transport.ServerParameters; |
| 16 | +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; |
17 | 17 | import io.modelcontextprotocol.client.transport.StdioClientTransport; |
| 18 | +import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper; |
18 | 19 | import io.modelcontextprotocol.spec.McpSchema; |
19 | 20 | import java.lang.reflect.ParameterizedType; |
20 | 21 | import java.lang.reflect.Type; |
21 | | -import java.net.http.HttpClient; |
22 | 22 | import java.time.Duration; |
23 | 23 | import java.util.ArrayList; |
24 | 24 | import java.util.Arrays; |
25 | 25 | import java.util.HashMap; |
26 | 26 | import java.util.List; |
27 | 27 | import java.util.Map; |
| 28 | +import java.util.Objects; |
28 | 29 | import java.util.Optional; |
29 | 30 | import java.util.function.Consumer; |
30 | 31 | import java.util.function.Function; |
|
35 | 36 | import org.springframework.ai.chat.client.ChatClient; |
36 | 37 | import org.springframework.ai.chat.client.ChatClient.ChatClientRequestSpec; |
37 | 38 | import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; |
38 | | -import org.springframework.ai.mcp.client.autoconfigure.NamedClientMcpTransport; |
39 | | -import org.springframework.ai.mcp.client.autoconfigure.properties.McpSseClientProperties.SseParameters; |
| 39 | +import org.springframework.ai.mcp.client.common.autoconfigure.NamedClientMcpTransport; |
40 | 40 | import org.springframework.ai.tool.StaticToolCallbackProvider; |
41 | 41 | import org.springframework.ai.tool.ToolCallbackProvider; |
42 | 42 | import org.springframework.core.ParameterizedTypeReference; |
@@ -234,24 +234,40 @@ protected ToolCallbackProvider getToolCallbackProvider() { |
234 | 234 | List<NamedClientMcpTransport> transports = new ArrayList<>(); |
235 | 235 | var stdioProperties = mcpClientConfiguration.stdioClientProperties(); |
236 | 236 | if (stdioProperties != null) { |
237 | | - for (Map.Entry<String, ServerParameters> serverParameters : stdioProperties.toServerParameters() |
| 237 | + for (var serverParameters : stdioProperties.toServerParameters() |
238 | 238 | .entrySet()) { |
239 | | - var transport = new StdioClientTransport(serverParameters.getValue()); |
| 239 | + var transport = new StdioClientTransport(serverParameters.getValue(), |
| 240 | + new JacksonMcpJsonMapper(objectMapper)); |
240 | 241 | transports.add(new NamedClientMcpTransport(serverParameters.getKey(), |
241 | 242 | transport)); |
242 | 243 | } |
243 | 244 | } |
| 245 | + |
244 | 246 | var sseProperties = mcpClientConfiguration.sseClientProperties(); |
245 | 247 | if (sseProperties != null) { |
246 | | - for (Map.Entry<String, SseParameters> serverParameters : sseProperties.connections() |
| 248 | + for (var serverParameters : sseProperties.connections() |
247 | 249 | .entrySet()) { |
248 | 250 | String baseUrl = serverParameters.getValue().url(); |
249 | | - String sseEndpoint = serverParameters.getValue().sseEndpoint() != null |
250 | | - ? serverParameters.getValue().sseEndpoint() : "/sse"; |
| 251 | + String sseEndpoint = Objects.requireNonNullElse(serverParameters.getValue().sseEndpoint(), |
| 252 | + "/sse"); |
251 | 253 | var transport = HttpClientSseClientTransport.builder(baseUrl) |
252 | 254 | .sseEndpoint(sseEndpoint) |
253 | | - .clientBuilder(HttpClient.newBuilder()) |
254 | | - .objectMapper(objectMapper) |
| 255 | + .jsonMapper(new JacksonMcpJsonMapper(objectMapper)) |
| 256 | + .build(); |
| 257 | + transports.add(new NamedClientMcpTransport(serverParameters.getKey(), transport)); |
| 258 | + } |
| 259 | + } |
| 260 | + |
| 261 | + var streamableHttpProperties = mcpClientConfiguration.streamableHttpClientProperties(); |
| 262 | + if (streamableHttpProperties != null) { |
| 263 | + for (var serverParameters : streamableHttpProperties.connections() |
| 264 | + .entrySet()) { |
| 265 | + String baseUrl = serverParameters.getValue().url(); |
| 266 | + String endpoint = Objects.requireNonNullElse(serverParameters.getValue().endpoint(), |
| 267 | + "/mcp"); |
| 268 | + var transport = HttpClientStreamableHttpTransport.builder(baseUrl) |
| 269 | + .endpoint(endpoint) |
| 270 | + .jsonMapper(new JacksonMcpJsonMapper(objectMapper)) |
255 | 271 | .build(); |
256 | 272 | transports.add(new NamedClientMcpTransport(serverParameters.getKey(), transport)); |
257 | 273 | } |
|
0 commit comments