11/*
2- * Copyright 2002-2023 the original author or authors.
2+ * Copyright 2002-2024 the original author or authors.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
4444import org .springframework .graphql .server .support .GraphQlWebSocketMessage ;
4545import org .springframework .http .HttpHeaders ;
4646import org .springframework .http .codec .CodecConfigurer ;
47+ import org .springframework .lang .Nullable ;
4748import org .springframework .util .Assert ;
4849import org .springframework .util .CollectionUtils ;
4950import org .springframework .web .reactive .socket .CloseStatus ;
@@ -72,10 +73,13 @@ public class GraphQlWebSocketHandler implements WebSocketHandler {
7273
7374 private final WebSocketGraphQlInterceptor webSocketInterceptor ;
7475
75- private final WebSocketCodecDelegate webSocketCodecDelegate ;
76+ private final WebSocketCodecDelegate codecDelegate ;
7677
7778 private final Duration initTimeoutDuration ;
7879
80+ @ Nullable
81+ private final Duration keepAliveDuration ;
82+
7983
8084 /**
8185 * Create a new instance.
@@ -87,12 +91,30 @@ public class GraphQlWebSocketHandler implements WebSocketHandler {
8791 public GraphQlWebSocketHandler (
8892 WebGraphQlHandler graphQlHandler , CodecConfigurer codecConfigurer , Duration connectionInitTimeout ) {
8993
94+ this (graphQlHandler , codecConfigurer , connectionInitTimeout , null );
95+ }
96+
97+ /**
98+ * Create a new instance.
99+ * @param graphQlHandler common handler for GraphQL over WebSocket requests
100+ * @param codecConfigurer codec configurer for JSON encoding and decoding
101+ * @param connectionInitTimeout how long to wait after the establishment of
102+ * the WebSocket for the {@code "connection_ini"} message from the client.
103+ * @param keepAliveDuration how frequently to send ping messages; if not
104+ * set then ping messages are not sent.
105+ * @since 1.3
106+ */
107+ public GraphQlWebSocketHandler (
108+ WebGraphQlHandler graphQlHandler , CodecConfigurer codecConfigurer ,
109+ Duration connectionInitTimeout , @ Nullable Duration keepAliveDuration ) {
110+
90111 Assert .notNull (graphQlHandler , "WebGraphQlHandler is required" );
91112
92113 this .graphQlHandler = graphQlHandler ;
93114 this .webSocketInterceptor = this .graphQlHandler .getWebSocketInterceptor ();
94- this .webSocketCodecDelegate = new WebSocketCodecDelegate (codecConfigurer );
115+ this .codecDelegate = new WebSocketCodecDelegate (codecConfigurer );
95116 this .initTimeoutDuration = connectionInitTimeout ;
117+ this .keepAliveDuration = keepAliveDuration ;
96118 }
97119
98120
@@ -137,7 +159,7 @@ public Mono<Void> handle(WebSocketSession session) {
137159 .subscribe ();
138160
139161 return session .send (session .receive ().flatMap ((webSocketMessage ) -> {
140- GraphQlWebSocketMessage message = this .webSocketCodecDelegate .decode (webSocketMessage );
162+ GraphQlWebSocketMessage message = this .codecDelegate .decode (webSocketMessage );
141163 String id = message .getId ();
142164 Map <String , Object > payload = message .getPayload ();
143165 switch (message .resolvedType ()) {
@@ -159,7 +181,7 @@ public Mono<Void> handle(WebSocketSession session) {
159181 .doOnTerminate (() -> subscriptions .remove (id ));
160182 }
161183 case PING -> {
162- return Flux .just (this .webSocketCodecDelegate .encode (session , GraphQlWebSocketMessage .pong (null )));
184+ return Flux .just (this .codecDelegate .encode (session , GraphQlWebSocketMessage .pong (null )));
163185 }
164186 case COMPLETE -> {
165187 if (id != null ) {
@@ -176,11 +198,16 @@ public Mono<Void> handle(WebSocketSession session) {
176198 if (!connectionInitPayloadRef .compareAndSet (null , payload )) {
177199 return GraphQlStatus .close (session , GraphQlStatus .TOO_MANY_INIT_REQUESTS_STATUS );
178200 }
179- return this .webSocketInterceptor .handleConnectionInitialization (sessionInfo , payload )
201+ Flux < WebSocketMessage > flux = this .webSocketInterceptor .handleConnectionInitialization (sessionInfo , payload )
180202 .defaultIfEmpty (Collections .emptyMap ())
181- .map ((ackPayload ) -> this .webSocketCodecDelegate .encodeConnectionAck (session , ackPayload ))
182- .flux ()
183- .onErrorResume ((ex ) -> GraphQlStatus .close (session , GraphQlStatus .UNAUTHORIZED_STATUS ));
203+ .map ((ackPayload ) -> this .codecDelegate .encodeConnectionAck (session , ackPayload ))
204+ .flux ();
205+ if (this .keepAliveDuration != null ) {
206+ flux = flux .mergeWith (Flux .interval (this .keepAliveDuration , this .keepAliveDuration )
207+ .filter ((aLong ) -> !this .codecDelegate .checkMessagesEncodedAndClear ())
208+ .map ((aLong ) -> this .codecDelegate .encode (session , GraphQlWebSocketMessage .ping (null ))));
209+ }
210+ return flux .onErrorResume ((ex ) -> GraphQlStatus .close (session , GraphQlStatus .UNAUTHORIZED_STATUS ));
184211 }
185212 default -> {
186213 return GraphQlStatus .close (session , GraphQlStatus .INVALID_MESSAGE_STATUS );
@@ -218,14 +245,14 @@ private Flux<WebSocketMessage> handleResponse(WebSocketSession session, String i
218245 }
219246
220247 return responseFlux
221- .map ((responseMap ) -> this .webSocketCodecDelegate .encodeNext (session , id , responseMap ))
222- .concatWith (Mono .fromCallable (() -> this .webSocketCodecDelegate .encodeComplete (session , id )))
248+ .map ((responseMap ) -> this .codecDelegate .encodeNext (session , id , responseMap ))
249+ .concatWith (Mono .fromCallable (() -> this .codecDelegate .encodeComplete (session , id )))
223250 .onErrorResume ((ex ) -> {
224251 if (ex instanceof SubscriptionExistsException ) {
225252 CloseStatus status = new CloseStatus (4409 , "Subscriber for " + id + " already exists" );
226253 return GraphQlStatus .close (session , status );
227254 }
228- return Mono .fromCallable (() -> this .webSocketCodecDelegate .encodeError (session , id , ex ));
255+ return Mono .fromCallable (() -> this .codecDelegate .encodeError (session , id , ex ));
229256 });
230257 }
231258
0 commit comments