Skip to content

Commit 5ca061f

Browse files
committed
Make worker thread-safe
1 parent 9d9917c commit 5ca061f

2 files changed

Lines changed: 62 additions & 1 deletion

File tree

tests/worker/start_task_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"net/http"
99
"net/http/httptest"
10+
"sync"
1011
"testing"
1112

1213
"github.com/google/uuid"
@@ -61,3 +62,46 @@ func TestStartTask__ShouldReturnAnErrorIfCreatingTheSameTaskTwice(t *testing.T)
6162
assert.Equal(t, 1, len(api.Worker.Tasks), "Tasks map should contain 1 task")
6263
assert.NotNil(t, api.Worker.Tasks[testTask.ID], "Persisted task ID should match the one from request")
6364
}
65+
66+
func TestStartTask__AllButOneRequestsShouldFailIfCreatingTheSameTaskSimultaneously(t *testing.T) {
67+
api := &worker.Api{
68+
Worker: &worker.Worker{
69+
Tasks: make(map[uuid.UUID]*task.Task),
70+
},
71+
}
72+
testTask := helper.PrintFileTask("print-task", helper.HostsFilePath)
73+
numOfRequests := 10
74+
requests := make([]*http.Request, numOfRequests)
75+
responseRecorders := make([]*httptest.ResponseRecorder, numOfRequests)
76+
77+
var wg sync.WaitGroup
78+
for i := range numOfRequests {
79+
wg.Add(1)
80+
81+
requests[i] = helper.NewTaskPostRequest(testTask)
82+
responseRecorders[i] = httptest.NewRecorder()
83+
84+
go func() {
85+
defer wg.Done()
86+
87+
api.HandleCreateTask(responseRecorders[i], requests[i])
88+
}()
89+
}
90+
91+
wg.Wait()
92+
93+
succeededRequests, conflictedRequests := 0, 0
94+
for i := range numOfRequests {
95+
switch responseRecorders[i].Code {
96+
case http.StatusCreated:
97+
succeededRequests++
98+
case http.StatusConflict:
99+
conflictedRequests++
100+
}
101+
}
102+
103+
assert.Equal(t, 1, succeededRequests, "There should be only 1 succeeded request")
104+
assert.Equal(t, numOfRequests-1, conflictedRequests, "There should be only N-1 conflicted requests")
105+
assert.Equal(t, 1, len(api.Worker.Tasks), "Tasks map should contain 1 task")
106+
assert.NotNil(t, api.Worker.Tasks[testTask.ID], "Persisted task ID should match the one from request")
107+
}

worker/worker.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,44 @@ import (
66
"maps"
77
"os"
88
"os/exec"
9+
"sync"
910

1011
"github.com/google/uuid"
1112
)
1213

1314
type Worker struct {
15+
sync.RWMutex
16+
1417
Tasks map[uuid.UUID]*task.Task
1518
}
1619

1720
func (w *Worker) ListTasks() iter.Seq[*task.Task] {
21+
w.RLock()
22+
defer w.RUnlock()
23+
1824
return maps.Values(w.Tasks)
1925
}
2026

2127
func (w *Worker) GetTask(id uuid.UUID) *task.Task {
28+
w.RLock()
29+
defer w.RUnlock()
30+
2231
return w.Tasks[id]
2332
}
2433

2534
func (w *Worker) StartTask(t task.Task) error {
35+
w.Lock()
36+
2637
if _, ok := w.Tasks[t.ID]; ok {
38+
w.Unlock()
2739
return task.ErrAlreadyExists
2840
}
2941

3042
t.Cmd = exec.Command(t.Executable, t.Args...)
3143
w.Tasks[t.ID] = &t
3244

45+
w.Unlock()
46+
3347
stdout, err := t.Cmd.CombinedOutput()
3448
os.Stdout.Write(stdout)
3549
if err != nil {
@@ -40,7 +54,10 @@ func (w *Worker) StartTask(t task.Task) error {
4054
}
4155

4256
func (w *Worker) StopTask(id uuid.UUID) error {
43-
t := w.GetTask(id)
57+
w.Lock()
58+
defer w.Unlock()
59+
60+
t := w.Tasks[id]
4461
if t == nil {
4562
return task.ErrNotExists
4663
}

0 commit comments

Comments
 (0)