Skip to content

Commit 25d5219

Browse files
committed
feat(ws): adds WebSocketClientTransport
1 parent 391ec19 commit 25d5219

4 files changed

Lines changed: 383 additions & 0 deletions

File tree

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.client.transport;
6+
7+
import java.net.URI;
8+
import java.net.http.HttpClient;
9+
import java.net.http.WebSocket;
10+
import java.time.Duration;
11+
import java.util.concurrent.CompletableFuture;
12+
import java.util.concurrent.CompletionStage;
13+
import java.util.concurrent.atomic.AtomicReference;
14+
import java.util.function.Consumer;
15+
import java.util.function.Function;
16+
17+
import org.slf4j.Logger;
18+
import org.slf4j.LoggerFactory;
19+
20+
import com.fasterxml.jackson.core.type.TypeReference;
21+
import com.fasterxml.jackson.databind.ObjectMapper;
22+
23+
import io.modelcontextprotocol.spec.McpClientTransport;
24+
import io.modelcontextprotocol.spec.McpSchema;
25+
import io.modelcontextprotocol.util.Assert;
26+
import reactor.core.publisher.Mono;
27+
import reactor.core.publisher.Sinks;
28+
import reactor.util.retry.Retry;
29+
30+
/**
31+
* The WebSocket (WS) implementation of the
32+
* {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with WS
33+
* transport specification, using Java's HttpClient.
34+
*
35+
* @author Aliaksei Darafeyeu
36+
*/
37+
public class WebSocketClientTransport implements McpClientTransport {
38+
39+
private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketClientTransport.class);
40+
41+
private final HttpClient httpClient;
42+
43+
private final ObjectMapper objectMapper;
44+
45+
private final URI uri;
46+
47+
private final AtomicReference<WebSocket> webSocketRef = new AtomicReference<>();
48+
49+
private final AtomicReference<TransportState> state = new AtomicReference<>(TransportState.DISCONNECTED);
50+
51+
private final Sinks.Many<Throwable> errorSink = Sinks.many().multicast().onBackpressureBuffer();
52+
53+
/**
54+
* The constructor for the WebSocketClientTransport.
55+
* @param uri the URI to connect to
56+
* @param clientBuilder the HttpClient builder
57+
* @param objectMapper the ObjectMapper for JSON serialization/deserialization
58+
*/
59+
WebSocketClientTransport(final URI uri, final HttpClient.Builder clientBuilder, final ObjectMapper objectMapper) {
60+
this.uri = uri;
61+
this.httpClient = clientBuilder.build();
62+
this.objectMapper = objectMapper;
63+
}
64+
65+
/**
66+
* Creates a new WebSocketClientTransport instance with the specified URI.
67+
* @param uri the URI to connect to
68+
* @return a new Builder instance
69+
*/
70+
public static Builder builder(final URI uri) {
71+
return new Builder().uri(uri);
72+
}
73+
74+
/**
75+
* The state of the Transport connection.
76+
*/
77+
public enum TransportState {
78+
79+
DISCONNECTED, CONNECTING, CONNECTED, CLOSED
80+
81+
}
82+
83+
/**
84+
* A builder for creating instances of WebSocketClientTransport.
85+
*/
86+
public static class Builder {
87+
88+
private URI uri;
89+
90+
private final HttpClient.Builder clientBuilder = HttpClient.newBuilder()
91+
.version(HttpClient.Version.HTTP_1_1)
92+
.connectTimeout(Duration.ofSeconds(10));
93+
94+
private ObjectMapper objectMapper = new ObjectMapper();
95+
96+
public Builder uri(final URI uri) {
97+
this.uri = uri;
98+
return this;
99+
}
100+
101+
public Builder customizeClient(final Consumer<HttpClient.Builder> clientCustomizer) {
102+
Assert.notNull(clientCustomizer, "clientCustomizer must not be null");
103+
clientCustomizer.accept(clientBuilder);
104+
return this;
105+
}
106+
107+
public Builder objectMapper(final ObjectMapper objectMapper) {
108+
this.objectMapper = objectMapper;
109+
return this;
110+
}
111+
112+
public WebSocketClientTransport build() {
113+
return new WebSocketClientTransport(uri, clientBuilder, objectMapper);
114+
}
115+
116+
}
117+
118+
public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
119+
if (!state.compareAndSet(TransportState.DISCONNECTED, TransportState.CONNECTING)) {
120+
return Mono.error(new IllegalStateException("WebSocket is already connecting or connected"));
121+
}
122+
123+
return Mono.fromFuture(httpClient.newWebSocketBuilder().buildAsync(uri, new WebSocket.Listener() {
124+
private final StringBuilder messageBuffer = new StringBuilder();
125+
126+
@Override
127+
public void onOpen(WebSocket webSocket) {
128+
webSocketRef.set(webSocket);
129+
state.set(TransportState.CONNECTED);
130+
}
131+
132+
@Override
133+
public CompletionStage<?> onText(WebSocket webSocket, CharSequence data, boolean last) {
134+
messageBuffer.append(data);
135+
if (last) {
136+
final String fullMessage = messageBuffer.toString();
137+
messageBuffer.setLength(0);
138+
try {
139+
final McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper,
140+
fullMessage);
141+
handler.apply(Mono.just(msg)).subscribe();
142+
}
143+
catch (Exception e) {
144+
errorSink.tryEmitNext(e);
145+
LOGGER.error("Error processing WS event", e);
146+
}
147+
}
148+
149+
webSocket.request(1);
150+
return CompletableFuture.completedFuture(null);
151+
}
152+
153+
@Override
154+
public void onError(WebSocket webSocket, Throwable error) {
155+
errorSink.tryEmitNext(error);
156+
state.set(TransportState.CLOSED);
157+
LOGGER.error("WS connection error", error);
158+
}
159+
160+
@Override
161+
public CompletionStage<?> onClose(WebSocket webSocket, int statusCode, String reason) {
162+
state.set(TransportState.CLOSED);
163+
return CompletableFuture.completedFuture(null);
164+
}
165+
166+
})).then();
167+
}
168+
169+
@Override
170+
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
171+
172+
return Mono.defer(() -> {
173+
WebSocket ws = webSocketRef.get();
174+
if (ws == null && state.get() == TransportState.CONNECTING) {
175+
return Mono.error(new IllegalStateException("WebSocket is connecting."));
176+
}
177+
178+
if (ws == null || state.get() == TransportState.DISCONNECTED || state.get() == TransportState.CLOSED) {
179+
return Mono.error(new IllegalStateException("WebSocket is closed."));
180+
}
181+
182+
try {
183+
String json = objectMapper.writeValueAsString(message);
184+
return Mono.fromFuture(ws.sendText(json, true)).then();
185+
}
186+
catch (Exception e) {
187+
return Mono.error(e);
188+
}
189+
}).retryWhen(Retry.backoff(3, Duration.ofSeconds(3)).filter(err -> {
190+
if (err instanceof IllegalStateException) {
191+
return err.getMessage().equals("WebSocket is connecting.");
192+
}
193+
return true;
194+
})).onErrorResume(e -> {
195+
LOGGER.error("Failed to send message after retries", e);
196+
errorSink.tryEmitNext(e);
197+
return Mono.error(new IllegalStateException("WebSocket send failed after retries", e));
198+
});
199+
200+
}
201+
202+
@Override
203+
public Mono<Void> closeGracefully() {
204+
WebSocket webSocket = webSocketRef.getAndSet(null);
205+
if (webSocket != null && state.get() == TransportState.CONNECTED) {
206+
state.set(TransportState.CLOSED);
207+
return Mono.fromFuture(webSocket.sendClose(WebSocket.NORMAL_CLOSURE, "Closing")).then();
208+
}
209+
return Mono.empty();
210+
}
211+
212+
@Override
213+
public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
214+
return objectMapper.convertValue(data, typeRef);
215+
}
216+
217+
public TransportState getState() {
218+
return state.get();
219+
}
220+
221+
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.client.transport;
6+
7+
import static org.junit.jupiter.api.Assertions.assertEquals;
8+
9+
import java.net.URI;
10+
import java.util.List;
11+
12+
import org.junit.jupiter.api.AfterAll;
13+
import org.junit.jupiter.api.BeforeAll;
14+
import org.junit.jupiter.api.BeforeEach;
15+
import org.junit.jupiter.api.Test;
16+
import org.testcontainers.containers.GenericContainer;
17+
import org.testcontainers.images.builder.ImageFromDockerfile;
18+
19+
import io.modelcontextprotocol.spec.McpSchema;
20+
import reactor.core.publisher.Mono;
21+
import reactor.test.StepVerifier;
22+
23+
/**
24+
* Tests for the {@link WebSocketClientTransport} class.
25+
*
26+
* @author Aliaksei Darafeyeu
27+
*/
28+
class WebSocketClientTransportTest {
29+
30+
private static GenericContainer<?> wsContainer;
31+
32+
private static URI websocketUri;
33+
34+
private WebSocketClientTransport transport;
35+
36+
@BeforeAll
37+
static void startContainer() {
38+
wsContainer = new GenericContainer<>(
39+
new ImageFromDockerfile().withFileFromClasspath("server.js", "ws/server.js")
40+
.withFileFromClasspath("Dockerfile", "ws/Dockerfile"))
41+
.withExposedPorts(8080);
42+
43+
wsContainer.start();
44+
45+
int port = wsContainer.getMappedPort(8080);
46+
websocketUri = URI.create("ws://localhost:" + port);
47+
}
48+
49+
@BeforeEach
50+
public void setUp() {
51+
transport = WebSocketClientTransport.builder(websocketUri).build();
52+
}
53+
54+
@AfterAll
55+
static void tearDown() {
56+
wsContainer.stop();
57+
}
58+
59+
@Test
60+
void testConnectSuccessfully() {
61+
// Try to connect to the WebSocket server
62+
Mono<Void> connection = transport.connect(message -> Mono.empty());
63+
64+
// Wait for the connection to complete
65+
StepVerifier.create(connection).expectComplete().verify();
66+
67+
// Ensure that connection is established
68+
assertEquals(WebSocketClientTransport.TransportState.CONNECTED, transport.getState());
69+
}
70+
71+
@Test
72+
void testSendMessage() {
73+
// Connect to the server
74+
Mono<Void> connection = transport.connect(message -> Mono.empty());
75+
76+
// Ensure connection is successful
77+
StepVerifier.create(connection).expectComplete().verify();
78+
79+
// Create a simple message to send
80+
var messageRequest = new McpSchema.CreateMessageRequest(
81+
List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))),
82+
null, null, null, null, 0, null, null);
83+
McpSchema.JSONRPCMessage message = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION,
84+
McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, "test-id", messageRequest);
85+
86+
// Send a message to the server
87+
Mono<Void> sendMessage = transport.sendMessage(message);
88+
89+
// Ensure message is sent successfully
90+
StepVerifier.create(sendMessage).expectComplete().verify();
91+
}
92+
93+
@Test
94+
void testCloseConnectionGracefully() {
95+
Mono<Void> connection = transport.connect(message -> Mono.empty());
96+
97+
StepVerifier.create(connection).expectComplete().verify();
98+
99+
// Close the connection gracefully
100+
Mono<Void> closeConnection = transport.closeGracefully();
101+
102+
// Verify that the connection is closed successfully
103+
StepVerifier.create(closeConnection).expectComplete().verify();
104+
105+
assertEquals(WebSocketClientTransport.TransportState.CLOSED, transport.getState());
106+
}
107+
108+
@Test
109+
void testSendMessageAfterConnectionClosed() {
110+
// Send a message before connection is established
111+
// Create a simple message to send
112+
var messageRequest = new McpSchema.CreateMessageRequest(
113+
List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))),
114+
null, null, null, null, 0, null, null);
115+
McpSchema.JSONRPCMessage message = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION,
116+
McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, "test-id", messageRequest);
117+
118+
Mono<Void> sendMessageBeforeConnect = transport.sendMessage(message);
119+
120+
// Verify that the transport returns an error because the connection is closed
121+
StepVerifier.create(sendMessageBeforeConnect).expectError(IllegalStateException.class).verify();
122+
}
123+
124+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Use a Node.js base image
2+
FROM node:14
3+
4+
# Set the working directory inside the container
5+
WORKDIR /usr/src/app
6+
7+
# Copy the server.js file into the container
8+
COPY server.js /usr/src/app/
9+
10+
# Install dependencies (e.g., the ws package)
11+
RUN npm init -y && npm install ws
12+
13+
# Expose the port for WebSocket (e.g., 8080)
14+
EXPOSE 8080
15+
16+
# Command to run the WebSocket server
17+
CMD ["node", "server.js"]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Import the WebSocket package
2+
const WebSocket = require('ws');
3+
4+
// Set up the WebSocket server to listen on port 8080
5+
const wss = new WebSocket.Server({ port: 8080 });
6+
7+
// When a new WebSocket connection is established
8+
wss.on('connection', function connection(ws) {
9+
console.log('New client connected');
10+
11+
// When a message is received from the client
12+
ws.on('message', function incoming(message) {
13+
console.log('received: %s', message);
14+
});
15+
16+
// Send a welcome message to the client
17+
ws.send('Welcome to the WebSocket server!');
18+
});
19+
20+
// Log the WebSocket server start
21+
console.log('WebSocket server is listening on ws://localhost:8080');

0 commit comments

Comments
 (0)