@@ -132,26 +132,36 @@ func (s *Server) SubmitTask(ctx context.Context, req *proto.SubmitTaskRequest) (
132132
133133// GetTaskStatus retrieves the status of a task by its ID
134134func (s * Server ) GetTaskStatus (ctx context.Context , req * proto.GetTaskStatusRequest ) (* proto.GetTaskStatusResponse , error ) {
135- task , err := s .findTask (req .TaskId )
136- if err != nil {
137- return nil , err
135+ s .tasksMux .RLock ()
136+ task , exists := s .tasks [req .TaskId ]
137+ if ! exists {
138+ s .tasksMux .RUnlock ()
139+ return nil , status .Errorf (codes .NotFound , "task %s not found" , req .TaskId )
138140 }
141+ taskStatus := task .Status
142+ s .tasksMux .RUnlock ()
139143
140- return & proto.GetTaskStatusResponse {Status : proto .TaskStatus (task . Status )}, nil
144+ return & proto.GetTaskStatusResponse {Status : proto .TaskStatus (taskStatus )}, nil
141145}
142146
143147// GetTaskResult retrieves the result of a completed task by its ID
144148func (s * Server ) GetTaskResult (ctx context.Context , req * proto.GetTaskResultRequest ) (* proto.GetTaskResultResponse , error ) {
145- task , err := s .findTask (req .TaskId )
146- if err != nil {
147- return nil , err
149+ s .tasksMux .RLock ()
150+ task , exists := s .tasks [req .TaskId ]
151+ if ! exists {
152+ s .tasksMux .RUnlock ()
153+ return nil , status .Errorf (codes .NotFound , "task %s not found" , req .TaskId )
148154 }
149155
150156 if task .Status != COMPLETED {
157+ s .tasksMux .RUnlock ()
151158 return nil , status .Errorf (codes .FailedPrecondition , "task %s not completed yet" , req .TaskId )
152159 }
153160
154- return & proto.GetTaskResultResponse {Result : task .Result }, nil
161+ result := task .Result
162+ s .tasksMux .RUnlock ()
163+
164+ return & proto.GetTaskResultResponse {Result : result }, nil
155165}
156166
157167// RegisterWorker handles worker registration requests
@@ -169,28 +179,33 @@ func (s *Server) RegisterWorker(ctx context.Context, req *proto.RegisterWorkerRe
169179
170180// Heartbeat processes heartbeat messages from workers
171181func (s * Server ) Heartbeat (ctx context.Context , req * proto.HeartbeatRequest ) (* proto.HeartbeatResponse , error ) {
172- worker , err := s .findWorker (req .WorkerId )
173- if err != nil {
174- return nil , err
182+ s .workersMux .Lock ()
183+ worker , exists := s .workers [req .WorkerId ]
184+ if ! exists {
185+ s .workersMux .Unlock ()
186+ return nil , status .Errorf (codes .NotFound , "worker %s not found" , req .WorkerId )
175187 }
176188
177189 worker .LastHeartbeat = time .Now ()
178190 worker .CurrentLoad = int (req .CurrentLoad )
191+ currentLoad := worker .CurrentLoad
192+ s .workersMux .Unlock ()
179193
180- return & proto.HeartbeatResponse {Success : true , CurrentLoad : int32 (worker . CurrentLoad )}, nil
194+ return & proto.HeartbeatResponse {Success : true , CurrentLoad : int32 (currentLoad )}, nil
181195}
182196
183197// SubmitResult processes the result submission from workers
184198func (s * Server ) SubmitResult (ctx context.Context , req * proto.SubmitResultRequest ) (* proto.SubmitResultResponse , error ) {
185- task , err := s .findTask (req .TaskId )
186- if err != nil {
187- return nil , err
199+ s .tasksMux .Lock ()
200+ task , exists := s .tasks [req .TaskId ]
201+ if ! exists {
202+ s .tasksMux .Unlock ()
203+ return nil , status .Errorf (codes .NotFound , "task %s not found" , req .TaskId )
188204 }
189205
190206 now := time .Now ()
191207 task .CompletedAt = & now
192-
193- defer s .decrementCurrentLoad (task .WorkerID )
208+ workerID := task .WorkerID
194209
195210 if req .Error != "" {
196211 task .Error = req .Error
@@ -200,30 +215,43 @@ func (s *Server) SubmitResult(ctx context.Context, req *proto.SubmitResultReques
200215 task .Status = PENDING
201216 task .StartedAt = nil
202217 task .CompletedAt = nil
218+ s .tasksMux .Unlock ()
203219
204220 s .appendTaskToQueue (task )
221+ s .decrementCurrentLoad (workerID )
222+
223+ return & proto.SubmitResultResponse {Success : true , Result : req .Result }, nil
205224 } else {
206225 task .Status = FAILED
226+ s .tasksMux .Unlock ()
207227
228+ s .decrementCurrentLoad (workerID )
208229 return nil , status .Errorf (codes .DeadlineExceeded , "task %s failed after maximum retries: %s" , req .TaskId , req .Error )
209230 }
210231 } else {
211232 task .Status = COMPLETED
212233 task .Result = req .Result
213- }
234+ s . tasksMux . Unlock ()
214235
215- return & proto.SubmitResultResponse {Success : true , Result : req .Result }, nil
236+ s .decrementCurrentLoad (workerID )
237+ return & proto.SubmitResultResponse {Success : true , Result : req .Result }, nil
238+ }
216239}
217240
218241func (s * Server ) FetchTask (ctx context.Context , req * proto.FetchTaskRequest ) (* proto.FetchTaskResponse , error ) {
219- worker , err := s .findWorker (req .WorkerId )
220- if err != nil {
221- return nil , err
242+ s .workersMux .RLock ()
243+ worker , exists := s .workers [req .WorkerId ]
244+ if ! exists {
245+ s .workersMux .RUnlock ()
246+ return nil , status .Errorf (codes .NotFound , "worker %s not found" , req .WorkerId )
222247 }
223248
224249 if worker .CurrentLoad >= worker .Capacity {
250+ s .workersMux .RUnlock ()
225251 return & proto.FetchTaskResponse {HasTask : false }, nil
226252 }
253+ workerID := worker .ID
254+ s .workersMux .RUnlock ()
227255
228256 s .queuesMux .Lock ()
229257 defer s .queuesMux .Unlock ()
@@ -241,10 +269,10 @@ func (s *Server) FetchTask(ctx context.Context, req *proto.FetchTaskRequest) (*p
241269 s .tasksMux .Lock ()
242270 task .Status = RUNNING
243271 task .StartedAt = & now
244- task .WorkerID = worker . ID
272+ task .WorkerID = workerID
245273 s .tasksMux .Unlock ()
246274
247- s .incrementCurrentLoad (worker . ID )
275+ s .incrementCurrentLoad (workerID )
248276
249277 return & proto.FetchTaskResponse {Task : task .toProtoTask (), HasTask : true }, nil
250278 }
0 commit comments