Skip to content

Commit aeb1249

Browse files
committed
Fix concurrency issue
1 parent c0f9a11 commit aeb1249

7 files changed

Lines changed: 295 additions & 200 deletions

File tree

pkg/eventconsumer/event_consumer.go

Lines changed: 106 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/fystack/mpcium/pkg/messaging"
1717
"github.com/fystack/mpcium/pkg/monitoring"
1818
"github.com/fystack/mpcium/pkg/mpc/node"
19+
"github.com/fystack/mpcium/pkg/mpc/session"
1920
"github.com/fystack/mpcium/pkg/tsslimiter"
2021
"github.com/fystack/mpcium/pkg/types"
2122
"github.com/nats-io/nats.go"
@@ -29,10 +30,10 @@ const (
2930

3031
// Default version for keygen
3132
DefaultVersion int = 1
32-
SessionTimeout = 1 * time.Minute
33+
SessionTimeout = 15 * time.Second
3334
MaxConcurrentSessions = 5
3435
// how long the entire handler will wait for *all* sessions + publishing:
35-
HandlerTimeout = 2 * time.Minute
36+
HandlerTimeout = 20 * time.Second
3637
)
3738

3839
type EventConsumer interface {
@@ -77,7 +78,8 @@ func NewEventConsumer(
7778
identityStore identity.Store,
7879
) EventConsumer {
7980
limiter := tsslimiter.NewWeightedLimiter(concurrency.GetTSSConcurrencyLimit())
80-
limiterQueue := tsslimiter.NewWeightedQueue(limiter, 100)
81+
bufferSize := 100
82+
limiterQueue := tsslimiter.NewWeightedQueue(limiter, bufferSize)
8183

8284
ec := &eventConsumer{
8385
node: node,
@@ -122,16 +124,17 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error {
122124
// Create session limiter channel with capacity 5
123125
sub, err := ec.pubsub.Subscribe(MPCGenerateEvent, func(natMsg *nats.Msg) {
124126
logger.Info("Received key generation event", "subject", natMsg.Subject)
125-
// This blocks if max sessions are already running
126-
// go func(data []byte) {
127-
go func(data []byte) {
128-
// Ack the message immediately to prevent redelivery from JetStream. This is critical.
129-
130-
if err := ec.handleKeyGenerationEvent(context.Background(), data); err != nil {
127+
job := tsslimiter.SessionJob{
128+
Type: tsslimiter.SessionKeygenCombined,
129+
Run: func() error {
130+
return ec.handleKeyGenerationEvent(context.Background(), natMsg)
131+
},
132+
OnError: func(err error) {
131133
logger.Error("Failed to handle key generation event", err)
132-
}
133-
}(natMsg.Data)
134-
// }(natMsg.Data)
134+
},
135+
Name: fmt.Sprintf("keygen-%s", string(natMsg.Data)),
136+
}
137+
ec.limiterQueue.Enqueue(job)
135138
})
136139

137140
if err != nil {
@@ -142,7 +145,11 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error {
142145
return nil
143146
}
144147

145-
func (ec *eventConsumer) handleKeyGenerationEvent(parentCtx context.Context, raw []byte) error {
148+
func (ec *eventConsumer) handleKeyGenerationEvent(parentCtx context.Context, natMsg *nats.Msg) error {
149+
raw := natMsg.Data
150+
ctx, handlerCancel := context.WithTimeout(parentCtx, HandlerTimeout)
151+
defer handlerCancel()
152+
146153
// 1) decode and verify
147154
var msg types.GenerateKeyMessage
148155
if err := json.Unmarshal(raw, &msg); err != nil {
@@ -155,104 +162,97 @@ func (ec *eventConsumer) handleKeyGenerationEvent(parentCtx context.Context, raw
155162
walletID := msg.WalletID
156163
successEvent := &event.KeygenSuccessEvent{WalletID: walletID}
157164

158-
// 2) give this handler its own timeout
159-
handlerCtx, handlerCancel := context.WithTimeout(parentCtx, HandlerTimeout)
160-
defer handlerCancel()
165+
// 2) prepare both sessions
166+
s0, err := ec.node.CreateKeygenSession(types.KeyTypeSecp256k1, walletID, ec.mpcThreshold, ec.genKeySuccessQueue)
167+
if err != nil {
168+
return fmt.Errorf("create ECDSA session: %w", err)
169+
}
161170

162-
// wait for the sessions to return (even if they timed out)
163-
var wg sync.WaitGroup
164-
// wait for *both* callbacks to fire before publishing
165-
var cbWg sync.WaitGroup
166-
cbWg.Add(2)
171+
s1, err := ec.node.CreateKeygenSession(types.KeyTypeEd25519, walletID, ec.mpcThreshold, ec.genKeySuccessQueue)
172+
if err != nil {
173+
s0.Close()
174+
return fmt.Errorf("create EDDSA session: %w", err)
175+
}
167176

168-
var eventMutex sync.Mutex
177+
s0.Listen(ctx)
178+
s1.Listen(ctx)
169179

170-
// 3) enqueue ECDSA & EDDSA jobs
171-
for _, keyType := range []types.KeyType{types.KeyTypeSecp256k1, types.KeyTypeEd25519} {
172-
keyType := keyType
180+
defer s0.Close()
181+
defer s1.Close()
173182

174-
s, err := ec.node.CreateKeygenSession(keyType, walletID, ec.mpcThreshold, ec.genKeySuccessQueue)
175-
if err != nil {
176-
return fmt.Errorf("create %v session: %w", keyType, err)
183+
runKeygen := func(s session.Session, keyType types.KeyType) error {
184+
sessionCtx, sessionCancel := context.WithTimeout(ctx, SessionTimeout)
185+
defer sessionCancel()
186+
187+
// // 1. Wait for all parties to be ready to start
188+
if err := s.WaitForReady(sessionCtx, fmt.Sprintf("KEYGEN-start:%s", keyType)); err != nil {
189+
return fmt.Errorf("failed to wait for ready: %w", err)
177190
}
178-
s.Listen()
179-
180-
wg.Add(1)
181-
run := func() {
182-
defer wg.Done()
183-
defer s.Close()
184-
185-
// give each session its own shorter timeout
186-
sessionCtx, sessionCancel := context.WithTimeout(handlerCtx, SessionTimeout)
187-
defer sessionCancel()
188-
189-
s.StartKeygen(sessionCtx, s.Send, func(data []byte) {
190-
defer cbWg.Done() // signal that this keyType actually called back
191-
192-
logger.Info("[callback] StartKeygen fired", "walletID", walletID, "keyType", keyType)
193-
194-
// save the share
195-
if err := s.SaveKey(
196-
ec.node.GetReadyPeersIncludeSelf(),
197-
ec.mpcThreshold,
198-
DefaultVersion,
199-
data,
200-
); err != nil {
201-
logger.Error("Failed to save key", err, "walletID", walletID, "keyType", keyType)
202-
}
203191

204-
// extract & record the pubkey
205-
if pubKey, err := s.GetPublicKey(data); err == nil {
206-
eventMutex.Lock()
207-
switch keyType {
208-
case types.KeyTypeSecp256k1:
209-
successEvent.ECDSAPubKey = pubKey
210-
case types.KeyTypeEd25519:
211-
successEvent.EDDSAPubKey = pubKey
212-
}
213-
eventMutex.Unlock()
214-
} else {
215-
logger.Error("Failed to get public key", err, "walletID", walletID, "keyType", keyType)
192+
doneCh := make(chan error, 1)
193+
194+
// 2. Start the key generation protocol
195+
s.StartKeygen(sessionCtx, s.Send, func(data []byte) {
196+
logger.Info("[callback] StartKeygen fired", "walletID", walletID, "keyType", keyType)
197+
if err := s.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, DefaultVersion, data); err != nil {
198+
logger.Error("Failed to save key", err, "walletID", walletID, "keyType", keyType)
199+
doneCh <- err
200+
return
201+
}
202+
203+
if pubKey, err := s.GetPublicKey(data); err == nil {
204+
switch keyType {
205+
case types.KeyTypeSecp256k1:
206+
successEvent.ECDSAPubKey = pubKey
207+
case types.KeyTypeEd25519:
208+
successEvent.EDDSAPubKey = pubKey
216209
}
217-
})
218-
}
210+
} else {
211+
logger.Error("Failed to get public key", err, "walletID", walletID, "keyType", keyType)
212+
doneCh <- err
213+
return
214+
}
219215

220-
var sessionType tsslimiter.SessionType
221-
if keyType == types.KeyTypeSecp256k1 {
222-
sessionType = tsslimiter.SessionKeygenECDSA
223-
} else {
224-
sessionType = tsslimiter.SessionKeygenEDDSA
225-
}
216+
// // 3. Wait for all parties to confirm completion
217+
if err := s.WaitForReady(sessionCtx, fmt.Sprintf("KEYGEN-complete:%s", keyType)); err != nil {
218+
doneCh <- fmt.Errorf("failed to wait for completion: %w", err)
219+
return
220+
}
226221

227-
ec.limiterQueue.Enqueue(tsslimiter.SessionJob{
228-
Type: sessionType,
229-
Run: run,
222+
doneCh <- nil
230223
})
231-
}
232-
233-
// 4) wait for both session goroutines to return
234-
wg.Wait()
235224

236-
// 5) now wait for both callbacks (or handler timeout)
237-
doneCb := make(chan struct{})
238-
go func() {
239-
cbWg.Wait()
240-
close(doneCb)
241-
}()
225+
select {
226+
case err := <-doneCh:
227+
if err != nil {
228+
return fmt.Errorf("keygen onComplete failed: %w", err)
229+
}
230+
return nil
231+
case err := <-s.ErrCh():
232+
return fmt.Errorf("session error during keygen: %w", err)
233+
case <-sessionCtx.Done():
234+
return fmt.Errorf("keygen timed out: %w", sessionCtx.Err())
235+
}
236+
}
242237

243-
select {
244-
case <-handlerCtx.Done():
245-
logger.Warn("Keygen callbacks did not all fire before timeout", "walletID", walletID)
246-
return handlerCtx.Err()
247-
case <-doneCb:
248-
// both callbacks have run
238+
logger.Info("Starting ECDSA key generation...", "walletID", walletID)
239+
if err := runKeygen(s0, types.KeyTypeSecp256k1); err != nil {
240+
return fmt.Errorf("ECDSA keygen failed: %w", err)
249241
}
242+
logger.Info("ECDSA key generation completed.", "walletID", walletID)
250243

251-
// 6) marshal & publish success
244+
logger.Info("Starting EDDSA key generation...", "walletID", walletID)
245+
if err := runKeygen(s1, types.KeyTypeEd25519); err != nil {
246+
return fmt.Errorf("EDDSA keygen failed: %w", err)
247+
}
248+
logger.Info("EDDSA key generation completed.", "walletID", walletID)
249+
// 3) Send reply to keygen consumer after both keygens complete
250+
// 4) marshal & publish success
252251
successBytes, err := json.Marshal(successEvent)
253252
if err != nil {
254253
return fmt.Errorf("marshal success event: %w", err)
255254
}
255+
256256
if err := ec.genKeySuccessQueue.Enqueue(
257257
fmt.Sprintf(event.TypeGenerateWalletSuccess, walletID),
258258
successBytes,
@@ -263,6 +263,15 @@ func (ec *eventConsumer) handleKeyGenerationEvent(parentCtx context.Context, raw
263263
return fmt.Errorf("enqueue success event: %w", err)
264264
}
265265

266+
if natMsg.Reply != "" {
267+
err = ec.pubsub.Publish(natMsg.Reply, successBytes)
268+
if err != nil {
269+
logger.Error("Failed to publish reply", err)
270+
} else {
271+
logger.Info("Reply sent to keygen consumer", "reply", natMsg.Reply, "walletID", walletID)
272+
}
273+
}
274+
266275
logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID)
267276
return nil
268277
}
@@ -323,7 +332,7 @@ func (ec *eventConsumer) consumeTxSigningEvent() error {
323332
return
324333
}
325334

326-
go signingSession.Listen()
335+
go signingSession.Listen(context.Background())
327336

328337
txBigInt := new(big.Int).SetBytes(msg.Tx)
329338
go func() {
@@ -366,6 +375,7 @@ func (ec *eventConsumer) consumeTxSigningEvent() error {
366375

367376
logger.Info("Signing completed", "walletID", msg.WalletID, "txID", msg.TxID, "data", len(data))
368377
ec.removeSession(msg.WalletID, msg.TxID)
378+
369379
})
370380
}()
371381

@@ -436,8 +446,8 @@ func (ec *eventConsumer) handleReshareEvent(ctx context.Context, raw []byte) err
436446
return fmt.Errorf("create new session: %w", err)
437447
}
438448

439-
go oldSession.Listen()
440-
go newSession.Listen()
449+
go oldSession.Listen(context.Background())
450+
go newSession.Listen(context.Background())
441451

442452
successEvent := &event.ResharingSuccessEvent{WalletID: msg.WalletID}
443453

pkg/eventconsumer/keygen_consumer.go

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ import (
1313

1414
const (
1515
// Maximum time to wait for a keygen response.
16-
keygenResponseTimeout = 30 * time.Second
16+
keygenResponseTimeout = 90 * time.Second
1717
// How often to poll for the reply message.
18-
keygenPollingInterval = 500 * time.Millisecond
18+
keygenPollingInterval = 1 * time.Second
1919
)
2020

2121
// KeygenConsumer represents a consumer that processes keygen events.
@@ -50,7 +50,18 @@ func (sc *keygenConsumer) Run(ctx context.Context) error {
5050
logger.Info("Starting key generation event consumer")
5151

5252
go func() {
53-
ticker := time.NewTicker(30 * time.Second)
53+
// Initial fetch
54+
logger.Info("Calling to fetch key generation events...")
55+
err := sc.keygenRequestQueue.Fetch(5, func(msg jetstream.Msg) error {
56+
sc.handleKeygenEvent(msg)
57+
return nil
58+
})
59+
if err != nil && !errors.Is(err, context.DeadlineExceeded) {
60+
logger.Error("Error fetching key generation events", err)
61+
}
62+
63+
// Then start the ticker
64+
ticker := time.NewTicker(15 * time.Second)
5465
defer ticker.Stop()
5566

5667
for {
@@ -62,17 +73,12 @@ func (sc *keygenConsumer) Run(ctx context.Context) error {
6273
case <-ticker.C:
6374
logger.Info("Calling to fetch key generation events...")
6475

65-
// No need for a separate fetch context since the fetch operation
66-
// is synchronous and completes before we'd cancel it
67-
err := sc.keygenRequestQueue.Fetch(2, func(msg jetstream.Msg) error {
76+
err := sc.keygenRequestQueue.Fetch(5, func(msg jetstream.Msg) error {
6877
sc.handleKeygenEvent(msg)
6978
return nil
7079
})
71-
72-
if err != nil {
73-
if !errors.Is(err, context.DeadlineExceeded) {
74-
logger.Error("Error fetching key generation events", err)
75-
}
80+
if err != nil && !errors.Is(err, context.DeadlineExceeded) {
81+
logger.Error("Error fetching key generation events", err)
7682
}
7783
}
7884
}
@@ -89,28 +95,22 @@ func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) {
8995
replySub, err := sc.natsConn.SubscribeSync(replyInbox)
9096
if err != nil {
9197
logger.Error("KeygenConsumer: Failed to subscribe to reply inbox", err)
92-
_ = msg.Nak()
98+
_ = msg.Term()
9399
return
94100
}
95-
defer func() {
96-
if err := replySub.Unsubscribe(); err != nil {
97-
logger.Warn("KeygenConsumer: Failed to unsubscribe from reply inbox", err)
98-
}
99-
}()
101+
defer replySub.Unsubscribe()
100102

101103
// Publish the keygen event with the reply inbox.
102104
if err := sc.pubsub.PublishWithReply(MPCGenerateEvent, replyInbox, msg.Data()); err != nil {
103105
logger.Error("KeygenConsumer: Failed to publish keygen event with reply", err)
104-
_ = msg.Nak()
106+
_ = msg.Term()
105107
return
106108
}
107109

108-
// Poll for the reply message until timeout.
109110
deadline := time.Now().Add(keygenResponseTimeout)
110111
for time.Now().Before(deadline) {
111112
replyMsg, err := replySub.NextMsg(keygenPollingInterval)
112113
if err != nil {
113-
// If timeout occurs, continue trying.
114114
if err == nats.ErrTimeout {
115115
continue
116116
}
@@ -119,15 +119,18 @@ func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) {
119119
}
120120
if replyMsg != nil {
121121
logger.Info("KeygenConsumer: Completed keygen event reply received")
122-
if ackErr := msg.Ack(); ackErr != nil {
123-
logger.Error("KeygenConsumer: ACK failed", ackErr)
122+
if err := msg.Ack(); err != nil && !messaging.IsAlreadyAcknowledged(err) {
123+
logger.Error("KeygenConsumer: ACK failed", err)
124124
}
125125
return
126126
}
127127
}
128128

129+
// Timeout
129130
logger.Warn("KeygenConsumer: Timeout waiting for keygen event response")
130-
_ = msg.Nak()
131+
if err := msg.Term(); err != nil {
132+
logger.Error("KeygenConsumer: Failed to terminate message", err)
133+
}
131134
}
132135

133136
// Close unsubscribes from the JetStream subject and cleans up resources.

0 commit comments

Comments
 (0)