Skip to content

Commit 7673626

Browse files
committed
fix(server): prevents race condition w/ maps
1 parent e6c5776 commit 7673626

1 file changed

Lines changed: 52 additions & 24 deletions

File tree

internal/server/server.go

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -132,26 +132,36 @@ func (s *Server) SubmitTask(ctx context.Context, req *proto.SubmitTaskRequest) (
132132

133133
// GetTaskStatus retrieves the status of a task by its ID
134134
func (s *Server) GetTaskStatus(ctx context.Context, req *proto.GetTaskStatusRequest) (*proto.GetTaskStatusResponse, error) {
135-
task, err := s.findTask(req.TaskId)
136-
if err != nil {
137-
return nil, err
135+
s.tasksMux.RLock()
136+
task, exists := s.tasks[req.TaskId]
137+
if !exists {
138+
s.tasksMux.RUnlock()
139+
return nil, status.Errorf(codes.NotFound, "task %s not found", req.TaskId)
138140
}
141+
taskStatus := task.Status
142+
s.tasksMux.RUnlock()
139143

140-
return &proto.GetTaskStatusResponse{Status: proto.TaskStatus(task.Status)}, nil
144+
return &proto.GetTaskStatusResponse{Status: proto.TaskStatus(taskStatus)}, nil
141145
}
142146

143147
// GetTaskResult retrieves the result of a completed task by its ID
144148
func (s *Server) GetTaskResult(ctx context.Context, req *proto.GetTaskResultRequest) (*proto.GetTaskResultResponse, error) {
145-
task, err := s.findTask(req.TaskId)
146-
if err != nil {
147-
return nil, err
149+
s.tasksMux.RLock()
150+
task, exists := s.tasks[req.TaskId]
151+
if !exists {
152+
s.tasksMux.RUnlock()
153+
return nil, status.Errorf(codes.NotFound, "task %s not found", req.TaskId)
148154
}
149155

150156
if task.Status != COMPLETED {
157+
s.tasksMux.RUnlock()
151158
return nil, status.Errorf(codes.FailedPrecondition, "task %s not completed yet", req.TaskId)
152159
}
153160

154-
return &proto.GetTaskResultResponse{Result: task.Result}, nil
161+
result := task.Result
162+
s.tasksMux.RUnlock()
163+
164+
return &proto.GetTaskResultResponse{Result: result}, nil
155165
}
156166

157167
// RegisterWorker handles worker registration requests
@@ -169,28 +179,33 @@ func (s *Server) RegisterWorker(ctx context.Context, req *proto.RegisterWorkerRe
169179

170180
// Heartbeat processes heartbeat messages from workers
171181
func (s *Server) Heartbeat(ctx context.Context, req *proto.HeartbeatRequest) (*proto.HeartbeatResponse, error) {
172-
worker, err := s.findWorker(req.WorkerId)
173-
if err != nil {
174-
return nil, err
182+
s.workersMux.Lock()
183+
worker, exists := s.workers[req.WorkerId]
184+
if !exists {
185+
s.workersMux.Unlock()
186+
return nil, status.Errorf(codes.NotFound, "worker %s not found", req.WorkerId)
175187
}
176188

177189
worker.LastHeartbeat = time.Now()
178190
worker.CurrentLoad = int(req.CurrentLoad)
191+
currentLoad := worker.CurrentLoad
192+
s.workersMux.Unlock()
179193

180-
return &proto.HeartbeatResponse{Success: true, CurrentLoad: int32(worker.CurrentLoad)}, nil
194+
return &proto.HeartbeatResponse{Success: true, CurrentLoad: int32(currentLoad)}, nil
181195
}
182196

183197
// SubmitResult processes the result submission from workers
184198
func (s *Server) SubmitResult(ctx context.Context, req *proto.SubmitResultRequest) (*proto.SubmitResultResponse, error) {
185-
task, err := s.findTask(req.TaskId)
186-
if err != nil {
187-
return nil, err
199+
s.tasksMux.Lock()
200+
task, exists := s.tasks[req.TaskId]
201+
if !exists {
202+
s.tasksMux.Unlock()
203+
return nil, status.Errorf(codes.NotFound, "task %s not found", req.TaskId)
188204
}
189205

190206
now := time.Now()
191207
task.CompletedAt = &now
192-
193-
defer s.decrementCurrentLoad(task.WorkerID)
208+
workerID := task.WorkerID
194209

195210
if req.Error != "" {
196211
task.Error = req.Error
@@ -200,30 +215,43 @@ func (s *Server) SubmitResult(ctx context.Context, req *proto.SubmitResultReques
200215
task.Status = PENDING
201216
task.StartedAt = nil
202217
task.CompletedAt = nil
218+
s.tasksMux.Unlock()
203219

204220
s.appendTaskToQueue(task)
221+
s.decrementCurrentLoad(workerID)
222+
223+
return &proto.SubmitResultResponse{Success: true, Result: req.Result}, nil
205224
} else {
206225
task.Status = FAILED
226+
s.tasksMux.Unlock()
207227

228+
s.decrementCurrentLoad(workerID)
208229
return nil, status.Errorf(codes.DeadlineExceeded, "task %s failed after maximum retries: %s", req.TaskId, req.Error)
209230
}
210231
} else {
211232
task.Status = COMPLETED
212233
task.Result = req.Result
213-
}
234+
s.tasksMux.Unlock()
214235

215-
return &proto.SubmitResultResponse{Success: true, Result: req.Result}, nil
236+
s.decrementCurrentLoad(workerID)
237+
return &proto.SubmitResultResponse{Success: true, Result: req.Result}, nil
238+
}
216239
}
217240

218241
func (s *Server) FetchTask(ctx context.Context, req *proto.FetchTaskRequest) (*proto.FetchTaskResponse, error) {
219-
worker, err := s.findWorker(req.WorkerId)
220-
if err != nil {
221-
return nil, err
242+
s.workersMux.RLock()
243+
worker, exists := s.workers[req.WorkerId]
244+
if !exists {
245+
s.workersMux.RUnlock()
246+
return nil, status.Errorf(codes.NotFound, "worker %s not found", req.WorkerId)
222247
}
223248

224249
if worker.CurrentLoad >= worker.Capacity {
250+
s.workersMux.RUnlock()
225251
return &proto.FetchTaskResponse{HasTask: false}, nil
226252
}
253+
workerID := worker.ID
254+
s.workersMux.RUnlock()
227255

228256
s.queuesMux.Lock()
229257
defer s.queuesMux.Unlock()
@@ -241,10 +269,10 @@ func (s *Server) FetchTask(ctx context.Context, req *proto.FetchTaskRequest) (*p
241269
s.tasksMux.Lock()
242270
task.Status = RUNNING
243271
task.StartedAt = &now
244-
task.WorkerID = worker.ID
272+
task.WorkerID = workerID
245273
s.tasksMux.Unlock()
246274

247-
s.incrementCurrentLoad(worker.ID)
275+
s.incrementCurrentLoad(workerID)
248276

249277
return &proto.FetchTaskResponse{Task: task.toProtoTask(), HasTask: true}, nil
250278
}

0 commit comments

Comments
 (0)