Skip to content

Commit 9786ecd

Browse files
committed
feat(worker): implement worker + tests
1 parent 4f52777 commit 9786ecd

2 files changed

Lines changed: 730 additions & 25 deletions

File tree

internal/worker/worker.go

Lines changed: 224 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,239 @@
11
package worker
22

33
import (
4+
"context"
5+
"log"
6+
"sync"
47
"time"
58

6-
"github.com/google/uuid"
79
"github.com/mateusmlo/taskqueue/proto"
8-
"google.golang.org/grpc/codes"
9-
"google.golang.org/grpc/status"
10+
"google.golang.org/grpc"
11+
"google.golang.org/grpc/credentials"
1012
)
1113

1214
type Worker struct {
13-
ID string
14-
Address string
15-
RegisteredAt time.Time
16-
LastHeartbeat time.Time
17-
TaskTypes []string
18-
Capacity int
19-
CurrentLoad int
20-
Metadata map[string]string
21-
}
22-
23-
// FromProtoWorker initializes a Worker instance from a proto.Worker message (server generates ID)
24-
func (w *Worker) FromProtoWorker(pw *proto.Worker) error {
25-
uuid, err := uuid.NewV7()
15+
serverAddr string
16+
conn *grpc.ClientConn
17+
client proto.WorkerServiceClient
18+
19+
id string
20+
capacity int
21+
22+
handlers map[string]TaskHandler
23+
currentLoad int
24+
loadMux sync.RWMutex
25+
26+
ctx context.Context
27+
cancel context.CancelFunc
28+
wg sync.WaitGroup
29+
}
30+
31+
type TaskHandler interface {
32+
Handle(ctx context.Context, payload []byte) ([]byte, error)
33+
}
34+
35+
func NewWorker(serverAddr string, capacity int) *Worker {
36+
ctx, cancel := context.WithCancel(context.Background())
37+
38+
return &Worker{
39+
serverAddr: serverAddr,
40+
capacity: capacity,
41+
handlers: make(map[string]TaskHandler),
42+
ctx: ctx,
43+
cancel: cancel,
44+
}
45+
}
46+
47+
func (w *Worker) RegisterHandler(taskType string, handler TaskHandler) {
48+
w.handlers[taskType] = handler
49+
}
50+
51+
func (w *Worker) Start() error {
52+
tcr, err := credentials.NewClientTLSFromFile("./cert/server.crt", "localhost")
2653
if err != nil {
27-
return status.Errorf(codes.Internal, "failed to generate worker UUID: %v", err)
54+
return err
55+
}
56+
57+
conn, err := grpc.NewClient(w.serverAddr, grpc.WithTransportCredentials(tcr))
58+
if err != nil {
59+
return err
60+
}
61+
62+
w.conn = conn
63+
w.client = proto.NewWorkerServiceClient(w.conn)
64+
65+
if err := w.register(); err != nil {
66+
w.conn.Close()
67+
return err
68+
}
69+
70+
w.wg.Add(2)
71+
go w.heartbeatLoop()
72+
go w.fetchLoop()
73+
74+
return nil
75+
}
76+
77+
func (w *Worker) Stop() {
78+
w.cancel()
79+
w.wg.Wait()
80+
81+
if w.conn != nil {
82+
if err := w.conn.Close(); err != nil {
83+
log.Printf("Error closing gRPC connection: %v", err)
84+
}
85+
86+
w.conn = nil
87+
}
88+
}
89+
90+
func (w *Worker) heartbeatLoop() {
91+
defer w.wg.Done()
92+
93+
ticker := time.NewTicker(10 * time.Second)
94+
defer ticker.Stop()
95+
96+
for {
97+
select {
98+
case <-ticker.C:
99+
req := w.buildHeartbeatRequest()
100+
_, err := w.client.Heartbeat(w.ctx, req)
101+
if err != nil {
102+
log.Printf("Worker heartbeat error: %v", err)
103+
}
104+
case <-w.ctx.Done():
105+
return
106+
}
107+
}
108+
}
109+
110+
func (w *Worker) fetchLoop() {
111+
defer w.wg.Done()
112+
113+
ticker := time.NewTicker(2 * time.Second)
114+
defer ticker.Stop()
115+
116+
for {
117+
select {
118+
case <-ticker.C:
119+
req := w.buildFetchTasksRequest()
120+
res, err := w.client.FetchTask(w.ctx, req)
121+
if err != nil {
122+
log.Printf("Worker fetch task error: %v", err)
123+
continue
124+
}
125+
126+
if !res.HasTask {
127+
continue
128+
}
129+
130+
handler, exists := w.handlers[res.Task.Type]
131+
if !exists {
132+
log.Printf("No handler registered for task type: %s", res.Task.Type)
133+
continue
134+
}
135+
136+
w.incrementLoad()
137+
138+
handleTask := w.getTaskHandler(handler)
139+
140+
go handleTask(res.Task)
141+
case <-w.ctx.Done():
142+
return
143+
}
28144
}
145+
}
146+
147+
func (w *Worker) getTaskHandler(handler TaskHandler) func(task *proto.Task) {
148+
return func(task *proto.Task) {
149+
defer w.decrementLoad()
150+
151+
result, err := handler.Handle(w.ctx, task.Payload)
152+
submitReq := &proto.SubmitResultRequest{
153+
TaskId: task.Id,
154+
}
155+
if err != nil {
156+
submitReq.Error = err.Error()
157+
submitReq.Result = nil
158+
} else {
159+
submitReq.Error = ""
160+
submitReq.Result = result
161+
}
29162

30-
w.ID = uuid.String()
31-
w.TaskTypes = pw.TaskTypes
32-
w.Address = pw.Metadata["address"]
33-
w.Capacity = int(pw.Capacity)
34-
w.CurrentLoad = 0
35-
w.Metadata = pw.Metadata
36-
w.RegisteredAt = time.Now()
37-
w.LastHeartbeat = time.Now()
163+
_, err = w.client.SubmitResult(w.ctx, submitReq)
164+
if err != nil {
165+
log.Printf("Error submitting task result: %v", err)
166+
}
167+
}
168+
}
38169

170+
func (w *Worker) getCurrentLoad() int32 {
171+
w.loadMux.RLock()
172+
defer w.loadMux.RUnlock()
173+
174+
return int32(w.currentLoad)
175+
}
176+
177+
func (w *Worker) incrementLoad() {
178+
w.loadMux.Lock()
179+
defer w.loadMux.Unlock()
180+
181+
w.currentLoad++
182+
}
183+
184+
func (w *Worker) decrementLoad() {
185+
w.loadMux.Lock()
186+
defer w.loadMux.Unlock()
187+
188+
if w.currentLoad > 0 {
189+
w.currentLoad--
190+
}
191+
}
192+
193+
func (w *Worker) register() error {
194+
req := w.buildRegisterRequest()
195+
196+
res, err := w.client.RegisterWorker(w.ctx, req)
197+
if err != nil {
198+
return err
199+
}
200+
201+
w.id = res.WorkerId
39202
return nil
40203
}
204+
205+
func (w *Worker) buildRegisterRequest() *proto.RegisterWorkerRequest {
206+
taskTypes := make([]string, 0, len(w.handlers))
207+
for taskType := range w.handlers {
208+
taskTypes = append(taskTypes, taskType)
209+
}
210+
211+
return &proto.RegisterWorkerRequest{
212+
Worker: &proto.Worker{
213+
TaskTypes: taskTypes,
214+
Capacity: int32(w.capacity),
215+
Metadata: map[string]string{
216+
"address": w.serverAddr,
217+
},
218+
},
219+
}
220+
}
221+
222+
func (w *Worker) buildFetchTasksRequest() *proto.FetchTaskRequest {
223+
taskTypes := make([]string, 0, len(w.handlers))
224+
for taskType := range w.handlers {
225+
taskTypes = append(taskTypes, taskType)
226+
}
227+
228+
return &proto.FetchTaskRequest{
229+
WorkerId: w.id,
230+
TaskTypes: taskTypes,
231+
}
232+
}
233+
234+
func (w *Worker) buildHeartbeatRequest() *proto.HeartbeatRequest {
235+
return &proto.HeartbeatRequest{
236+
WorkerId: w.id,
237+
CurrentLoad: w.getCurrentLoad(),
238+
}
239+
}

0 commit comments

Comments
 (0)