@@ -12,6 +12,7 @@ import (
1212 "log/slog"
1313 "sync"
1414 "sync/atomic"
15+ "time"
1516
1617 "google.golang.org/grpc/codes"
1718 "google.golang.org/grpc/status"
@@ -31,18 +32,20 @@ import (
3132
3233var _ commandService = (* CommandService )(nil )
3334
34- const (
35- createConnectionMaxElapsedTime = 0
36- )
35+ const createConnectionMaxElapsedTime = 0
36+
37+ var timeToWaitBetweenChecks = 5 * time . Second
3738
3839type (
3940 CommandService struct {
4041 commandServiceClient mpi.CommandServiceClient
4142 subscribeClient mpi.CommandService_SubscribeClient
4243 agentConfig * config.Config
4344 isConnected * atomic.Bool
45+ connectionResetInProgress * atomic.Bool
4446 subscribeChannel chan * mpi.ManagementPlaneRequest
4547 configApplyRequestQueue map [string ][]* mpi.ManagementPlaneRequest // key is the instance ID
48+ requestsInProgress map [string ]* mpi.ManagementPlaneRequest // key is the correlation ID
4649 resource * mpi.Resource
4750 subscribeClientMutex sync.Mutex
4851 configApplyRequestQueueMutex sync.Mutex
@@ -56,19 +59,16 @@ func NewCommandService(
5659 agentConfig * config.Config ,
5760 subscribeChannel chan * mpi.ManagementPlaneRequest ,
5861) * CommandService {
59- isConnected := & atomic.Bool {}
60- isConnected .Store (false )
61-
62- commandService := & CommandService {
63- commandServiceClient : commandServiceClient ,
64- agentConfig : agentConfig ,
65- isConnected : isConnected ,
66- subscribeChannel : subscribeChannel ,
67- configApplyRequestQueue : make (map [string ][]* mpi.ManagementPlaneRequest ),
68- resource : & mpi.Resource {},
62+ return & CommandService {
63+ commandServiceClient : commandServiceClient ,
64+ agentConfig : agentConfig ,
65+ isConnected : & atomic.Bool {},
66+ connectionResetInProgress : & atomic.Bool {},
67+ subscribeChannel : subscribeChannel ,
68+ configApplyRequestQueue : make (map [string ][]* mpi.ManagementPlaneRequest ),
69+ resource : & mpi.Resource {},
70+ requestsInProgress : make (map [string ]* mpi.ManagementPlaneRequest ),
6971 }
70-
71- return commandService
7272}
7373
7474func (cs * CommandService ) IsConnected () bool {
@@ -181,6 +181,11 @@ func (cs *CommandService) SendDataPlaneResponse(ctx context.Context, response *m
181181 return err
182182 }
183183
184+ if response .GetCommandResponse ().GetStatus () == mpi .CommandResponse_COMMAND_STATUS_OK ||
185+ response .GetCommandResponse ().GetStatus () == mpi .CommandResponse_COMMAND_STATUS_FAILURE {
186+ delete (cs .requestsInProgress , response .GetMessageMeta ().GetCorrelationId ())
187+ }
188+
184189 return backoff .Retry (
185190 cs .sendDataPlaneResponseCallback (ctx , response ),
186191 backoffHelpers .Context (backOffCtx , cs .agentConfig .Client .Backoff ),
@@ -272,6 +277,33 @@ func (cs *CommandService) CreateConnection(
272277}
273278
274279func (cs * CommandService ) UpdateClient (ctx context.Context , client mpi.CommandServiceClient ) error {
280+ cs .connectionResetInProgress .Store (true )
281+ defer cs .connectionResetInProgress .Store (false )
282+
283+ // Wait for any in-progress requests to complete before updating the client
284+ start := time .Now ()
285+
286+ for len (cs .requestsInProgress ) > 0 {
287+ if time .Since (start ) >= cs .agentConfig .Client .Grpc .ConnectionResetTimeout {
288+ slog .WarnContext (
289+ ctx ,
290+ "Timeout reached while waiting for in-progress requests to complete" ,
291+ "number_of_requests_in_progress" , len (cs .requestsInProgress ),
292+ )
293+
294+ break
295+ }
296+
297+ slog .InfoContext (
298+ ctx ,
299+ "Waiting for in-progress requests to complete before updating command service gRPC client" ,
300+ "max_wait_time" , cs .agentConfig .Client .Grpc .ConnectionResetTimeout ,
301+ "number_of_requests_in_progress" , len (cs .requestsInProgress ),
302+ )
303+
304+ time .Sleep (timeToWaitBetweenChecks )
305+ }
306+
275307 cs .subscribeClientMutex .Lock ()
276308 cs .commandServiceClient = client
277309 cs .subscribeClientMutex .Unlock ()
@@ -379,7 +411,7 @@ func (cs *CommandService) sendResponseForQueuedConfigApplyRequests(
379411 cs .configApplyRequestQueue [instanceID ] = cs .configApplyRequestQueue [instanceID ][indexOfConfigApplyRequest + 1 :]
380412 slog .DebugContext (ctx , "Removed config apply requests from queue" , "queue" , cs .configApplyRequestQueue [instanceID ])
381413
382- if len (cs .configApplyRequestQueue [instanceID ]) > 0 {
414+ if len (cs .configApplyRequestQueue [instanceID ]) > 0 && ! cs . connectionResetInProgress . Load () {
383415 cs .subscribeChannel <- cs .configApplyRequestQueue [instanceID ][len (cs .configApplyRequestQueue [instanceID ])- 1 ]
384416 }
385417
@@ -423,6 +455,12 @@ func (cs *CommandService) dataPlaneHealthCallback(
423455//nolint:revive // cognitive complexity is 18
424456func (cs * CommandService ) receiveCallback (ctx context.Context ) func () error {
425457 return func () error {
458+ if cs .connectionResetInProgress .Load () {
459+ slog .DebugContext (ctx , "Connection reset in progress, skipping receive from subscribe stream" )
460+
461+ return nil
462+ }
463+
426464 cs .subscribeClientMutex .Lock ()
427465
428466 if cs .subscribeClient == nil {
@@ -463,6 +501,8 @@ func (cs *CommandService) receiveCallback(ctx context.Context) func() error {
463501 default :
464502 cs .subscribeChannel <- request
465503 }
504+
505+ cs .requestsInProgress [request .GetMessageMeta ().GetCorrelationId ()] = request
466506 }
467507
468508 return nil
@@ -495,7 +535,7 @@ func (cs *CommandService) queueConfigApplyRequests(ctx context.Context, request
495535
496536 instanceID := request .GetConfigApplyRequest ().GetOverview ().GetConfigVersion ().GetInstanceId ()
497537 cs .configApplyRequestQueue [instanceID ] = append (cs .configApplyRequestQueue [instanceID ], request )
498- if len (cs .configApplyRequestQueue [instanceID ]) == 1 {
538+ if len (cs .configApplyRequestQueue [instanceID ]) == 1 && ! cs . connectionResetInProgress . Load () {
499539 cs .subscribeChannel <- request
500540 } else {
501541 slog .DebugContext (
0 commit comments