Skip to content

Commit 5fc6ce1

Browse files
committed
Encapsulate worker's tasks map
To make sure that access to it will be provided only by thread-safe worker's methods
1 parent 01f8cb4 commit 5fc6ce1

5 files changed

Lines changed: 43 additions & 51 deletions

File tree

cmd/worker/main.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
package main
22

33
import (
4-
"dirigeant/task"
54
"dirigeant/worker"
65
"net/http"
76

87
"github.com/go-chi/chi/v5"
98
"github.com/go-chi/chi/v5/middleware"
10-
"github.com/google/uuid"
119
)
1210

1311
func main() {
1412
api := &worker.Api{
15-
Worker: &worker.Worker{
16-
Tasks: make(map[uuid.UUID]*task.Task),
17-
},
13+
Worker: worker.NewWorker(),
1814
}
1915
r := chi.NewRouter()
2016

tests/worker/start_task_test.go

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,12 @@ import (
1212
"testing"
1313
"time"
1414

15-
"github.com/google/uuid"
1615
"github.com/stretchr/testify/assert"
1716
)
1817

1918
func TestStartTask__ShouldPersistTask(t *testing.T) {
2019
api := &worker.Api{
21-
Worker: &worker.Worker{
22-
Tasks: make(map[uuid.UUID]*task.Task),
23-
},
20+
Worker: worker.NewWorker(),
2421
}
2522
testTask := helper.PrintFileTask("print-task", helper.HostsFilePath)
2623
request := helper.NewTaskPostRequest(testTask)
@@ -30,15 +27,13 @@ func TestStartTask__ShouldPersistTask(t *testing.T) {
3027

3128
assert.Equal(t, http.StatusCreated, responseRecorder.Code, "Response status code should be 201 Created")
3229
assert.Empty(t, responseRecorder.Body, "Response body should be empty")
33-
assert.Equal(t, 1, len(api.Worker.Tasks), "Tasks map should contain 1 task")
34-
assert.NotNil(t, api.Worker.Tasks[testTask.ID], "Persisted task ID should match the one from request")
30+
assert.Equal(t, 1, api.Worker.LenTasks(), "Tasks map should contain 1 task")
31+
assert.NotNil(t, api.Worker.GetTask(testTask.ID), "Persisted task ID should match the one from request")
3532
}
3633

3734
func TestStartTask__ShouldReturnAnErrorIfCreatingTheSameTaskTwice(t *testing.T) {
3835
api := &worker.Api{
39-
Worker: &worker.Worker{
40-
Tasks: make(map[uuid.UUID]*task.Task),
41-
},
36+
Worker: worker.NewWorker(),
4237
}
4338
testTask := helper.PrintFileTask("print-task", helper.HostsFilePath)
4439

@@ -50,8 +45,8 @@ func TestStartTask__ShouldReturnAnErrorIfCreatingTheSameTaskTwice(t *testing.T)
5045

5146
assert.Equal(t, http.StatusCreated, firstResponseRecorder.Code, "Response status code should be 201 Created")
5247
assert.Empty(t, firstResponseRecorder.Body, "Response body should be empty")
53-
assert.Equal(t, 1, len(api.Worker.Tasks), "Tasks map should contain 1 task")
54-
assert.NotNil(t, api.Worker.Tasks[testTask.ID], "Persisted task ID should match the one from request")
48+
assert.Equal(t, 1, api.Worker.LenTasks(), "Tasks map should contain 1 task")
49+
assert.NotNil(t, api.Worker.GetTask(testTask.ID), "Persisted task ID should match the one from request")
5550

5651
// 2 - Create the same task for the second time
5752
secondRequest := helper.NewTaskPostRequest(testTask)
@@ -61,15 +56,13 @@ func TestStartTask__ShouldReturnAnErrorIfCreatingTheSameTaskTwice(t *testing.T)
6156

6257
assert.Equal(t, http.StatusConflict, secondResponseRecorder.Code, "Response status code should be 409 Conflict")
6358
assert.Equal(t, fmt.Sprintf("Error when executing the task: %s", task.ErrAlreadyExists), secondResponseRecorder.Body.String(), "Response body should contain error message")
64-
assert.Equal(t, 1, len(api.Worker.Tasks), "Tasks map should contain 1 task")
65-
assert.NotNil(t, api.Worker.Tasks[testTask.ID], "Persisted task ID should match the one from request")
59+
assert.Equal(t, 1, api.Worker.LenTasks(), "Tasks map should contain 1 task")
60+
assert.NotNil(t, api.Worker.GetTask(testTask.ID), "Persisted task ID should match the one from request")
6661
}
6762

6863
func TestStartTask__AllButOneRequestsShouldFailIfCreatingTheSameTaskSimultaneously(t *testing.T) {
6964
api := &worker.Api{
70-
Worker: &worker.Worker{
71-
Tasks: make(map[uuid.UUID]*task.Task),
72-
},
65+
Worker: worker.NewWorker(),
7366
}
7467
testTask := helper.PrintFileTask("print-task", helper.HostsFilePath)
7568
numOfRequests := 10
@@ -104,15 +97,13 @@ func TestStartTask__AllButOneRequestsShouldFailIfCreatingTheSameTaskSimultaneous
10497

10598
assert.Equal(t, 1, succeededRequests, "There should be only 1 succeeded request")
10699
assert.Equal(t, numOfRequests-1, conflictedRequests, "There should be only N-1 conflicted requests")
107-
assert.Equal(t, 1, len(api.Worker.Tasks), "Tasks map should contain 1 task")
108-
assert.NotNil(t, api.Worker.Tasks[testTask.ID], "Persisted task ID should match the one from request")
100+
assert.Equal(t, 1, api.Worker.LenTasks(), "Tasks map should contain 1 task")
101+
assert.NotNil(t, api.Worker.GetTask(testTask.ID), "Persisted task ID should match the one from request")
109102
}
110103

111104
func TestStartTask__ShouldHandleClientClosedRequest(t *testing.T) {
112105
api := &worker.Api{
113-
Worker: &worker.Worker{
114-
Tasks: make(map[uuid.UUID]*task.Task),
115-
},
106+
Worker: worker.NewWorker(),
116107
}
117108
testTask := helper.PingTask("ping-task", "127.0.0.1")
118109
ctx, cancel := context.WithCancel(context.TODO())
@@ -133,12 +124,12 @@ func TestStartTask__ShouldHandleClientClosedRequest(t *testing.T) {
133124
assert.Equal(t, 499, createResponseRecorder.Code, "Response status code should be 499 Client Closed Request")
134125
assert.Equal(t, "Error when executing the task: client closed request", createResponseRecorder.Body.String(), "Response body should contain error message")
135126
assert.NotEmpty(t, stdout, "Task logs shouldn't be empty")
136-
assert.Empty(t, api.Worker.Tasks, "Tasks map should be empty")
127+
assert.Zero(t, api.Worker.LenTasks(), "Tasks map should be empty")
137128
}()
138129

139130
time.Sleep(1 * time.Second)
140131

141-
assert.Equal(t, 1, len(api.Worker.Tasks), "Tasks map should contain 1 task")
132+
assert.Equal(t, 1, api.Worker.LenTasks(), "Tasks map should contain 1 task")
142133

143134
// 2 - Cancel a request
144135
cancel()

tests/worker/stop_task_test.go

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ func TestStopTask__ShouldReturnAnErrorIfNotFound(t *testing.T) {
3030

3131
func TestStopTask__ShouldStopCompletedTask(t *testing.T) {
3232
api := &worker.Api{
33-
Worker: &worker.Worker{
34-
Tasks: make(map[uuid.UUID]*task.Task),
35-
},
33+
Worker: worker.NewWorker(),
3634
}
3735
testTask := helper.PrintFileTask("print-task", helper.HostsFilePath)
3836

@@ -43,7 +41,7 @@ func TestStopTask__ShouldStopCompletedTask(t *testing.T) {
4341

4442
assert.Equal(t, http.StatusCreated, responseRecorder.Code, "Response status code should be 201 Created")
4543
assert.Empty(t, responseRecorder.Body, "Response body should be empty")
46-
assert.Equal(t, 1, len(api.Worker.Tasks), "Tasks map should contain 1 task")
44+
assert.Equal(t, 1, api.Worker.LenTasks(), "Tasks map should contain 1 task")
4745

4846
// 2 - Delete a task
4947
request = helper.NewTaskDeleteRequest(testTask.ID)
@@ -53,14 +51,12 @@ func TestStopTask__ShouldStopCompletedTask(t *testing.T) {
5351

5452
assert.Equal(t, http.StatusNoContent, responseRecorder.Code, "Response status code should be 204 No Content")
5553
assert.Empty(t, responseRecorder.Body, "Response body should be empty")
56-
assert.Empty(t, api.Worker.Tasks, "Tasks map should be empty")
54+
assert.Zero(t, api.Worker.LenTasks(), "Tasks map should be empty")
5755
}
5856

5957
func TestStopTask__ShouldStopRunningTask(t *testing.T) {
6058
api := &worker.Api{
61-
Worker: &worker.Worker{
62-
Tasks: make(map[uuid.UUID]*task.Task),
63-
},
59+
Worker: worker.NewWorker(),
6460
}
6561
testTask := helper.PingTask("ping-task", "127.0.0.1")
6662

@@ -77,12 +73,12 @@ func TestStopTask__ShouldStopRunningTask(t *testing.T) {
7773

7874
assert.Equal(t, http.StatusInternalServerError, createResponseRecorder.Code, "Response status code should be 500 Internal Server Error")
7975
assert.Equal(t, fmt.Sprintf("Error when executing the task: %s", helper.SignalKilledErrMessage), createResponseRecorder.Body.String(), "Response body should contain error message")
80-
assert.Empty(t, api.Worker.Tasks, "Tasks map should be empty")
76+
assert.Zero(t, api.Worker.LenTasks(), "Tasks map should be empty")
8177
}()
8278

8379
time.Sleep(1 * time.Second)
8480

85-
assert.Equal(t, 1, len(api.Worker.Tasks), "Tasks map should contain 1 task")
81+
assert.Equal(t, 1, api.Worker.LenTasks(), "Tasks map should contain 1 task")
8682

8783
// 2 - Delete a task
8884
deleteRequest := helper.NewTaskDeleteRequest(testTask.ID)
@@ -92,7 +88,7 @@ func TestStopTask__ShouldStopRunningTask(t *testing.T) {
9288

9389
assert.Equal(t, http.StatusNoContent, deleteResponseRecorder.Code, "Response status code should be 204 No Content")
9490
assert.Empty(t, deleteResponseRecorder.Body, "Response body should be empty")
95-
assert.Empty(t, api.Worker.Tasks, "Tasks map should be empty")
91+
assert.Zero(t, api.Worker.LenTasks(), "Tasks map should be empty")
9692

9793
wg.Wait()
9894
}

tests/worker/task_logs_test.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
package worker
22

33
import (
4-
"dirigeant/task"
54
"dirigeant/tests/helper"
65
"dirigeant/worker"
76
"fmt"
87
"net/http"
98
"net/http/httptest"
109
"testing"
1110

12-
"github.com/google/uuid"
1311
"github.com/stretchr/testify/assert"
1412
)
1513

@@ -40,9 +38,7 @@ func TestTaskLogs__PrintFile(t *testing.T) {
4038
for _, tc := range tcs {
4139
t.Run(tc.name, func(t *testing.T) {
4240
api := &worker.Api{
43-
Worker: &worker.Worker{
44-
Tasks: make(map[uuid.UUID]*task.Task),
45-
},
41+
Worker: worker.NewWorker(),
4642
}
4743
request := helper.NewTaskPostRequest(helper.PrintFileTask(tc.name, tc.path))
4844
responseRecorder := httptest.NewRecorder()

worker/worker.go

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,46 @@ import (
1414
type Worker struct {
1515
sync.RWMutex
1616

17-
Tasks map[uuid.UUID]*task.Task
17+
tasks map[uuid.UUID]*task.Task
18+
}
19+
20+
func NewWorker() *Worker {
21+
return &Worker{
22+
tasks: make(map[uuid.UUID]*task.Task),
23+
}
24+
}
25+
26+
func (w *Worker) LenTasks() int {
27+
w.RLock()
28+
defer w.RUnlock()
29+
30+
return len(w.tasks)
1831
}
1932

2033
func (w *Worker) ListTasks() iter.Seq[*task.Task] {
2134
w.RLock()
2235
defer w.RUnlock()
2336

24-
return maps.Values(w.Tasks)
37+
return maps.Values(w.tasks)
2538
}
2639

2740
func (w *Worker) GetTask(id uuid.UUID) *task.Task {
2841
w.RLock()
2942
defer w.RUnlock()
3043

31-
return w.Tasks[id]
44+
return w.tasks[id]
3245
}
3346

3447
func (w *Worker) StartTask(t task.Task) error {
3548
w.Lock()
3649

37-
if _, ok := w.Tasks[t.ID]; ok {
50+
if _, ok := w.tasks[t.ID]; ok {
3851
w.Unlock()
3952
return task.ErrAlreadyExists
4053
}
4154

4255
t.Cmd = exec.Command(t.Executable, t.Args...)
43-
w.Tasks[t.ID] = &t
56+
w.tasks[t.ID] = &t
4457

4558
w.Unlock()
4659

@@ -57,7 +70,7 @@ func (w *Worker) StopTask(id uuid.UUID) error {
5770
w.Lock()
5871
defer w.Unlock()
5972

60-
t := w.Tasks[id]
73+
t := w.tasks[id]
6174
if t == nil {
6275
return task.ErrNotExists
6376
}
@@ -68,7 +81,7 @@ func (w *Worker) StopTask(id uuid.UUID) error {
6881
}
6982
}
7083

71-
delete(w.Tasks, t.ID)
84+
delete(w.tasks, t.ID)
7285

7386
return nil
7487
}

0 commit comments

Comments
 (0)