From c41870e437c959affd668cfdc2ed456be2603ace Mon Sep 17 00:00:00 2001 From: vietddude Date: Thu, 16 Apr 2026 17:50:58 +0700 Subject: [PATCH 01/23] Implement coordinator runtime for MPC, including request handling, session management, and participant event processing. Update Go module version and dependencies. Add README documentation for the coordinator and its components. --- cmd/mpcium-coordinator/README.md | 69 ++ cmd/mpcium-coordinator/main.go | 141 ++++ .../architecture/external-cosigner-runtime.md | 637 ++++++++++++++++++ go.mod | 7 +- internal/coordinator/README.md | 136 ++++ internal/coordinator/coordinator.go | 493 ++++++++++++++ internal/coordinator/coordinator_test.go | 235 +++++++ internal/coordinator/errors.go | 37 + internal/coordinator/keyinfo.go | 56 ++ internal/coordinator/presence.go | 48 ++ internal/coordinator/publisher.go | 56 ++ internal/coordinator/runtime.go | 72 ++ internal/coordinator/signing.go | 78 +++ internal/coordinator/store.go | 328 +++++++++ internal/coordinator/topics.go | 33 + internal/coordinator/types.go | 81 +++ 16 files changed, 2505 insertions(+), 2 deletions(-) create mode 100644 cmd/mpcium-coordinator/README.md create mode 100644 cmd/mpcium-coordinator/main.go create mode 100644 docs/architecture/external-cosigner-runtime.md create mode 100644 internal/coordinator/README.md create mode 100644 internal/coordinator/coordinator.go create mode 100644 internal/coordinator/coordinator_test.go create mode 100644 internal/coordinator/errors.go create mode 100644 internal/coordinator/keyinfo.go create mode 100644 internal/coordinator/presence.go create mode 100644 internal/coordinator/publisher.go create mode 100644 internal/coordinator/runtime.go create mode 100644 internal/coordinator/signing.go create mode 100644 internal/coordinator/store.go create mode 100644 internal/coordinator/topics.go create mode 100644 internal/coordinator/types.go diff --git a/cmd/mpcium-coordinator/README.md b/cmd/mpcium-coordinator/README.md new file mode 100644 index 00000000..fd3b3e73 --- /dev/null +++ b/cmd/mpcium-coordinator/README.md @@ -0,0 +1,69 @@ +# Mpcium Coordinator MVP + +This runtime implements the v1 control-plane coordinator from `docs/architecture/external-cosigner-runtime.md`. + +It owns: + +- NATS request-reply intake on `mpc.v1.request.keygen`, `mpc.v1.request.sign`, and `mpc.v1.request.reshare` +- pinned participant validation +- session lifecycle state +- signed control fan-out to `mpc.v1.peer..control` +- participant event intake from `mpc.v1.session..event` +- terminal result publishing to `mpc.v1.session..result` + +It does not implement relay, MQTT mailboxing, p2p MPC packet routing, or legacy `mpc.*` subjects. + +## Run + +```sh +go run ./cmd/mpcium-coordinator \ + --nats-url nats://127.0.0.1:4222 \ + --coordinator-id coordinator-01 \ + --coordinator-private-key-hex \ + --snapshot-dir ./coordinator-snapshots \ + --relay-available=true +``` + +The same settings can be provided through environment variables: + +- `NATS_URL` +- `COORDINATOR_ID` +- `COORDINATOR_PRIVATE_KEY_HEX` +- `COORDINATOR_SNAPSHOT_DIR` +- `COORDINATOR_RELAY_AVAILABLE` +- `COORDINATOR_TICK_INTERVAL` + +Each operation has its own request shape. The operation comes from the NATS subject, so a sign request to `mpc.v1.request.sign` looks like: + +```json +{ + "request_id": "req_123", + "ttl_sec": 120, + "threshold": 2, + "participants": [ + { "peer_id": "peer-node-01", "transport": "nats" }, + { "peer_id": "peer-node-02", "transport": "nats" } + ], + "wallet_id": "wallet_123", + "key_type": "secp256k1", + "tx_id": "tx_456", + "tx_hash": "0xabc" +} +``` + +For keygen, send `wallet_id`, `threshold`, and the full keygen participant set to `mpc.v1.request.keygen`. `key_type` is optional; when omitted, participants should generate both `secp256k1` and `ed25519` for that wallet/session. For sign, send exactly the participants selected for this signing session; MVP validation requires `len(participants) == threshold`. + +Internal `nats` participants must publish online presence before requests are accepted: + +```json +{ + "v": 1, + "type": "peer.presence", + "peer_id": "peer-node-01", + "status": "online", + "transport": "nats", + "last_seen_at": "2026-04-16T10:00:00Z" +} +``` + +Publish it to `mpc.v1.peer.peer-node-01.presence`. diff --git a/cmd/mpcium-coordinator/main.go b/cmd/mpcium-coordinator/main.go new file mode 100644 index 00000000..e18e44d0 --- /dev/null +++ b/cmd/mpcium-coordinator/main.go @@ -0,0 +1,141 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + "os/signal" + "strconv" + "syscall" + "time" + + "github.com/fystack/mpcium/internal/coordinator" + "github.com/nats-io/nats.go" +) + +func main() { + if err := run(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func run() error { + natsURL := flag.String("nats-url", envDefault("NATS_URL", nats.DefaultURL), "NATS server URL") + coordinatorID := flag.String("coordinator-id", envDefault("COORDINATOR_ID", ""), "stable coordinator ID") + privateKeyHex := flag.String("coordinator-private-key-hex", envDefault("COORDINATOR_PRIVATE_KEY_HEX", ""), "hex encoded Ed25519 private key") + snapshotDir := flag.String("snapshot-dir", envDefault("COORDINATOR_SNAPSHOT_DIR", "coordinator-snapshots"), "directory for coordinator session snapshots") + relayAvailable := flag.Bool("relay-available", envBoolDefault("COORDINATOR_RELAY_AVAILABLE", true), "whether relay is available for MQTT participants") + defaultSessionTTLSec := flag.Int("default-session-ttl-sec", envIntDefault("COORDINATOR_DEFAULT_SESSION_TTL_SEC", 120), "default session TTL in seconds") + tickInterval := flag.Duration("tick-interval", envDurationDefault("COORDINATOR_TICK_INTERVAL", time.Second), "session timeout scan interval") + flag.Parse() + + if *coordinatorID == "" { + return fmt.Errorf("coordinator-id is required") + } + if *privateKeyHex == "" { + return fmt.Errorf("coordinator-private-key-hex is required") + } + + signer, err := coordinator.NewEd25519SignerFromHex(*privateKeyHex) + if err != nil { + return err + } + + nc, err := nats.Connect(*natsURL) + if err != nil { + return fmt.Errorf("connect to NATS: %w", err) + } + defer nc.Close() + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + snapshotStore := coordinator.NewAtomicFileSnapshotStore(*snapshotDir) + sessionStore, err := coordinator.NewMemorySessionStore(ctx, snapshotStore) + if err != nil { + return fmt.Errorf("restore coordinator state: %w", err) + } + keyInfoStore := coordinator.NewMemoryKeyInfoStore() + if err := coordinator.RestoreKeyInfoFromSnapshotStore(ctx, snapshotStore, keyInfoStore); err != nil { + return fmt.Errorf("restore key info: %w", err) + } + _ = relayAvailable + presence := coordinator.NewInMemoryPresenceView() + coord, err := coordinator.NewCoordinator(coordinator.CoordinatorConfig{ + CoordinatorID: *coordinatorID, + Signer: signer, + EventVerifier: coordinator.Ed25519SessionEventVerifier{}, + Store: sessionStore, + KeyInfoStore: keyInfoStore, + Presence: presence, + Controls: coordinator.NewNATSControlPublisher(nc), + Results: coordinator.NewNATSResultPublisher(nc), + DefaultSessionTTL: time.Duration(*defaultSessionTTLSec) * time.Second, + }) + if err != nil { + return err + } + + runtime := coordinator.NewNATSRuntime(nc, coord, presence) + if err := runtime.Start(ctx); err != nil { + return err + } + + ticker := time.NewTicker(*tickInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return runtime.Stop() + case <-ticker.C: + if _, err := coord.Tick(ctx); err != nil { + fmt.Fprintln(os.Stderr, "coordinator tick error:", err) + } + } + } +} + +func envDefault(name string, fallback string) string { + if value := os.Getenv(name); value != "" { + return value + } + return fallback +} + +func envBoolDefault(name string, fallback bool) bool { + value := os.Getenv(name) + if value == "" { + return fallback + } + parsed, err := strconv.ParseBool(value) + if err != nil { + return fallback + } + return parsed +} + +func envDurationDefault(name string, fallback time.Duration) time.Duration { + value := os.Getenv(name) + if value == "" { + return fallback + } + parsed, err := time.ParseDuration(value) + if err != nil { + return fallback + } + return parsed +} + +func envIntDefault(name string, fallback int) int { + value := os.Getenv(name) + if value == "" { + return fallback + } + parsed, err := strconv.Atoi(value) + if err != nil { + return fallback + } + return parsed +} diff --git a/docs/architecture/external-cosigner-runtime.md b/docs/architecture/external-cosigner-runtime.md new file mode 100644 index 00000000..800a572d --- /dev/null +++ b/docs/architecture/external-cosigner-runtime.md @@ -0,0 +1,637 @@ +# External And Mobile Cosigner Runtime Design + +## Status +Draft + +## Date +2026-04-16 + +## Summary +This document captures the current design proposal for supporting external cosigners and mobile cosigners through a new runtime model. + +Core direction: + +- Internal `mpcium node` workers continue to communicate over `NATS`. +- External and mobile cosigners connect through `MQTT`. +- `Relay` is also the MQTT broker, but it is transport-only. +- `Coordinator` is the control-plane authority for request intake, session management, participant selection, and lifecycle transitions. +- MPC packets are end-to-end peer-to-peer messages after a dedicated key exchange round. + +## Problem Statement +The current model assumes requests such as `keygen`, `sign`, and `reshare` are published into `NATS`, and internal nodes consume them directly. + +That model is not sufficient once we introduce: + +- external cosigner services +- mobile cosigners +- uncertain peer connectivity over MQTT +- store-and-forward behavior for offline external peers +- a dedicated key exchange round before encrypted MPC traffic starts + +## Goals + +- Support internal participants over `NATS`. +- Support external and mobile participants over `MQTT`. +- Keep MPC traffic end-to-end encrypted between participants. +- Allow offline external peers to reconnect and resume through relay mailbox behavior. +- Separate control-plane concerns from transport concerns. +- Keep the data plane simple enough for a new runtime implementation. + +## Non-Goals + +- Relay does not execute MPC logic. +- Relay does not decrypt MPC payloads. +- Coordinator does not inspect or process MPC round payloads. +- Mobile or end-user apps do not own session truth. + +## High-Level Decisions + +1. Introduce a dedicated `Coordinator` runtime. +2. Keep `Relay` separate from `Coordinator` at the responsibility level. +3. Keep `mpcium node` as a worker runtime, not a session authority. +4. Introduce a dedicated `Round 0` key exchange before normal MPC rounds. +5. Treat all MPC round packets after key exchange as end-to-end encrypted peer-to-peer messages. +6. Use versioned topic namespaces for the new runtime instead of reusing the current ad-hoc subject format. +7. Start with coordinator fan-out for control messages instead of a dedicated control broadcast topic. + +## Runtime Layout + +### Coordinator +Owns the control plane. + +Responsibilities: + +- receive `keygen`, `sign`, and `reshare` requests +- create `sessionId` +- choose participant set +- check peer presence +- send control messages to internal and external peers +- manage session state +- enforce timeout, abort, and completion rules + +### Relay +Owns transport for external peers. + +Responsibilities: + +- act as MQTT broker +- track online and offline status +- expose peer presence +- hold pending messages for offline peers +- forward messages by `sessionId` and `peerId` + +Explicitly out of scope: + +- no session ownership +- no participant selection +- no MPC computation +- no payload decryption + +### MPCIUM Node +Owns internal MPC worker behavior. + +Responsibilities: + +- join assigned sessions +- perform signed key exchange +- run MPC rounds +- emit lifecycle events such as ready, failed, completed + +### External Cosigner / Mobile Cosigner +Own the same participant-side behavior as internal nodes, but connect through MQTT via relay. + +Responsibilities: + +- connect and reconnect through relay +- receive control messages +- join sessions +- perform signed key exchange +- run MPC rounds +- emit lifecycle events + +## Control Plane And Data Plane + +### Control Plane +Owned by `Coordinator`. + +Used for: + +- request intake +- session creation +- participant assignment +- presence checks +- session start +- begin key exchange +- abort and timeout +- completion and failure aggregation + +### Data Plane +Owned by participants plus transport. + +Used for: + +- peer-to-peer key exchange hello messages +- encrypted MPC round packets + +Key rule: + +- `Coordinator` manages session lifecycle but never reads MPC packet bodies. +- `Relay` forwards packets but never interprets MPC semantics. + +## Proposed Deployment + +- `coordinator-runtime` +- `relay-runtime` +- `mpcium node` +- `external cosigner` +- `mobile cosigner` + +The coordinator may initially be deployed in the same environment as internal services, but it should remain a separate runtime or module from both relay and worker nodes. + +## Architecture Diagram + +```mermaid +flowchart LR + App["App / API Client"] -->|"request keygen / sign / reshare"| Coord["Coordinator"] + + Coord -->|"create session + choose participants"| NATS["NATS"] + Coord -->|"session control for external peers"| Relay["Relay + MQTT Broker"] + + subgraph Internal["Internal MPC Cluster"] + NATS --> Worker1["mpcium node"] + NATS --> Worker2["mpcium node"] + NATS --> Worker3["mpcium node"] + end + + subgraph External["External Cosigner Side"] + Relay <-->|"MQTT connect / reconnect"| Cosigner["External Cosigner"] + Relay <-->|"MQTT connect / reconnect"| Mobile["Mobile Cosigner"] + end + + Relay --> Presence["Presence / mailbox / route only"] + Coord --> State["Session state / participants / timeout"] + + Worker1 -. "p2p MPC msg via NATS" .-> NATS + Worker2 -. "p2p MPC msg via NATS" .-> NATS + Worker3 -. "p2p MPC msg via NATS" .-> NATS + + Cosigner -. "p2p MPC msg via MQTT" .-> Relay + Mobile -. "p2p MPC msg via MQTT" .-> Relay +``` + +## Session Flow + +```mermaid +sequenceDiagram + participant A as "App / API Client" + participant C as "Coordinator" + participant N as "NATS" + participant I as "Internal mpcium node" + participant R as "Relay + MQTT Broker" + participant X as "External / Mobile Cosigner" + + A->>C: Request keygen / sign / reshare + C->>C: Create sessionId + participants + ttl + + C->>R: Check external peer presence + C->>N: Publish session_start for internal participants + C->>R: Forward session_start for external participants + + N-->>I: Deliver session_start + + alt cosigner online + R-->>X: Deliver session_start + else cosigner offline + R->>R: Store pending session + X->>R: Reconnect + R-->>X: Deliver pending session + end + + I->>C: joined / ready + X->>R: joined / ready + R->>C: joined / ready + + C->>N: begin_key_exchange + C->>R: begin_key_exchange + + Note over I,X: Round 0: signed key exchange + Note over I,X: Rounds 1..N: e2e encrypted p2p MPC messages + + I->>N: p2p msg + N-->>R: if target is external + R-->>X: route only + + X->>R: p2p msg + R->>N: if target is internal + N-->>I: route only + + I->>C: completed / failed + X->>R: completed / failed + R->>C: completed / failed + + C-->>A: Result / status +``` + +## Session Lifecycle + +```mermaid +flowchart TD + A["created"] --> B["waiting_participants"] + B --> C["ready"] + C --> D["key_exchange"] + D --> E["active_mpc"] + E --> F["completed"] + + B --> G["expired"] + C --> G + D --> H["failed"] + E --> H +``` + +## Topic Namespace + +### Design Principles + +- All new runtime topics are versioned. +- Topic naming should be transport-neutral at the logical level. +- `NATS` and `MQTT` use the same logical structure with different separators. +- Each peer has a stable control inbox and a session-scoped p2p inbox. +- Phase 1 avoids broadcast topics for control messages. Coordinator fans out control messages to each participant inbox. + +### Logical Namespace + +- `request.` +- `peer..control` +- `peer..session..p2p` +- `session..event` +- `peer..presence` + +### Transport Mapping + +| Purpose | NATS | MQTT | +| --- | --- | --- | +| Keygen request | `mpc.v1.request.keygen` | not required | +| Sign request | `mpc.v1.request.sign` | not required | +| Reshare request | `mpc.v1.request.reshare` | not required | +| Peer control inbox | `mpc.v1.peer..control` | `mpc/v1/peer//control` | +| Peer p2p inbox | `mpc.v1.peer..session..p2p` | `mpc/v1/peer//session//p2p` | +| Session event stream | `mpc.v1.session..event` | `mpc/v1/session//event` | +| Session terminal result | `mpc.v1.session..result` | not required | +| Peer presence | `mpc.v1.peer..presence` | `mpc/v1/peer//presence` | + +### Relay Bridge Expectations + +Relay should bridge: + +- external peer control inboxes +- external peer p2p inboxes +- presence events +- session lifecycle events from external participants back to coordinator + +Relay should not mutate session payload shape. + +## Message Model + +All messages use an envelope with explicit versioning and typed payloads. + +### Control Envelope + +```json +{ + "v": 1, + "type": "session.start", + "session_id": "sess_01HXYZ", + "op": "sign", + "request_id": "req_01HXYZ", + "correlation_id": "corr_01HXYZ", + "from": "coordinator", + "to": "peer-mobile-01", + "ts": "2026-04-16T10:00:00Z", + "ttl_sec": 120, + "sig": "base64(coordinator_signature)", + "body": {} +} +``` + +Control message types: + +- `session.start` +- `key_exchange.begin` +- `mpc.begin` +- `session.abort` + +Deferred control message types: + +- `session.cancel` +- `session.resume` + +### P2P Envelope + +```json +{ + "v": 1, + "type": "mpc.packet", + "session_id": "sess_01HXYZ", + "op": "sign", + "from": "peer-node-01", + "to": "peer-mobile-01", + "round": 2, + "seq": 14, + "ts": "2026-04-16T10:00:08Z", + "encryption": { + "alg": "x25519-chacha20poly1305", + "kid": "kx_01", + "nonce": "base64(...)" + }, + "ciphertext": "base64(...)" +} +``` + +### Session Event Envelope + +```json +{ + "v": 1, + "type": "peer.joined", + "session_id": "sess_01HXYZ", + "op": "sign", + "from": "peer-mobile-01", + "ts": "2026-04-16T10:00:03Z", + "body": {} +} +``` + +Session event types: + +- `peer.joined` +- `peer.ready` +- `peer.key_exchange_done` +- `peer.failed` +- `session.completed` +- `session.failed` +- `session.timed_out` + +### Presence Event + +```json +{ + "v": 1, + "type": "peer.presence", + "peer_id": "peer-mobile-01", + "status": "online", + "transport": "mqtt", + "conn_id": "conn_8f3a", + "last_seen_at": "2026-04-16T10:00:01Z" +} +``` + +## Suggested Message Bodies + +### session.start + +```json +{ + "threshold": 2, + "participants": [ + { "peer_id": "peer-node-01", "transport": "nats" }, + { "peer_id": "peer-node-02", "transport": "nats" }, + { "peer_id": "peer-mobile-01", "transport": "mqtt" } + ], + "key_type": "secp256k1", + "payload": { + "wallet_id": "wallet_123", + "tx_id": "tx_456", + "tx_hash": "0xabc" + } +} +``` + +### key_exchange.begin + +```json +{ + "exchange_id": "kx_01", + "curve": "x25519", + "participants": [ + "peer-node-01", + "peer-node-02", + "peer-mobile-01" + ] +} +``` + +### key_exchange.hello + +This message is sent peer-to-peer during round 0. It is not encrypted yet, but it must be signed by the sender identity. + +```json +{ + "v": 1, + "type": "key_exchange.hello", + "session_id": "sess_01HXYZ", + "from": "peer-mobile-01", + "to": "peer-node-01", + "ts": "2026-04-16T10:00:05Z", + "body": { + "exchange_id": "kx_01", + "identity_key_id": "id_mobile_01", + "ephemeral_pubkey": "base64(...)" + }, + "sig": "base64(peer_identity_signature)" +} +``` + +## Security Model + +- Control messages are signed by `Coordinator`. +- Key exchange hello messages are signed by the sender identity. +- All MPC packets after key exchange are end-to-end encrypted. +- `Relay` must only route based on metadata such as target peer and session. +- AEAD additional authenticated data should bind at least: + - `session_id` + - `from` + - `to` + - `round` + - `seq` + +## Session Ownership Rules + +- `Coordinator` is the single authority for session lifecycle. +- `Relay` may expose presence and mailbox state, but it is not the session owner. +- `mpcium node` workers do not decide session creation or participant assignment. +- External cosigners and mobile cosigners do not own global session truth. + +## Why Coordinator Is Not Relay + +`Relay` and `Coordinator` should remain separate responsibilities because they scale and fail differently. + +- Relay scales with connections and message forwarding. +- Coordinator scales with active session count and lifecycle state. +- Relay should remain transport-focused. +- Coordinator should remain control-focused. + +They may be co-located in early deployment, but the boundary should stay explicit in code and APIs. + +## Why Coordinator Is Not The MPC Node + +`mpcium node` should remain a worker runtime. + +If every node also acts as coordinator, the system must solve duplicate orchestration, leader election, and split-brain. That adds control-plane complexity into the worker path and makes debugging much harder. + +A separate coordinator runtime keeps: + +- one session authority +- simpler workers +- clearer retry and timeout behavior +- cleaner protocol boundaries + +## Coordinator V1 Runtime + +The first coordinator implementation is intentionally scoped to the new runtime only. It does not bridge legacy subjects such as `mpc.keygen_request.*` or the existing `eventconsumer` flow. + +Locked v1 choices: + +- Request intake is NATS request-reply on `mpc.v1.request.keygen`, `mpc.v1.request.sign`, and `mpc.v1.request.reshare`. +- The request reply is only an accept/reject response. Accepted responses include `session_id`, `status_subject`, `result_subject`, and `expires_at`. +- Terminal output is asynchronous and published on `mpc.v1.session..result`. +- Participant selection is request-pinned. Coordinator validates the requested participant set instead of auto-ranking peers. +- Internal `nats` participants must be online before a session is accepted. +- External or mobile `mqtt` participants may be offline if relay is available; Coordinator sends control messages and relies on relay mailbox delivery until the session TTL expires. +- Session state is in memory with an atomic JSON snapshot per session after each state transition. +- V1 is a singleton coordinator. Multi-coordinator leader election, CAS state storage, and delivery acknowledgments are deferred. + +Coordinator-owned lifecycle: + +1. Validate the signed request envelope and operation body. +2. Create `session_id`, initialize `created`, and fan out signed `session.start`. +3. Move to `waiting_participants` until all selected participants emit `peer.joined` and `peer.ready`. +4. Fan out `key_exchange.begin`, move to `key_exchange`, and wait for `peer.key_exchange_done` from every selected participant. +5. Fan out `mpc.begin`, move to `active_mpc`, and wait for terminal participant events. +6. Mark `completed` only when every selected participant emits `session.completed` with the same `result_hash`. +7. Mark `failed` on participant failure or result-hash mismatch. +8. Mark `expired` when absolute session TTL passes, then fan out `session.abort` and publish a terminal timeout result. + +## Open Questions + +- How long should pending session and pending control messages stay in relay mailbox? +- Should `session..event` also be mirrored to MQTT for external debugging and observability? +- Do we need resumable session tokens for mobile reconnect flows? +- Do we want a dedicated event for delivery acknowledgment from relay to coordinator? +- When production requires more than one coordinator, which durable store and leader election model should replace v1 in-memory snapshots? + +## Suggested Next Steps + +1. Define relay presence and mailbox behavior in more detail. +2. Implement the transport bridge behavior for MQTT peers. +3. Integrate participant runtimes with `session.start`, `key_exchange.begin`, `mpc.begin`, and `session.abort`. +4. Integrate round 0 key exchange into participant runtimes. +5. Replace v1 JSON contracts with protobuf if the runtime protocol standardizes on generated schemas. +6. Add durable coordinator state and leader election before running multiple coordinators. + +## Suggested Implementation Order + +The implementation should not start with full end-to-end signing. + +Even though `keygen` must happen before real `sign`, the first milestone should focus on the shared runtime foundation that both operations need. + +### Phase 1: Control Plane Foundation + +Implement: + +- coordinator runtime skeleton +- session store with in-memory state +- request intake +- participant selection +- session lifecycle states +- control message fan-out +- session event handling + +Deliverable: + +- coordinator can create a session and drive participants through `created -> waiting_participants -> ready` + +### Phase 2: Relay Foundation + +Implement: + +- relay runtime skeleton +- MQTT broker integration +- external peer presence +- online and offline tracking +- control message forwarding + +Deliverable: + +- external peers can connect, expose presence, and receive control messages through relay + +### Phase 3: Participant Control Integration + +Implement: + +- internal worker subscription to control inbox +- external cosigner subscription to control inbox +- `session.start` +- `peer.joined` +- `peer.ready` +- basic timeout and abort handling + +Deliverable: + +- coordinator can start a session and collect participant readiness from both internal and external peers + +### Phase 4: Round 0 Key Exchange + +Implement: + +- `key_exchange.begin` +- `key_exchange.hello` +- signed identity verification +- pairwise key derivation +- `peer.key_exchange_done` + +Deliverable: + +- all participants derive session keys and coordinator can transition a session into `active_mpc` + +### Phase 5: Generic P2P Transport + +Implement: + +- session-scoped p2p inbox +- internal transport over NATS +- external transport over MQTT through relay +- metadata-based routing in relay +- encrypted packet envelope + +Deliverable: + +- participants can exchange encrypted p2p packets over the new transport without yet running full MPC + +### Phase 6: MPC Operation Integration + +At this point, integrate real MPC operations in dependency order: + +1. `keygen` +2. `sign` +3. `reshare` + +Rationale: + +- `keygen` creates the key material needed by later operations +- `sign` depends on existing key material +- `reshare` depends on both lifecycle control and participant-change handling + +### Phase 7: Reliability And Recovery + +Implement: + +- relay mailbox +- offline replay for external peers +- session persistence +- reconnect and resume behavior +- delivery acknowledgment if needed +- observability and tracing + +Deliverable: + +- the runtime can recover from disconnects and partial failures with predictable behavior diff --git a/go.mod b/go.mod index 2d6dd8fe..38d1262d 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/fystack/mpcium -go 1.25.8 +go 1.26 require ( filippo.io/age v1.3.1 @@ -8,7 +8,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.32.7 github.com/aws/aws-sdk-go-v2/credentials v1.19.7 github.com/aws/aws-sdk-go-v2/service/kms v1.49.5 - github.com/bnb-chain/tss-lib/v2 v2.0.2 + github.com/bnb-chain/tss-lib/v2 v2.0.3 github.com/btcsuite/btcd v0.25.0 github.com/btcsuite/btcd/btcec/v2 v2.3.6 github.com/btcsuite/btcutil v1.0.2 @@ -86,6 +86,7 @@ require ( github.com/spf13/pflag v1.0.10 // indirect github.com/stretchr/objx v0.5.3 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/vietddude/mpcium-sdk v0.0.0 go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel v1.39.0 // indirect go.opentelemetry.io/otel/metric v1.39.0 // indirect @@ -104,3 +105,5 @@ require ( replace github.com/agl/ed25519 => github.com/binance-chain/edwards25519 v0.0.0-20200305024217-f36fc4b53d43 replace github.com/bnb-chain/tss-lib/v2 => github.com/fystack/tss-lib/v2 v2.0.3 + +replace github.com/vietddude/mpcium-sdk => ../sdk diff --git a/internal/coordinator/README.md b/internal/coordinator/README.md new file mode 100644 index 00000000..8e92856d --- /dev/null +++ b/internal/coordinator/README.md @@ -0,0 +1,136 @@ +# Coordinator Package + +This package implements the control-plane coordinator for the new MPC runtime. + +It is responsible for: + +- request intake on versioned subjects (`keygen`, `sign`, `reshare`) +- session creation and lifecycle state transitions +- participant readiness and key exchange gating +- control message fan-out to participants +- participant event handling +- terminal result publishing +- timeout/abort handling + +It is not responsible for: + +- MPC cryptographic round computation +- relay mailbox behavior implementation +- decrypting participant-to-participant MPC packets + +## Main Responsibilities + +1. Accept operation requests over NATS: + - `mpc.v1.request.keygen` + - `mpc.v1.request.sign` + - `mpc.v1.request.reshare` +2. Validate request shape and participant constraints. +3. Create a new `session_id` and initial session state. +4. Fan out `session.start` control messages to each selected participant. +5. Track participant events and advance session phases. +6. Publish terminal result on `mpc.v1.session..result`. + +## Runtime Components + +- `Coordinator`: + core orchestration logic and state machine. +- `NATSRuntime`: + wiring from subjects to coordinator handlers. +- `MemorySessionStore`: + in-memory session state. +- `AtomicFileSnapshotStore`: + optional JSON snapshots for session persistence across restarts. +- `InMemoryPresenceView`: + online/offline view used during request validation. +- `NATSControlPublisher` / `NATSResultPublisher`: + delivery adapters for control and result messages. + +## Request Models + +The operation is determined by subject. Each operation has its own request struct: + +- `KeygenRequest` +- `SignRequest` +- `ReshareRequest` + +Validation rules: + +- keygen: `threshold + 1 <= len(participants)` +- sign: `len(participants) == threshold` +- reshare: validate `new_threshold` and `new_participants` consistency +- `key_type`: + - keygen: optional; empty means default key types (`secp256k1`, `ed25519`) + - sign: required + - reshare: required + +## Session Lifecycle + +States: + +- `created` +- `waiting_participants` +- `key_exchange` +- `active_mpc` +- `completed` / `failed` / `expired` + +State flow: + +1. `created`: + session object allocated. +2. `waiting_participants`: + wait until all selected participants report `peer.joined` and `peer.ready`. +3. `key_exchange`: + coordinator fans out `key_exchange.begin`, then waits for `peer.key_exchange_done` from all participants. +4. `active_mpc`: + coordinator fans out `mpc.begin`, then waits for terminal participant events. +5. terminal: + - `completed` when all participants emit `session.completed` with identical `result_hash` + - `failed` on participant/session failure or hash mismatch + - `expired` when TTL passes + +## Control and Event Flow + +Request intake: + +1. client publishes request to one of: + - `mpc.v1.request.keygen` + - `mpc.v1.request.sign` + - `mpc.v1.request.reshare` +2. coordinator validates request and participant availability. +3. coordinator returns accept/reject response via NATS request-reply. + +Session control: + +1. coordinator publishes `session.start` to each participant control inbox: + - `mpc.v1.peer..control` +2. participants report events on: + - `mpc.v1.session..event` +3. coordinator transitions session via `advance(...)` and emits: + - `key_exchange.begin` + - `mpc.begin` + - `session.abort` (when failing/expiring) + +Result publishing: + +- terminal result is published to: + - `mpc.v1.session..result` + +## Important Internal Functions + +- `HandleRequest(...)`: + parse op-specific request, validate, create session, send `session.start`. +- `HandleSessionEvent(...)`: + process participant event and update participant/session state. +- `advance(...)`: + state machine transition logic from readiness -> key exchange -> active MPC -> completed. +- `failSession(...)`: + terminal failure handling and abort broadcast. +- `expireSession(...)`: + TTL terminal handling and abort broadcast. + +## Notes for Embedders + +- coordinator is currently singleton-oriented. +- snapshots are best-effort persistence, not distributed consensus. +- event signature verification is pluggable but currently permissive by default. +- this package is internal control-plane logic; transport and participant runtimes integrate around it. diff --git a/internal/coordinator/coordinator.go b/internal/coordinator/coordinator.go new file mode 100644 index 00000000..a0485470 --- /dev/null +++ b/internal/coordinator/coordinator.go @@ -0,0 +1,493 @@ +package coordinator + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "time" + + "github.com/google/uuid" + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +type CoordinatorConfig struct { + CoordinatorID string + Signer Signer + EventVerifier SessionEventVerifier + Store *MemorySessionStore + KeyInfoStore *MemoryKeyInfoStore + Presence PresenceView + Controls ControlPublisher + Results ResultPublisher + DefaultSessionTTL time.Duration + Now func() time.Time +} + +type Coordinator struct { + id string + signer Signer + eventVerifier SessionEventVerifier + store *MemorySessionStore + keyInfoStore *MemoryKeyInfoStore + presence PresenceView + controls ControlPublisher + results ResultPublisher + defaultSessionTTL time.Duration + now func() time.Time +} + +func NewCoordinator(cfg CoordinatorConfig) (*Coordinator, error) { + if cfg.CoordinatorID == "" { + return nil, fmt.Errorf("coordinator ID is required") + } + if cfg.Signer == nil { + return nil, fmt.Errorf("signer is required") + } + if cfg.Store == nil { + return nil, fmt.Errorf("session store is required") + } + if cfg.Presence == nil { + return nil, fmt.Errorf("presence view is required") + } + if cfg.Controls == nil { + return nil, fmt.Errorf("control publisher is required") + } + if cfg.Results == nil { + return nil, fmt.Errorf("result publisher is required") + } + if cfg.Now == nil { + cfg.Now = func() time.Time { return time.Now().UTC() } + } + if cfg.DefaultSessionTTL <= 0 { + cfg.DefaultSessionTTL = 120 * time.Second + } + return &Coordinator{ + id: cfg.CoordinatorID, + signer: cfg.Signer, + eventVerifier: cfg.EventVerifier, + store: cfg.Store, + keyInfoStore: cfg.KeyInfoStore, + presence: cfg.Presence, + controls: cfg.Controls, + results: cfg.Results, + defaultSessionTTL: cfg.DefaultSessionTTL, + now: cfg.Now, + }, nil +} + +func (c *Coordinator) HandleRequest(ctx context.Context, op Operation, raw []byte) ([]byte, error) { + if op == OperationReshare { + return reject(ErrorCodeUnsupported, "reshare is unsupported in this runtime version"), nil + } + req, err := parseRequest(raw) + if err != nil { + return reject(ErrorCodeInvalidJSON, "invalid JSON request"), nil + } + if err := c.validateRequest(ctx, op, req); err != nil { + return rejectFromError(err), nil + } + + now := c.now() + sessionID := "sess_" + uuid.NewString() + start := cloneSessionStart(req.SessionStart) + start.SessionID = sessionID + start.Operation = op.ToSDK() + + participants := cloneParticipants(start.Participants) + states := make(map[string]*ParticipantState, len(participants)) + keys := make(map[string][]byte, len(participants)) + for _, participant := range participants { + states[participant.ParticipantID] = &ParticipantState{} + keys[participant.ParticipantID] = append([]byte(nil), participant.IdentityPublicKey...) + } + + session := &Session{ + ID: sessionID, + RequestID: "req_" + uuid.NewString(), + Op: op, + State: SessionCreated, + Start: start, + Participants: participants, + ParticipantState: states, + ExchangeID: "kx_" + uuid.NewString(), + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: now.Add(c.defaultSessionTTL), + ParticipantKeys: keys, + } + if err := c.store.Create(ctx, session); err != nil { + return rejectFromError(err), nil + } + + if err := c.fanOutSessionStart(ctx, session); err != nil { + _ = c.failSession(ctx, session, ErrorCodeInternal, err.Error()) + return reject(ErrorCodeInternal, "failed to publish session start"), nil + } + session.State = SessionWaitingParticipants + session.UpdatedAt = c.now() + if err := c.store.Save(ctx, session); err != nil { + return reject(ErrorCodeInternal, "failed to save session"), nil + } + + resp := sdkprotocol.RequestAccepted{ + Accepted: true, + SessionID: session.ID, + ExpiresAt: session.ExpiresAt.UTC().Format(time.RFC3339Nano), + } + return json.Marshal(resp) +} + +func (c *Coordinator) HandleSessionEvent(ctx context.Context, raw []byte) error { + var event sdkprotocol.SessionEvent + if err := json.Unmarshal(raw, &event); err != nil { + return newCoordinatorError(ErrorCodeInvalidJSON, "invalid JSON session event") + } + if err := sdkprotocol.ValidateSessionEvent(&event); err != nil { + return newCoordinatorError(ErrorCodeValidation, err.Error()) + } + + session, ok := c.store.Get(ctx, event.SessionID) + if !ok { + return newCoordinatorError(ErrorCodeValidation, "unknown session") + } + if session.State.Terminal() { + return nil + } + state, ok := session.ParticipantState[event.ParticipantID] + if !ok { + return newCoordinatorError(ErrorCodeValidation, "event sender is not a session participant") + } + if event.Sequence <= state.LastSequence { + return newCoordinatorError(ErrorCodeValidation, "replayed session event sequence") + } + if c.eventVerifier != nil { + if err := c.eventVerifier.VerifySessionEvent(ctx, session, &event); err != nil { + return err + } + } + state.LastSequence = event.Sequence + + switch { + case event.PeerJoined != nil: + state.Joined = true + case event.PeerReady != nil: + state.Ready = true + case event.PeerKeyExchangeDone != nil: + if session.State != SessionKeyExchange { + return newCoordinatorError(ErrorCodeInvalidTransition, "key exchange done outside key exchange state") + } + state.KeyExchangeDone = true + case event.SessionCompleted != nil: + if session.State != SessionActiveMPC { + return newCoordinatorError(ErrorCodeInvalidTransition, "completion outside active MPC state") + } + state.Completed = true + if event.SessionCompleted.Result == nil { + return c.failSession(ctx, session, ErrorCodeValidation, "missing result payload") + } + state.ResultHash = canonicalResultHash(event.SessionCompleted.Result) + case event.PeerFailed != nil: + state.Failed = true + state.ErrorCode = ErrorCodeParticipantFailed + state.ErrorMessage = firstNonEmpty(event.PeerFailed.Detail, "participant failed") + return c.failSession(ctx, session, state.ErrorCode, state.ErrorMessage) + case event.SessionFailed != nil: + state.Failed = true + state.ErrorCode = ErrorCodeParticipantFailed + state.ErrorMessage = firstNonEmpty(event.SessionFailed.Detail, "session failed") + return c.failSession(ctx, session, state.ErrorCode, state.ErrorMessage) + default: + return newCoordinatorError(ErrorCodeValidation, "unsupported session event type") + } + + session.UpdatedAt = c.now() + if err := c.advance(ctx, session, &event); err != nil { + return err + } + return c.store.Save(ctx, session) +} + +func (c *Coordinator) Tick(ctx context.Context) (int, error) { + now := c.now() + expired := 0 + for _, session := range c.store.ListActive(ctx) { + if !now.Before(session.ExpiresAt) { + if err := c.expireSession(ctx, session); err != nil { + return expired, err + } + expired++ + } + } + return expired, nil +} + +func parseRequest(raw []byte) (*sdkprotocol.ControlMessage, error) { + var msg sdkprotocol.ControlMessage + if err := json.Unmarshal(raw, &msg); err != nil { + return nil, err + } + return &msg, nil +} + +func (c *Coordinator) validateRequest(ctx context.Context, op Operation, msg *sdkprotocol.ControlMessage) error { + if msg == nil || msg.SessionStart == nil { + return newCoordinatorError(ErrorCodeValidation, "session_start is required") + } + start := msg.SessionStart + start.SessionID = "tmp" + start.Operation = op.ToSDK() + if err := sdkprotocol.ValidateSessionStart(start); err != nil { + return newCoordinatorError(ErrorCodeValidation, err.Error()) + } + if start.Operation != op.ToSDK() { + return newCoordinatorError(ErrorCodeValidation, "operation mismatch between subject and payload") + } + for _, participant := range start.Participants { + if string(participant.PartyKey) != participant.ParticipantID { + return newCoordinatorError(ErrorCodeValidation, "party_key must equal participant_id bytes") + } + if !c.presence.IsOnline(ctx, participant.ParticipantID) { + return newCoordinatorError(ErrorCodeUnavailable, "participant is offline") + } + } + return nil +} + +func (c *Coordinator) advance(ctx context.Context, session *Session, event *sdkprotocol.SessionEvent) error { + switch session.State { + case SessionWaitingParticipants: + if allParticipants(session, func(p *ParticipantState) bool { return p.Joined && p.Ready }) { + if err := c.fanOutKeyExchangeBegin(ctx, session); err != nil { + return err + } + session.State = SessionKeyExchange + } + case SessionKeyExchange: + if allParticipants(session, func(p *ParticipantState) bool { return p.KeyExchangeDone }) { + if err := c.fanOutMPCBegin(ctx, session); err != nil { + return err + } + session.State = SessionActiveMPC + } + case SessionActiveMPC: + if allParticipants(session, func(p *ParticipantState) bool { return p.Completed }) { + result, resultHash, err := c.buildCompletedResult(session, event) + if err != nil { + return c.failSession(ctx, session, ErrorCodeResultHashMismatch, err.Error()) + } + now := c.now() + session.State = SessionCompleted + session.ResultHash = resultHash + session.Result = result + session.CompletedAt = &now + session.UpdatedAt = now + if err := c.store.Save(ctx, session); err != nil { + return err + } + return c.results.PublishResult(ctx, session.ID, result) + } + } + return nil +} + +func (c *Coordinator) fanOutSessionStart(ctx context.Context, session *Session) error { + msg := &sdkprotocol.ControlMessage{ + SessionID: session.ID, + Sequence: c.nextControlSequence(session), + CoordinatorID: c.id, + SessionStart: cloneSessionStart(session.Start), + } + if err := SignControl(ctx, c.signer, msg); err != nil { + return err + } + for _, participant := range session.Participants { + if err := c.controls.PublishControl(ctx, participant.ParticipantID, msg); err != nil { + return err + } + } + return nil +} + +func (c *Coordinator) fanOutKeyExchangeBegin(ctx context.Context, session *Session) error { + msg := &sdkprotocol.ControlMessage{ + SessionID: session.ID, + Sequence: c.nextControlSequence(session), + CoordinatorID: c.id, + KeyExchange: &sdkprotocol.KeyExchangeBegin{ExchangeID: session.ExchangeID}, + } + if err := SignControl(ctx, c.signer, msg); err != nil { + return err + } + for _, participant := range session.Participants { + if err := c.controls.PublishControl(ctx, participant.ParticipantID, msg); err != nil { + return err + } + } + return nil +} + +func (c *Coordinator) fanOutMPCBegin(ctx context.Context, session *Session) error { + msg := &sdkprotocol.ControlMessage{ + SessionID: session.ID, + Sequence: c.nextControlSequence(session), + CoordinatorID: c.id, + MPCBegin: &sdkprotocol.MPCBegin{}, + } + if err := SignControl(ctx, c.signer, msg); err != nil { + return err + } + for _, participant := range session.Participants { + if err := c.controls.PublishControl(ctx, participant.ParticipantID, msg); err != nil { + return err + } + } + return nil +} + +func (c *Coordinator) failSession(ctx context.Context, session *Session, code, message string) error { + now := c.now() + session.State = SessionFailed + session.ErrorCode = code + session.ErrorMessage = message + session.UpdatedAt = now + session.CompletedAt = &now + abort := &sdkprotocol.ControlMessage{ + SessionID: session.ID, + Sequence: c.nextControlSequence(session), + CoordinatorID: c.id, + SessionAbort: &sdkprotocol.SessionAbort{Reason: sdkprotocol.FailureReasonAborted, Detail: message}, + } + if err := SignControl(ctx, c.signer, abort); err != nil { + return err + } + for _, participant := range session.Participants { + if err := c.controls.PublishControl(ctx, participant.ParticipantID, abort); err != nil { + return err + } + } + if err := c.store.Save(ctx, session); err != nil { + return err + } + return c.results.PublishResult(ctx, session.ID, nil) +} + +func (c *Coordinator) expireSession(ctx context.Context, session *Session) error { + now := c.now() + session.State = SessionExpired + session.ErrorCode = ErrorCodeTimeout + session.ErrorMessage = "session TTL expired" + session.UpdatedAt = now + session.CompletedAt = &now + abort := &sdkprotocol.ControlMessage{ + SessionID: session.ID, + Sequence: c.nextControlSequence(session), + CoordinatorID: c.id, + SessionAbort: &sdkprotocol.SessionAbort{Reason: sdkprotocol.FailureReasonTimeout, Detail: session.ErrorMessage}, + } + if err := SignControl(ctx, c.signer, abort); err != nil { + return err + } + for _, participant := range session.Participants { + if err := c.controls.PublishControl(ctx, participant.ParticipantID, abort); err != nil { + return err + } + } + if err := c.store.Save(ctx, session); err != nil { + return err + } + return c.results.PublishResult(ctx, session.ID, nil) +} + +func (c *Coordinator) buildCompletedResult(session *Session, event *sdkprotocol.SessionEvent) (*sdkprotocol.Result, string, error) { + var resultHash string + var result *sdkprotocol.Result + for _, state := range session.ParticipantState { + if state.ResultHash == "" { + return nil, "", fmt.Errorf("participant completed without result hash") + } + if resultHash == "" { + resultHash = state.ResultHash + continue + } + if resultHash != state.ResultHash { + return nil, "", fmt.Errorf("participant result hash mismatch") + } + } + if event == nil || event.SessionCompleted == nil || event.SessionCompleted.Result == nil { + return nil, "", fmt.Errorf("missing completion result") + } + in := event.SessionCompleted.Result + switch session.Op { + case OperationKeygen: + if in.KeyShare == nil { + return nil, "", fmt.Errorf("missing key share result") + } + result = &sdkprotocol.Result{ + KeyShare: &sdkprotocol.KeyShareResult{ + KeyID: in.KeyShare.KeyID, + PublicKey: append([]byte(nil), in.KeyShare.PublicKey...), + }, + } + case OperationSign: + if in.Signature == nil { + return nil, "", fmt.Errorf("missing signature result") + } + result = &sdkprotocol.Result{ + Signature: cloneResult(in).Signature, + } + default: + return nil, "", fmt.Errorf("unsupported operation") + } + return result, canonicalResultHash(result), nil +} + +func (c *Coordinator) nextControlSequence(session *Session) uint64 { + session.ControlSeq++ + return session.ControlSeq +} + +func allParticipants(session *Session, predicate func(*ParticipantState) bool) bool { + for _, participant := range session.ParticipantState { + if !predicate(participant) { + return false + } + } + return true +} + +func canonicalResultHash(result *sdkprotocol.Result) string { + if result == nil { + return "" + } + raw, _ := json.Marshal(result) + sum := sha256.Sum256(raw) + return hex.EncodeToString(sum[:]) +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if value != "" { + return value + } + } + return "" +} + +func reject(code, message string) []byte { + raw, _ := json.Marshal(sdkprotocol.RequestRejected{ + Accepted: false, + ErrorCode: code, + ErrorMessage: message, + }) + return raw +} + +func rejectFromError(err error) []byte { + var coordErr *CoordinatorError + if ok := AsCoordinatorError(err, &coordErr); ok { + return reject(coordErr.Code, coordErr.Message) + } + return reject(ErrorCodeInternal, err.Error()) +} diff --git a/internal/coordinator/coordinator_test.go b/internal/coordinator/coordinator_test.go new file mode 100644 index 00000000..92c13c09 --- /dev/null +++ b/internal/coordinator/coordinator_test.go @@ -0,0 +1,235 @@ +package coordinator + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/json" + "testing" + "time" + + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +type fakeSigner struct{} + +func (fakeSigner) Sign(context.Context, []byte) ([]byte, error) { return []byte("sig"), nil } + +type fakeControlPublisher struct { + published map[string][]*sdkprotocol.ControlMessage +} + +func (p *fakeControlPublisher) PublishControl(_ context.Context, participantID string, control *sdkprotocol.ControlMessage) error { + if p.published == nil { + p.published = map[string][]*sdkprotocol.ControlMessage{} + } + cloned := *control + p.published[participantID] = append(p.published[participantID], &cloned) + return nil +} + +type fakeResultPublisher struct { + results map[string]*sdkprotocol.Result +} + +func (p *fakeResultPublisher) PublishResult(_ context.Context, sessionID string, result *sdkprotocol.Result) error { + if p.results == nil { + p.results = map[string]*sdkprotocol.Result{} + } + p.results[sessionID] = result + return nil +} + +func TestTopicHelpersMatchRuntimeNamespace(t *testing.T) { + if got := RequestSubject(OperationKeygen); got != "mpc.v1.request.keygen" { + t.Fatalf("RequestSubject() = %q", got) + } + if got := PeerControlSubject("peer-node-01"); got != "mpc.v1.peer.peer-node-01.control" { + t.Fatalf("PeerControlSubject() = %q", got) + } + if got := SessionEventSubject("sess_123"); got != "mpc.v1.session.sess_123.event" { + t.Fatalf("SessionEventSubject() = %q", got) + } + if got := SessionResultSubject("sess_123"); got != "mpc.v1.session.sess_123.result" { + t.Fatalf("SessionResultSubject() = %q", got) + } +} + +func TestHandleRequestAcceptsAndFansOutSessionStart(t *testing.T) { + coord, controls, _, fixtures := newTestCoordinator(t) + markOnline(t, coord.presence, fixtures["p1"].pub, "p1") + markOnline(t, coord.presence, fixtures["p2"].pub, "p2") + + req := &sdkprotocol.ControlMessage{ + SessionStart: newSessionStart(fixtures), + } + rawReply, err := coord.HandleRequest(context.Background(), OperationSign, mustJSON(t, req)) + if err != nil { + t.Fatal(err) + } + var reply sdkprotocol.RequestAccepted + if err := json.Unmarshal(rawReply, &reply); err != nil { + t.Fatal(err) + } + if !reply.Accepted || reply.SessionID == "" { + t.Fatalf("unexpected reply: %+v", reply) + } + if len(controls.published["p1"]) == 0 || controls.published["p1"][0].SessionStart == nil { + t.Fatalf("missing session start fanout") + } +} + +func TestHandleRequestRejectsOfflineParticipant(t *testing.T) { + coord, _, _, fixtures := newTestCoordinator(t) + _ = fixtures + req := &sdkprotocol.ControlMessage{SessionStart: newSessionStart(fixtures)} + rawReply, err := coord.HandleRequest(context.Background(), OperationSign, mustJSON(t, req)) + if err != nil { + t.Fatal(err) + } + var reply sdkprotocol.RequestRejected + if err := json.Unmarshal(rawReply, &reply); err != nil { + t.Fatal(err) + } + if reply.Accepted { + t.Fatalf("expected rejection") + } + if reply.ErrorCode != ErrorCodeUnavailable { + t.Fatalf("error code = %s, want %s", reply.ErrorCode, ErrorCodeUnavailable) + } +} + +func TestLifecycleCompletesSignAndPublishesResult(t *testing.T) { + ctx := context.Background() + coord, _, results, fixtures := newTestCoordinator(t) + markOnline(t, coord.presence, fixtures["p1"].pub, "p1") + markOnline(t, coord.presence, fixtures["p2"].pub, "p2") + + req := &sdkprotocol.ControlMessage{SessionStart: newSessionStart(fixtures)} + rawReply, err := coord.HandleRequest(ctx, OperationSign, mustJSON(t, req)) + if err != nil { + t.Fatal(err) + } + var reply sdkprotocol.RequestAccepted + if err := json.Unmarshal(rawReply, &reply); err != nil { + t.Fatal(err) + } + + signResult := &sdkprotocol.Result{Signature: &sdkprotocol.SignatureResult{KeyID: "k", Signature: []byte("sig")}} + for _, participant := range []string{"p1", "p2"} { + emitSignedEvent(t, coord, reply.SessionID, fixtures, participant, &sdkprotocol.SessionEvent{PeerJoined: &sdkprotocol.PeerJoined{ParticipantID: participant}}) + emitSignedEvent(t, coord, reply.SessionID, fixtures, participant, &sdkprotocol.SessionEvent{PeerReady: &sdkprotocol.PeerReady{ParticipantID: participant}}) + } + for _, participant := range []string{"p1", "p2"} { + emitSignedEvent(t, coord, reply.SessionID, fixtures, participant, &sdkprotocol.SessionEvent{PeerKeyExchangeDone: &sdkprotocol.PeerKeyExchangeDone{ParticipantID: participant}}) + } + for _, participant := range []string{"p1", "p2"} { + emitSignedEvent(t, coord, reply.SessionID, fixtures, participant, &sdkprotocol.SessionEvent{SessionCompleted: &sdkprotocol.SessionCompleted{Result: signResult}}) + } + + result := results.results[reply.SessionID] + if result == nil || result.Signature == nil { + t.Fatalf("missing published sign result") + } +} + +type participantKey struct { + pub ed25519.PublicKey + priv ed25519.PrivateKey +} + +func newTestCoordinator(t *testing.T) (*Coordinator, *fakeControlPublisher, *fakeResultPublisher, map[string]participantKey) { + t.Helper() + fixtures := map[string]participantKey{} + for _, id := range []string{"p1", "p2"} { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + fixtures[id] = participantKey{pub: pub, priv: priv} + } + + store, err := NewMemorySessionStore(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + controls := &fakeControlPublisher{} + results := &fakeResultPublisher{} + coord, err := NewCoordinator(CoordinatorConfig{ + CoordinatorID: "coordinator-1", + Signer: fakeSigner{}, + EventVerifier: Ed25519SessionEventVerifier{}, + Store: store, + Presence: NewInMemoryPresenceView(), + Controls: controls, + Results: results, + DefaultSessionTTL: 120 * time.Second, + Now: func() time.Time { return time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC) }, + }) + if err != nil { + t.Fatal(err) + } + return coord, controls, results, fixtures +} + +func newSessionStart(keys map[string]participantKey) *sdkprotocol.SessionStart { + return &sdkprotocol.SessionStart{ + SessionID: "client-supplied", + Protocol: sdkprotocol.ProtocolTypeECDSA, + Operation: sdkprotocol.OperationTypeSign, + Threshold: 1, + Participants: []*sdkprotocol.SessionParticipant{ + {ParticipantID: "p1", PartyKey: []byte("p1"), IdentityPublicKey: keys["p1"].pub}, + {ParticipantID: "p2", PartyKey: []byte("p2"), IdentityPublicKey: keys["p2"].pub}, + }, + Sign: &sdkprotocol.SignPayload{ + KeyID: "k", + SigningInput: []byte("message"), + }, + } +} + +func emitSignedEvent(t *testing.T, coord *Coordinator, sessionID string, keys map[string]participantKey, participant string, body *sdkprotocol.SessionEvent) { + t.Helper() + event := &sdkprotocol.SessionEvent{ + SessionID: sessionID, + ParticipantID: participant, + Sequence: uint64(time.Now().UnixNano()), + } + event.PeerJoined = body.PeerJoined + event.PeerReady = body.PeerReady + event.PeerKeyExchangeDone = body.PeerKeyExchangeDone + event.SessionCompleted = body.SessionCompleted + event.SessionFailed = body.SessionFailed + payload, err := sdkprotocol.SessionEventSigningBytes(event) + if err != nil { + t.Fatal(err) + } + event.Signature = ed25519.Sign(keys[participant].priv, payload) + if err := coord.HandleSessionEvent(context.Background(), mustJSON(t, event)); err != nil { + t.Fatal(err) + } +} + +func markOnline(t *testing.T, presence PresenceView, _ ed25519.PublicKey, participantID string) { + t.Helper() + err := presence.ApplyPresence(sdkprotocol.PresenceEvent{ + PeerID: participantID, + Status: sdkprotocol.PresenceStatusOnline, + Transport: sdkprotocol.TransportTypeNATS, + ConnectionID: "conn-" + participantID, + LastSeenUnixMs: 1, + }) + if err != nil { + t.Fatal(err) + } +} + +func mustJSON(t *testing.T, v any) []byte { + t.Helper() + raw, err := json.Marshal(v) + if err != nil { + t.Fatal(err) + } + return raw +} diff --git a/internal/coordinator/errors.go b/internal/coordinator/errors.go new file mode 100644 index 00000000..37d1b82b --- /dev/null +++ b/internal/coordinator/errors.go @@ -0,0 +1,37 @@ +package coordinator + +import ( + "errors" + "fmt" +) + +const ( + ErrorCodeInvalidJSON = "INVALID_JSON" + ErrorCodeValidation = "VALIDATION_ERROR" + ErrorCodeUnauthorized = "UNAUTHORIZED" + ErrorCodeConflict = "CONFLICT" + ErrorCodeUnavailable = "UNAVAILABLE" + ErrorCodeInternal = "INTERNAL_ERROR" + ErrorCodeTimeout = "SESSION_TIMEOUT" + ErrorCodeParticipantFailed = "PARTICIPANT_FAILED" + ErrorCodeResultHashMismatch = "RESULT_HASH_MISMATCH" + ErrorCodeInvalidTransition = "INVALID_TRANSITION" + ErrorCodeUnsupported = "UNSUPPORTED_OPERATION" +) + +type CoordinatorError struct { + Code string + Message string +} + +func (e *CoordinatorError) Error() string { + return fmt.Sprintf("%s: %s", e.Code, e.Message) +} + +func newCoordinatorError(code, message string) *CoordinatorError { + return &CoordinatorError{Code: code, Message: message} +} + +func AsCoordinatorError(err error, target **CoordinatorError) bool { + return errors.As(err, target) +} diff --git a/internal/coordinator/keyinfo.go b/internal/coordinator/keyinfo.go new file mode 100644 index 00000000..40efc24b --- /dev/null +++ b/internal/coordinator/keyinfo.go @@ -0,0 +1,56 @@ +package coordinator + +import ( + "context" + "fmt" + "sync" + "time" +) + +type KeyInfo struct { + WalletID string `json:"wallet_id"` + KeyType string `json:"key_type,omitempty"` + Threshold int `json:"threshold"` + Participants []string `json:"participants"` + PublicKey []byte `json:"public_key,omitempty"` + CreatedAt string `json:"created_at"` +} + +type MemoryKeyInfoStore struct { + mu sync.RWMutex + infos map[string]KeyInfo +} + +func NewMemoryKeyInfoStore() *MemoryKeyInfoStore { + return &MemoryKeyInfoStore{infos: make(map[string]KeyInfo)} +} + +func (s *MemoryKeyInfoStore) Save(info KeyInfo) { + s.mu.Lock() + defer s.mu.Unlock() + if info.CreatedAt == "" { + info.CreatedAt = time.Now().UTC().Format(time.RFC3339Nano) + } + s.infos[info.WalletID] = info +} + +func (s *MemoryKeyInfoStore) Get(walletID string) (KeyInfo, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + info, ok := s.infos[walletID] + return info, ok +} + +func RestoreKeyInfoFromSnapshotStore(ctx context.Context, snapshots SnapshotStore, store *MemoryKeyInfoStore) error { + if snapshots == nil || store == nil { + return nil + } + infos, err := snapshots.LoadKeyInfos(ctx) + if err != nil { + return fmt.Errorf("load key info snapshots: %w", err) + } + for _, info := range infos { + store.Save(info) + } + return nil +} diff --git a/internal/coordinator/presence.go b/internal/coordinator/presence.go new file mode 100644 index 00000000..d452a255 --- /dev/null +++ b/internal/coordinator/presence.go @@ -0,0 +1,48 @@ +package coordinator + +import ( + "context" + "sync" + "time" + + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +type PresenceView interface { + IsOnline(ctx context.Context, peerID string) bool + ApplyPresence(event sdkprotocol.PresenceEvent) error +} + +type InMemoryPresenceView struct { + mu sync.RWMutex + peers map[string]sdkprotocol.PresenceEvent +} + +func NewInMemoryPresenceView() *InMemoryPresenceView { + return &InMemoryPresenceView{ + peers: make(map[string]sdkprotocol.PresenceEvent), + } +} + +func (p *InMemoryPresenceView) IsOnline(_ context.Context, peerID string) bool { + p.mu.RLock() + defer p.mu.RUnlock() + event, ok := p.peers[peerID] + if !ok { + return false + } + return event.Status == sdkprotocol.PresenceStatusOnline && event.Transport == sdkprotocol.TransportTypeNATS +} + +func (p *InMemoryPresenceView) ApplyPresence(event sdkprotocol.PresenceEvent) error { + if event.PeerID == "" { + return newCoordinatorError(ErrorCodeValidation, "invalid presence event") + } + if event.LastSeenUnixMs <= 0 { + event.LastSeenUnixMs = time.Now().UTC().UnixMilli() + } + p.mu.Lock() + defer p.mu.Unlock() + p.peers[event.PeerID] = event + return nil +} diff --git a/internal/coordinator/publisher.go b/internal/coordinator/publisher.go new file mode 100644 index 00000000..8de0b6ba --- /dev/null +++ b/internal/coordinator/publisher.go @@ -0,0 +1,56 @@ +package coordinator + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nats-io/nats.go" + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +type ControlPublisher interface { + PublishControl(ctx context.Context, participantID string, control *sdkprotocol.ControlMessage) error +} + +type ResultPublisher interface { + PublishResult(ctx context.Context, sessionID string, result *sdkprotocol.Result) error +} + +type NATSControlPublisher struct { + nc *nats.Conn +} + +func NewNATSControlPublisher(nc *nats.Conn) *NATSControlPublisher { + return &NATSControlPublisher{nc: nc} +} + +func (p *NATSControlPublisher) PublishControl(ctx context.Context, participantID string, control *sdkprotocol.ControlMessage) error { + if err := ctx.Err(); err != nil { + return err + } + raw, err := json.Marshal(control) + if err != nil { + return fmt.Errorf("marshal control: %w", err) + } + return p.nc.Publish(PeerControlSubject(participantID), raw) +} + +type NATSResultPublisher struct { + nc *nats.Conn +} + +func NewNATSResultPublisher(nc *nats.Conn) *NATSResultPublisher { + return &NATSResultPublisher{nc: nc} +} + +func (p *NATSResultPublisher) PublishResult(ctx context.Context, sessionID string, result *sdkprotocol.Result) error { + if err := ctx.Err(); err != nil { + return err + } + raw, err := json.Marshal(result) + if err != nil { + return fmt.Errorf("marshal result: %w", err) + } + return p.nc.Publish(SessionResultSubject(sessionID), raw) +} diff --git a/internal/coordinator/runtime.go b/internal/coordinator/runtime.go new file mode 100644 index 00000000..ae66bf5c --- /dev/null +++ b/internal/coordinator/runtime.go @@ -0,0 +1,72 @@ +package coordinator + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nats-io/nats.go" + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +type NATSRuntime struct { + nc *nats.Conn + coord *Coordinator + presence PresenceView + subs []*nats.Subscription +} + +func NewNATSRuntime(nc *nats.Conn, coord *Coordinator, presence PresenceView) *NATSRuntime { + return &NATSRuntime{nc: nc, coord: coord, presence: presence} +} + +func (r *NATSRuntime) Start(ctx context.Context) error { + for _, op := range []Operation{OperationKeygen, OperationSign, OperationReshare} { + op := op + sub, err := r.nc.Subscribe(RequestSubject(op), func(msg *nats.Msg) { + reply, err := r.coord.HandleRequest(ctx, op, msg.Data) + if err != nil { + reply = reject(ErrorCodeInternal, err.Error()) + } + if msg.Reply != "" { + _ = msg.Respond(reply) + } + }) + if err != nil { + return fmt.Errorf("subscribe request subject %s: %w", RequestSubject(op), err) + } + r.subs = append(r.subs, sub) + } + + eventSub, err := r.nc.Subscribe(AllSessionEventsSubject(), func(msg *nats.Msg) { + _ = r.coord.HandleSessionEvent(ctx, msg.Data) + }) + if err != nil { + return fmt.Errorf("subscribe session events: %w", err) + } + r.subs = append(r.subs, eventSub) + + presenceSub, err := r.nc.Subscribe(AllPresenceSubject(), func(msg *nats.Msg) { + var event sdkprotocol.PresenceEvent + if err := json.Unmarshal(msg.Data, &event); err != nil { + return + } + _ = r.presence.ApplyPresence(event) + }) + if err != nil { + return fmt.Errorf("subscribe presence events: %w", err) + } + r.subs = append(r.subs, presenceSub) + + return r.nc.Flush() +} + +func (r *NATSRuntime) Stop() error { + for _, sub := range r.subs { + if err := sub.Unsubscribe(); err != nil { + return err + } + } + r.subs = nil + return nil +} diff --git a/internal/coordinator/signing.go b/internal/coordinator/signing.go new file mode 100644 index 00000000..6251377a --- /dev/null +++ b/internal/coordinator/signing.go @@ -0,0 +1,78 @@ +package coordinator + +import ( + "context" + "crypto/ed25519" + "encoding/hex" + "fmt" + + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +type Signer interface { + Sign(ctx context.Context, data []byte) ([]byte, error) +} + +type SessionEventVerifier interface { + VerifySessionEvent(ctx context.Context, session *Session, event *sdkprotocol.SessionEvent) error +} + +type Ed25519Signer struct { + privateKey ed25519.PrivateKey +} + +func NewEd25519SignerFromHex(privateKeyHex string) (*Ed25519Signer, error) { + raw, err := hex.DecodeString(privateKeyHex) + if err != nil { + return nil, fmt.Errorf("decode coordinator private key hex: %w", err) + } + switch len(raw) { + case ed25519.PrivateKeySize: + return &Ed25519Signer{privateKey: ed25519.PrivateKey(raw)}, nil + case ed25519.SeedSize: + return &Ed25519Signer{privateKey: ed25519.NewKeyFromSeed(raw)}, nil + default: + return nil, fmt.Errorf("invalid Ed25519 private key length %d", len(raw)) + } +} + +func (s *Ed25519Signer) Sign(_ context.Context, data []byte) ([]byte, error) { + if len(s.privateKey) != ed25519.PrivateKeySize { + return nil, fmt.Errorf("invalid Ed25519 private key") + } + return ed25519.Sign(s.privateKey, data), nil +} + +type Ed25519SessionEventVerifier struct{} + +func (Ed25519SessionEventVerifier) VerifySessionEvent(_ context.Context, session *Session, event *sdkprotocol.SessionEvent) error { + if session == nil || event == nil { + return newCoordinatorError(ErrorCodeValidation, "invalid session event verification input") + } + pubKey, ok := session.ParticipantKeys[event.ParticipantID] + if !ok || len(pubKey) == 0 { + return newCoordinatorError(ErrorCodeUnauthorized, "unknown participant public key") + } + payload, err := sdkprotocol.SessionEventSigningBytes(event) + if err != nil { + return newCoordinatorError(ErrorCodeValidation, err.Error()) + } + if !ed25519.Verify(ed25519.PublicKey(pubKey), payload, event.Signature) { + return newCoordinatorError(ErrorCodeUnauthorized, "invalid participant event signature") + } + return nil +} + +func SignControl(ctx context.Context, signer Signer, control *sdkprotocol.ControlMessage) error { + control.Signature = nil + bytes, err := sdkprotocol.ControlSigningBytes(control) + if err != nil { + return err + } + sig, err := signer.Sign(ctx, bytes) + if err != nil { + return fmt.Errorf("sign control: %w", err) + } + control.Signature = sig + return nil +} diff --git a/internal/coordinator/store.go b/internal/coordinator/store.go new file mode 100644 index 00000000..9d7a0cd2 --- /dev/null +++ b/internal/coordinator/store.go @@ -0,0 +1,328 @@ +package coordinator + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +type SnapshotStore interface { + SaveSession(ctx context.Context, session *Session) error + LoadSessions(ctx context.Context) ([]*Session, error) + SaveKeyInfo(ctx context.Context, info KeyInfo) error + LoadKeyInfos(ctx context.Context) ([]KeyInfo, error) +} + +type AtomicFileSnapshotStore struct { + dir string +} + +func NewAtomicFileSnapshotStore(dir string) *AtomicFileSnapshotStore { + return &AtomicFileSnapshotStore{dir: dir} +} + +func (s *AtomicFileSnapshotStore) SaveSession(ctx context.Context, session *Session) error { + if err := ctx.Err(); err != nil { + return err + } + if err := os.MkdirAll(s.dir, 0o700); err != nil { + return fmt.Errorf("create snapshot dir: %w", err) + } + path := filepath.Join(s.dir, "session_"+safeFilePart(session.ID)+".json") + return writeJSONAtomic(path, session) +} + +func (s *AtomicFileSnapshotStore) LoadSessions(ctx context.Context) ([]*Session, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + entries, err := os.ReadDir(s.dir) + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("read snapshot dir: %w", err) + } + sessions := make([]*Session, 0) + for _, entry := range entries { + if entry.IsDir() || !strings.HasPrefix(entry.Name(), "session_") || !strings.HasSuffix(entry.Name(), ".json") { + continue + } + raw, err := os.ReadFile(filepath.Join(s.dir, entry.Name())) + if err != nil { + return nil, fmt.Errorf("read session snapshot %s: %w", entry.Name(), err) + } + var session Session + if err := json.Unmarshal(raw, &session); err != nil { + return nil, fmt.Errorf("parse session snapshot %s: %w", entry.Name(), err) + } + sessions = append(sessions, &session) + } + return sessions, nil +} + +func (s *AtomicFileSnapshotStore) SaveKeyInfo(ctx context.Context, info KeyInfo) error { + if err := ctx.Err(); err != nil { + return err + } + if err := os.MkdirAll(s.dir, 0o700); err != nil { + return fmt.Errorf("create snapshot dir: %w", err) + } + path := filepath.Join(s.dir, "keyinfo_"+safeFilePart(info.WalletID)+".json") + return writeJSONAtomic(path, info) +} + +func (s *AtomicFileSnapshotStore) LoadKeyInfos(ctx context.Context) ([]KeyInfo, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + entries, err := os.ReadDir(s.dir) + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("read snapshot dir: %w", err) + } + infos := make([]KeyInfo, 0) + for _, entry := range entries { + if entry.IsDir() || !strings.HasPrefix(entry.Name(), "keyinfo_") || !strings.HasSuffix(entry.Name(), ".json") { + continue + } + raw, err := os.ReadFile(filepath.Join(s.dir, entry.Name())) + if err != nil { + return nil, fmt.Errorf("read key info snapshot %s: %w", entry.Name(), err) + } + var info KeyInfo + if err := json.Unmarshal(raw, &info); err != nil { + return nil, fmt.Errorf("parse key info snapshot %s: %w", entry.Name(), err) + } + infos = append(infos, info) + } + return infos, nil +} + +func writeJSONAtomic(path string, value any) error { + raw, err := json.MarshalIndent(value, "", " ") + if err != nil { + return fmt.Errorf("marshal snapshot: %w", err) + } + tmp := path + ".tmp" + if err := os.WriteFile(tmp, raw, 0o600); err != nil { + return fmt.Errorf("write snapshot temp file: %w", err) + } + if err := os.Rename(tmp, path); err != nil { + return fmt.Errorf("replace snapshot: %w", err) + } + return nil +} + +func safeFilePart(value string) string { + replacer := strings.NewReplacer("/", "_", "\\", "_", ":", "_", "..", "_") + return replacer.Replace(value) +} + +type MemorySessionStore struct { + mu sync.RWMutex + sessions map[string]*Session + requests map[string]string + snapshots SnapshotStore +} + +func NewMemorySessionStore(ctx context.Context, snapshots SnapshotStore) (*MemorySessionStore, error) { + store := &MemorySessionStore{ + sessions: make(map[string]*Session), + requests: make(map[string]string), + snapshots: snapshots, + } + if snapshots == nil { + return store, nil + } + sessions, err := snapshots.LoadSessions(ctx) + if err != nil { + return nil, err + } + for _, session := range sessions { + cloned := cloneSession(session) + store.sessions[cloned.ID] = cloned + if cloned.RequestID != "" { + store.requests[cloned.RequestID] = cloned.ID + } + } + return store, nil +} + +func (s *MemorySessionStore) Create(ctx context.Context, session *Session) error { + s.mu.Lock() + if _, ok := s.sessions[session.ID]; ok { + s.mu.Unlock() + return newCoordinatorError(ErrorCodeConflict, "session already exists") + } + if existingID, ok := s.requests[session.RequestID]; ok && existingID != "" { + s.mu.Unlock() + return newCoordinatorError(ErrorCodeConflict, "request already accepted") + } + s.sessions[session.ID] = cloneSession(session) + s.requests[session.RequestID] = session.ID + s.mu.Unlock() + return s.snapshot(ctx, session) +} + +func (s *MemorySessionStore) Save(ctx context.Context, session *Session) error { + s.mu.Lock() + if _, ok := s.sessions[session.ID]; !ok { + s.mu.Unlock() + return newCoordinatorError(ErrorCodeValidation, "unknown session") + } + s.sessions[session.ID] = cloneSession(session) + s.mu.Unlock() + return s.snapshot(ctx, session) +} + +func (s *MemorySessionStore) Get(_ context.Context, sessionID string) (*Session, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + session, ok := s.sessions[sessionID] + if !ok { + return nil, false + } + return cloneSession(session), true +} + +func (s *MemorySessionStore) GetByRequestID(_ context.Context, requestID string) (*Session, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + sessionID, ok := s.requests[requestID] + if !ok { + return nil, false + } + session, ok := s.sessions[sessionID] + if !ok { + return nil, false + } + return cloneSession(session), true +} + +func (s *MemorySessionStore) ListActive(_ context.Context) []*Session { + s.mu.RLock() + defer s.mu.RUnlock() + sessions := make([]*Session, 0, len(s.sessions)) + for _, session := range s.sessions { + if !session.State.Terminal() { + sessions = append(sessions, cloneSession(session)) + } + } + return sessions +} + +func (s *MemorySessionStore) snapshot(ctx context.Context, session *Session) error { + if s.snapshots == nil { + return nil + } + return s.snapshots.SaveSession(ctx, session) +} + +func cloneSession(session *Session) *Session { + if session == nil { + return nil + } + cloned := *session + cloned.Start = cloneSessionStart(session.Start) + cloned.Participants = cloneParticipants(session.Participants) + cloned.ParticipantKeys = cloneKeyMap(session.ParticipantKeys) + cloned.Result = cloneResult(session.Result) + cloned.ParticipantState = make(map[string]*ParticipantState, len(session.ParticipantState)) + for peerID, state := range session.ParticipantState { + stateCopy := *state + cloned.ParticipantState[peerID] = &stateCopy + } + if session.CompletedAt != nil { + completedAt := *session.CompletedAt + cloned.CompletedAt = &completedAt + } + return &cloned +} + +func cloneSessionStart(start *sdkprotocol.SessionStart) *sdkprotocol.SessionStart { + if start == nil { + return nil + } + cloned := *start + cloned.Participants = cloneParticipants(start.Participants) + if start.Keygen != nil { + keygen := *start.Keygen + cloned.Keygen = &keygen + } + if start.Sign != nil { + sign := *start.Sign + sign.SigningInput = append([]byte(nil), start.Sign.SigningInput...) + if start.Sign.Derivation != nil { + derivation := *start.Sign.Derivation + derivation.Path = append([]uint32(nil), start.Sign.Derivation.Path...) + derivation.Delta = append([]byte(nil), start.Sign.Derivation.Delta...) + sign.Derivation = &derivation + } + cloned.Sign = &sign + } + if start.Reshare != nil { + reshare := *start.Reshare + reshare.NewParticipants = cloneParticipants(start.Reshare.NewParticipants) + cloned.Reshare = &reshare + } + return &cloned +} + +func cloneParticipants(participants []*sdkprotocol.SessionParticipant) []*sdkprotocol.SessionParticipant { + out := make([]*sdkprotocol.SessionParticipant, 0, len(participants)) + for _, participant := range participants { + if participant == nil { + continue + } + cloned := *participant + cloned.PartyKey = append([]byte(nil), participant.PartyKey...) + cloned.IdentityPublicKey = append([]byte(nil), participant.IdentityPublicKey...) + out = append(out, &cloned) + } + return out +} + +func cloneResult(result *sdkprotocol.Result) *sdkprotocol.Result { + if result == nil { + return nil + } + cloned := *result + if result.KeyShare != nil { + keyShare := *result.KeyShare + keyShare.ShareBlob = append([]byte(nil), result.KeyShare.ShareBlob...) + keyShare.PublicKey = append([]byte(nil), result.KeyShare.PublicKey...) + cloned.KeyShare = &keyShare + } + if result.Signature != nil { + signature := *result.Signature + signature.Signature = append([]byte(nil), result.Signature.Signature...) + signature.SignatureRecovery = append([]byte(nil), result.Signature.SignatureRecovery...) + signature.R = append([]byte(nil), result.Signature.R...) + signature.S = append([]byte(nil), result.Signature.S...) + signature.SignedInput = append([]byte(nil), result.Signature.SignedInput...) + signature.PublicKey = append([]byte(nil), result.Signature.PublicKey...) + cloned.Signature = &signature + } + return &cloned +} + +func cloneKeyMap(src map[string][]byte) map[string][]byte { + if len(src) == 0 { + return nil + } + out := make(map[string][]byte, len(src)) + for key, value := range src { + out[key] = append([]byte(nil), value...) + } + return out +} diff --git a/internal/coordinator/topics.go b/internal/coordinator/topics.go new file mode 100644 index 00000000..37cac48d --- /dev/null +++ b/internal/coordinator/topics.go @@ -0,0 +1,33 @@ +package coordinator + +import "fmt" + +const TopicPrefix = "mpc.v1" + +func RequestSubject(op Operation) string { + return fmt.Sprintf("%s.request.%s", TopicPrefix, op) +} + +func PeerControlSubject(peerID string) string { + return fmt.Sprintf("%s.peer.%s.control", TopicPrefix, peerID) +} + +func PeerPresenceSubject(peerID string) string { + return fmt.Sprintf("%s.peer.%s.presence", TopicPrefix, peerID) +} + +func SessionEventSubject(sessionID string) string { + return fmt.Sprintf("%s.session.%s.event", TopicPrefix, sessionID) +} + +func SessionResultSubject(sessionID string) string { + return fmt.Sprintf("%s.session.%s.result", TopicPrefix, sessionID) +} + +func AllPresenceSubject() string { + return fmt.Sprintf("%s.peer.*.presence", TopicPrefix) +} + +func AllSessionEventsSubject() string { + return fmt.Sprintf("%s.session.*.event", TopicPrefix) +} diff --git a/internal/coordinator/types.go b/internal/coordinator/types.go new file mode 100644 index 00000000..f85eb423 --- /dev/null +++ b/internal/coordinator/types.go @@ -0,0 +1,81 @@ +package coordinator + +import ( + "time" + + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +type Operation string + +const ( + OperationKeygen Operation = "keygen" + OperationSign Operation = "sign" + OperationReshare Operation = "reshare" +) + +func (o Operation) Valid() bool { + return o == OperationKeygen || o == OperationSign || o == OperationReshare +} + +func (o Operation) ToSDK() sdkprotocol.OperationType { + switch o { + case OperationKeygen: + return sdkprotocol.OperationTypeKeygen + case OperationSign: + return sdkprotocol.OperationTypeSign + case OperationReshare: + return sdkprotocol.OperationTypeReshare + default: + return sdkprotocol.OperationTypeUnspecified + } +} + +type SessionState string + +const ( + SessionCreated SessionState = "created" + SessionWaitingParticipants SessionState = "waiting_participants" + SessionKeyExchange SessionState = "key_exchange" + SessionActiveMPC SessionState = "active_mpc" + SessionCompleted SessionState = "completed" + SessionFailed SessionState = "failed" + SessionExpired SessionState = "expired" +) + +func (s SessionState) Terminal() bool { + return s == SessionCompleted || s == SessionFailed || s == SessionExpired +} + +type ParticipantState struct { + Joined bool `json:"joined"` + Ready bool `json:"ready"` + KeyExchangeDone bool `json:"key_exchange_done"` + Completed bool `json:"completed"` + Failed bool `json:"failed"` + LastSequence uint64 `json:"last_sequence"` + ResultHash string `json:"result_hash,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` +} + +type Session struct { + ID string `json:"id"` + RequestID string `json:"request_id"` + Op Operation `json:"op"` + State SessionState `json:"state"` + Start *sdkprotocol.SessionStart `json:"start"` + Participants []*sdkprotocol.SessionParticipant `json:"participants"` + ParticipantState map[string]*ParticipantState `json:"participant_state"` + ExchangeID string `json:"exchange_id,omitempty"` + ResultHash string `json:"result_hash,omitempty"` + Result *sdkprotocol.Result `json:"result,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + ExpiresAt time.Time `json:"expires_at"` + CompletedAt *time.Time `json:"completed_at,omitempty"` + ControlSeq uint64 `json:"control_seq"` + ParticipantKeys map[string][]byte `json:"participant_keys"` +} From 11b11f797d57794feaf6ee32fc4620b2ff914508 Mon Sep 17 00:00:00 2001 From: vietddude Date: Thu, 16 Apr 2026 17:51:30 +0700 Subject: [PATCH 02/23] Add initial implementation of cosigner runtime with configuration, identity management, and session handling. Introduce badger storage for session artifacts and implement NATS messaging topics for peer communication. --- cmd/mpcium-cosigner/main.go | 97 +++++++++++ internal/cosigner/config.go | 15 ++ internal/cosigner/identity.go | 40 +++++ internal/cosigner/runtime.go | 297 ++++++++++++++++++++++++++++++++++ internal/cosigner/storage.go | 97 +++++++++++ internal/cosigner/topics.go | 25 +++ 6 files changed, 571 insertions(+) create mode 100644 cmd/mpcium-cosigner/main.go create mode 100644 internal/cosigner/config.go create mode 100644 internal/cosigner/identity.go create mode 100644 internal/cosigner/runtime.go create mode 100644 internal/cosigner/storage.go create mode 100644 internal/cosigner/topics.go diff --git a/cmd/mpcium-cosigner/main.go b/cmd/mpcium-cosigner/main.go new file mode 100644 index 00000000..947e0251 --- /dev/null +++ b/cmd/mpcium-cosigner/main.go @@ -0,0 +1,97 @@ +package main + +import ( + "context" + "encoding/hex" + "flag" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "github.com/fystack/mpcium/internal/cosigner" +) + +func main() { + if err := run(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func run() error { + nodeID := flag.String("node-id", envDefault("NODE_ID", ""), "local participant ID") + natsURL := flag.String("nats-url", envDefault("NATS_URL", "nats://127.0.0.1:4222"), "NATS server URL") + coordinatorID := flag.String("coordinator-id", envDefault("COORDINATOR_ID", ""), "coordinator ID") + coordinatorPubHex := flag.String("coordinator-public-key-hex", envDefault("COORDINATOR_PUBLIC_KEY_HEX", ""), "coordinator Ed25519 public key hex") + privateKeyHex := flag.String("identity-private-key-hex", envDefault("IDENTITY_PRIVATE_KEY_HEX", ""), "node Ed25519 private key hex") + dataDir := flag.String("data-dir", envDefault("NODE_V1_DATA_DIR", "node-v1-data"), "node-v1 badger data directory") + maxActive := flag.Int("max-active-sessions", envIntDefault("NODE_V1_MAX_ACTIVE_SESSIONS", 64), "maximum concurrent active sessions") + presenceInterval := flag.Duration("presence-interval", envDurationDefault("NODE_V1_PRESENCE_INTERVAL", 5*time.Second), "presence heartbeat interval") + tickInterval := flag.Duration("tick-interval", envDurationDefault("NODE_V1_TICK_INTERVAL", 100*time.Millisecond), "participant tick interval") + flag.Parse() + + if *nodeID == "" || *coordinatorID == "" || *coordinatorPubHex == "" || *privateKeyHex == "" { + return fmt.Errorf("node-id, coordinator-id, coordinator-public-key-hex, and identity-private-key-hex are required") + } + coordinatorKey, err := hex.DecodeString(*coordinatorPubHex) + if err != nil { + return fmt.Errorf("decode coordinator public key: %w", err) + } + privateKey, err := hex.DecodeString(*privateKeyHex) + if err != nil { + return fmt.Errorf("decode identity private key: %w", err) + } + + runtime, err := cosigner.NewRuntime(cosigner.Config{ + NodeID: *nodeID, + NATSURL: *natsURL, + CoordinatorID: *coordinatorID, + CoordinatorPublicKey: coordinatorKey, + IdentityPrivateKey: privateKey, + DataDir: *dataDir, + MaxActiveSessions: *maxActive, + PresenceInterval: *presenceInterval, + TickInterval: *tickInterval, + }) + if err != nil { + return err + } + defer runtime.Close() + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + return runtime.Run(ctx) +} + +func envDefault(name, fallback string) string { + if value := os.Getenv(name); value != "" { + return value + } + return fallback +} + +func envDurationDefault(name string, fallback time.Duration) time.Duration { + value := os.Getenv(name) + if value == "" { + return fallback + } + parsed, err := time.ParseDuration(value) + if err != nil { + return fallback + } + return parsed +} + +func envIntDefault(name string, fallback int) int { + value := os.Getenv(name) + if value == "" { + return fallback + } + var parsed int + if _, err := fmt.Sscanf(value, "%d", &parsed); err != nil { + return fallback + } + return parsed +} diff --git a/internal/cosigner/config.go b/internal/cosigner/config.go new file mode 100644 index 00000000..e4353a94 --- /dev/null +++ b/internal/cosigner/config.go @@ -0,0 +1,15 @@ +package cosigner + +import "time" + +type Config struct { + NodeID string + NATSURL string + CoordinatorID string + CoordinatorPublicKey []byte + IdentityPrivateKey []byte + DataDir string + MaxActiveSessions int + PresenceInterval time.Duration + TickInterval time.Duration +} diff --git a/internal/cosigner/identity.go b/internal/cosigner/identity.go new file mode 100644 index 00000000..a196a293 --- /dev/null +++ b/internal/cosigner/identity.go @@ -0,0 +1,40 @@ +package cosigner + +import ( + "crypto/ed25519" + "fmt" +) + +type localIdentity struct { + participantID string + publicKey ed25519.PublicKey + privateKey ed25519.PrivateKey +} + +func (i *localIdentity) ParticipantID() string { return i.participantID } +func (i *localIdentity) PublicKey() ed25519.PublicKey { + return i.publicKey +} +func (i *localIdentity) Sign(message []byte) ([]byte, error) { + return ed25519.Sign(i.privateKey, message), nil +} + +type peerLookup struct{ keys map[string]ed25519.PublicKey } + +func (l *peerLookup) LookupParticipant(participantID string) (ed25519.PublicKey, error) { + key, ok := l.keys[participantID] + if !ok { + return nil, fmt.Errorf("peer %s not found", participantID) + } + return key, nil +} + +type coordinatorLookup struct{ keys map[string]ed25519.PublicKey } + +func (l *coordinatorLookup) LookupCoordinator(coordinatorID string) (ed25519.PublicKey, error) { + key, ok := l.keys[coordinatorID] + if !ok { + return nil, fmt.Errorf("coordinator %s not found", coordinatorID) + } + return key, nil +} diff --git a/internal/cosigner/runtime.go b/internal/cosigner/runtime.go new file mode 100644 index 00000000..29ec8300 --- /dev/null +++ b/internal/cosigner/runtime.go @@ -0,0 +1,297 @@ +package cosigner + +import ( + "context" + "crypto/ed25519" + "encoding/json" + "errors" + "fmt" + "sync" + "time" + + "github.com/nats-io/nats.go" + "github.com/vietddude/mpcium-sdk/participant" + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +type Runtime struct { + cfg Config + nc *nats.Conn + stores *badgerStores + identity *localIdentity + coordLookup *coordinatorLookup + sessionsMu sync.RWMutex + sessions map[string]*participant.ParticipantSession + subs []*nats.Subscription +} + +func NewRuntime(cfg Config) (*Runtime, error) { + if cfg.NodeID == "" { + return nil, errors.New("node_id is required") + } + if cfg.NATSURL == "" { + return nil, errors.New("nats_url is required") + } + if cfg.CoordinatorID == "" || len(cfg.CoordinatorPublicKey) != ed25519.PublicKeySize { + return nil, errors.New("valid coordinator key is required") + } + if len(cfg.IdentityPrivateKey) != ed25519.PrivateKeySize { + return nil, errors.New("valid identity private key is required") + } + if cfg.MaxActiveSessions <= 0 { + cfg.MaxActiveSessions = 64 + } + if cfg.PresenceInterval <= 0 { + cfg.PresenceInterval = 5 * time.Second + } + if cfg.TickInterval <= 0 { + cfg.TickInterval = 100 * time.Millisecond + } + + nc, err := nats.Connect(cfg.NATSURL) + if err != nil { + return nil, fmt.Errorf("connect nats: %w", err) + } + stores, err := newBadgerStores(cfg.DataDir) + if err != nil { + nc.Close() + return nil, err + } + private := ed25519.PrivateKey(cfg.IdentityPrivateKey) + public := private.Public().(ed25519.PublicKey) + return &Runtime{ + cfg: cfg, + nc: nc, + stores: stores, + identity: &localIdentity{participantID: cfg.NodeID, publicKey: public, privateKey: private}, + coordLookup: &coordinatorLookup{keys: map[string]ed25519.PublicKey{ + cfg.CoordinatorID: append([]byte(nil), cfg.CoordinatorPublicKey...), + }}, + sessions: map[string]*participant.ParticipantSession{}, + }, nil +} + +func (r *Runtime) Close() error { + for _, sub := range r.subs { + _ = sub.Unsubscribe() + } + if r.nc != nil { + r.nc.Close() + } + if r.stores != nil { + return r.stores.Close() + } + return nil +} + +func (r *Runtime) Run(ctx context.Context) error { + if err := r.subscribe(); err != nil { + return err + } + if err := r.publishPresence(sdkprotocol.PresenceStatusOnline); err != nil { + return err + } + + tick := time.NewTicker(r.cfg.TickInterval) + defer tick.Stop() + presence := time.NewTicker(r.cfg.PresenceInterval) + defer presence.Stop() + for { + select { + case <-ctx.Done(): + _ = r.publishPresence(sdkprotocol.PresenceStatusOffline) + return nil + case <-tick.C: + if err := r.tickSessions(); err != nil { + return err + } + case <-presence.C: + if err := r.publishPresence(sdkprotocol.PresenceStatusOnline); err != nil { + return err + } + } + } +} + +func (r *Runtime) subscribe() error { + controlSub, err := r.nc.Subscribe(controlSubject(r.cfg.NodeID), func(msg *nats.Msg) { + _ = r.handleControl(msg.Data) + }) + if err != nil { + return err + } + r.subs = append(r.subs, controlSub) + + p2pSub, err := r.nc.Subscribe(p2pWildcardSubject(r.cfg.NodeID), func(msg *nats.Msg) { + _ = r.handlePeer(msg.Data) + }) + if err != nil { + return err + } + r.subs = append(r.subs, p2pSub) + + return r.nc.Flush() +} + +func (r *Runtime) handleControl(raw []byte) error { + var msg sdkprotocol.ControlMessage + if err := json.Unmarshal(raw, &msg); err != nil { + return err + } + if err := sdkprotocol.ValidateControlMessage(&msg); err != nil { + return err + } + + if msg.SessionStart != nil { + return r.startSession(&msg) + } + session := r.getSession(msg.SessionID) + if session == nil { + return fmt.Errorf("unknown session %s", msg.SessionID) + } + effects, err := session.HandleControl(&msg) + if err != nil { + return err + } + return r.publishEffects(effects) +} + +func (r *Runtime) startSession(msg *sdkprotocol.ControlMessage) error { + if len(r.sessions) >= r.cfg.MaxActiveSessions { + return errors.New("max active sessions reached") + } + if err := r.verifyControlSignature(msg); err != nil { + return err + } + peerKeys := make(map[string]ed25519.PublicKey, len(msg.SessionStart.Participants)) + for _, participantDef := range msg.SessionStart.Participants { + if participantDef.ParticipantID == r.cfg.NodeID { + continue + } + peerKeys[participantDef.ParticipantID] = append([]byte(nil), participantDef.IdentityPublicKey...) + } + sess, err := participant.New(participant.Config{ + Start: msg.SessionStart, + LocalParticipantID: r.cfg.NodeID, + Identity: r.identity, + Peers: &peerLookup{keys: peerKeys}, + Coordinator: r.coordLookup, + Preparams: r.stores, + Shares: r.stores, + SessionArtifacts: r.stores, + }) + if err != nil { + return err + } + r.sessionsMu.Lock() + r.sessions[msg.SessionID] = sess + r.sessionsMu.Unlock() + + effects, err := sess.Start() + if err != nil { + return err + } + return r.publishEffects(effects) +} + +func (r *Runtime) handlePeer(raw []byte) error { + var msg sdkprotocol.PeerMessage + if err := json.Unmarshal(raw, &msg); err != nil { + return err + } + session := r.getSession(msg.SessionID) + if session == nil { + return fmt.Errorf("unknown session %s", msg.SessionID) + } + effects, err := session.HandlePeer(&msg) + if err != nil { + return err + } + return r.publishEffects(effects) +} + +func (r *Runtime) tickSessions() error { + r.sessionsMu.RLock() + ids := make([]string, 0, len(r.sessions)) + for id := range r.sessions { + ids = append(ids, id) + } + r.sessionsMu.RUnlock() + for _, id := range ids { + session := r.getSession(id) + if session == nil { + continue + } + effects, err := session.Tick(time.Now()) + if err != nil { + return err + } + if err := r.publishEffects(effects); err != nil { + return err + } + } + return nil +} + +func (r *Runtime) publishEffects(effects participant.Effects) error { + for _, peerMsg := range effects.PeerMessages { + raw, err := json.Marshal(peerMsg) + if err != nil { + return err + } + if err := r.nc.Publish(p2pSubject(peerMsg.ToParticipantID, peerMsg.SessionID), raw); err != nil { + return err + } + } + for _, event := range effects.SessionEvents { + raw, err := json.Marshal(event) + if err != nil { + return err + } + if err := r.nc.Publish(sessionEventSubject(event.SessionID), raw); err != nil { + return err + } + } + if effects.Cleanup != nil && effects.Cleanup.DropArtifacts { + _ = r.stores.DeleteSessionArtifacts(effects.Cleanup.SessionID) + } + return nil +} + +func (r *Runtime) publishPresence(status sdkprotocol.PresenceStatus) error { + event := sdkprotocol.PresenceEvent{ + PeerID: r.cfg.NodeID, + Status: status, + Transport: sdkprotocol.TransportTypeNATS, + LastSeenUnixMs: time.Now().UTC().UnixMilli(), + } + if status == sdkprotocol.PresenceStatusOnline { + event.ConnectionID = "nats:" + r.cfg.NodeID + } + raw, err := json.Marshal(event) + if err != nil { + return err + } + return r.nc.Publish(presenceSubject(r.cfg.NodeID), raw) +} + +func (r *Runtime) getSession(sessionID string) *participant.ParticipantSession { + r.sessionsMu.RLock() + defer r.sessionsMu.RUnlock() + return r.sessions[sessionID] +} + +func (r *Runtime) verifyControlSignature(msg *sdkprotocol.ControlMessage) error { + pub, err := r.coordLookup.LookupCoordinator(msg.CoordinatorID) + if err != nil { + return err + } + payload, err := sdkprotocol.ControlSigningBytes(msg) + if err != nil { + return err + } + if !ed25519.Verify(pub, payload, msg.Signature) { + return errors.New("invalid control signature") + } + return nil +} diff --git a/internal/cosigner/storage.go b/internal/cosigner/storage.go new file mode 100644 index 00000000..9e206f5e --- /dev/null +++ b/internal/cosigner/storage.go @@ -0,0 +1,97 @@ +package cosigner + +import ( + "fmt" + "path/filepath" + + "github.com/dgraph-io/badger/v4" + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +type badgerStores struct { + db *badger.DB +} + +func newBadgerStores(dataDir string) (*badgerStores, error) { + opts := badger.DefaultOptions(filepath.Join(dataDir, "node-v1-badger")) + opts.Logger = nil + db, err := badger.Open(opts) + if err != nil { + return nil, err + } + return &badgerStores{db: db}, nil +} + +func (s *badgerStores) Close() error { + if s == nil || s.db == nil { + return nil + } + return s.db.Close() +} + +func (s *badgerStores) LoadPreparams(protocolType sdkprotocol.ProtocolType, keyID string) ([]byte, error) { + return s.load(keyPreparams(protocolType, keyID)) +} + +func (s *badgerStores) SavePreparams(protocolType sdkprotocol.ProtocolType, keyID string, preparams []byte) error { + return s.save(keyPreparams(protocolType, keyID), preparams) +} + +func (s *badgerStores) LoadShare(protocolType sdkprotocol.ProtocolType, keyID string) ([]byte, error) { + return s.load(keyShare(protocolType, keyID)) +} + +func (s *badgerStores) SaveShare(protocolType sdkprotocol.ProtocolType, keyID string, share []byte) error { + return s.save(keyShare(protocolType, keyID), share) +} + +func (s *badgerStores) LoadSessionArtifacts(sessionID string) ([]byte, error) { + return s.load(keyArtifact(sessionID)) +} + +func (s *badgerStores) SaveSessionArtifacts(sessionID string, artifact []byte) error { + return s.save(keyArtifact(sessionID), artifact) +} + +func (s *badgerStores) DeleteSessionArtifacts(sessionID string) error { + return s.db.Update(func(txn *badger.Txn) error { + return txn.Delete([]byte(keyArtifact(sessionID))) + }) +} + +func (s *badgerStores) load(key string) ([]byte, error) { + var value []byte + err := s.db.View(func(txn *badger.Txn) error { + item, err := txn.Get([]byte(key)) + if err != nil { + if err == badger.ErrKeyNotFound { + value = nil + return nil + } + return err + } + return item.Value(func(v []byte) error { + value = append([]byte(nil), v...) + return nil + }) + }) + return value, err +} + +func (s *badgerStores) save(key string, value []byte) error { + return s.db.Update(func(txn *badger.Txn) error { + return txn.Set([]byte(key), append([]byte(nil), value...)) + }) +} + +func keyPreparams(protocolType sdkprotocol.ProtocolType, keyID string) string { + return fmt.Sprintf("preparams:%s:%s", protocolType, keyID) +} + +func keyShare(protocolType sdkprotocol.ProtocolType, keyID string) string { + return fmt.Sprintf("shares:%s:%s", protocolType, keyID) +} + +func keyArtifact(sessionID string) string { + return "artifacts:" + sessionID +} diff --git a/internal/cosigner/topics.go b/internal/cosigner/topics.go new file mode 100644 index 00000000..acc25f06 --- /dev/null +++ b/internal/cosigner/topics.go @@ -0,0 +1,25 @@ +package cosigner + +import "fmt" + +const topicPrefix = "mpc.v1" + +func controlSubject(peerID string) string { + return fmt.Sprintf("%s.peer.%s.control", topicPrefix, peerID) +} + +func p2pSubject(peerID, sessionID string) string { + return fmt.Sprintf("%s.peer.%s.session.%s.p2p", topicPrefix, peerID, sessionID) +} + +func p2pWildcardSubject(peerID string) string { + return fmt.Sprintf("%s.peer.%s.session.*.p2p", topicPrefix, peerID) +} + +func sessionEventSubject(sessionID string) string { + return fmt.Sprintf("%s.session.%s.event", topicPrefix, sessionID) +} + +func presenceSubject(peerID string) string { + return fmt.Sprintf("%s.peer.%s.presence", topicPrefix, peerID) +} From d36a7865c20e0705225c0bbf57c0b3177ce9f35a Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 17 Apr 2026 01:26:05 +0700 Subject: [PATCH 03/23] Implement configuration management for coordinator and cosigner runtimes, including YAML configuration files. Introduce a coordinator client for key generation requests and enhance logging throughout the runtime processes. Update .gitignore to include new configuration files. --- .gitignore | 4 +- cmd/mpcium-coordinator/main.go | 116 +++++------- cmd/mpcium-cosigner/main.go | 104 ++++------- coordinator.config.yaml | 8 + cosigner.config.yaml | 11 ++ cosigner2.config.yaml | 11 ++ examples/coordinatorclient-keygen/main.go | 101 +++++++++++ internal/coordinator/config.go | 68 +++++++ internal/coordinator/coordinator.go | 103 ++++++++++- internal/coordinator/result_hash_test.go | 60 +++++++ internal/coordinator/runtime.go | 13 +- internal/coordinator/signing.go | 3 + internal/coordinator/signing_test.go | 29 +++ internal/cosigner/config.go | 71 +++++++- internal/cosigner/runtime.go | 147 ++++++++++++++- pkg/coordinatorclient/client.go | 208 ++++++++++++++++++++++ 16 files changed, 898 insertions(+), 159 deletions(-) create mode 100644 coordinator.config.yaml create mode 100644 cosigner.config.yaml create mode 100644 cosigner2.config.yaml create mode 100644 examples/coordinatorclient-keygen/main.go create mode 100644 internal/coordinator/config.go create mode 100644 internal/coordinator/result_hash_test.go create mode 100644 internal/coordinator/signing_test.go create mode 100644 pkg/coordinatorclient/client.go diff --git a/.gitignore b/.gitignore index 90d18835..3768ac41 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,6 @@ node2 config.yaml .vscode .vagrant -.chain_code \ No newline at end of file +.chain_code +.gomodcache +.codex \ No newline at end of file diff --git a/cmd/mpcium-coordinator/main.go b/cmd/mpcium-coordinator/main.go index e18e44d0..5de34807 100644 --- a/cmd/mpcium-coordinator/main.go +++ b/cmd/mpcium-coordinator/main.go @@ -2,57 +2,71 @@ package main import ( "context" - "flag" "fmt" "os" "os/signal" - "strconv" "syscall" "time" "github.com/fystack/mpcium/internal/coordinator" + "github.com/fystack/mpcium/pkg/config" + "github.com/fystack/mpcium/pkg/logger" "github.com/nats-io/nats.go" + "github.com/urfave/cli/v3" ) +const coordinatorConfigPath = "coordinator.config.yaml" + func main() { - if err := run(); err != nil { - fmt.Fprintln(os.Stderr, err) + logger.Init(os.Getenv("ENVIRONMENT"), false) + + cmd := &cli.Command{ + Name: "mpcium-coordinator", + Usage: "Run MPC coordinator runtime", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "config", + Aliases: []string{"c"}, + Usage: "Path to coordinator config file", + Value: coordinatorConfigPath, + }, + }, + Action: run, + } + + if err := cmd.Run(context.Background(), os.Args); err != nil { + logger.Error("coordinator exited with error", err) os.Exit(1) } } -func run() error { - natsURL := flag.String("nats-url", envDefault("NATS_URL", nats.DefaultURL), "NATS server URL") - coordinatorID := flag.String("coordinator-id", envDefault("COORDINATOR_ID", ""), "stable coordinator ID") - privateKeyHex := flag.String("coordinator-private-key-hex", envDefault("COORDINATOR_PRIVATE_KEY_HEX", ""), "hex encoded Ed25519 private key") - snapshotDir := flag.String("snapshot-dir", envDefault("COORDINATOR_SNAPSHOT_DIR", "coordinator-snapshots"), "directory for coordinator session snapshots") - relayAvailable := flag.Bool("relay-available", envBoolDefault("COORDINATOR_RELAY_AVAILABLE", true), "whether relay is available for MQTT participants") - defaultSessionTTLSec := flag.Int("default-session-ttl-sec", envIntDefault("COORDINATOR_DEFAULT_SESSION_TTL_SEC", 120), "default session TTL in seconds") - tickInterval := flag.Duration("tick-interval", envDurationDefault("COORDINATOR_TICK_INTERVAL", time.Second), "session timeout scan interval") - flag.Parse() - - if *coordinatorID == "" { - return fmt.Errorf("coordinator-id is required") +func run(ctx context.Context, c *cli.Command) error { + configPath := c.String("config") + config.InitViperConfig(configPath) + + cfg, err := coordinator.LoadRuntimeConfig() + if err != nil { + return err } - if *privateKeyHex == "" { - return fmt.Errorf("coordinator-private-key-hex is required") + if err := cfg.Validate(); err != nil { + return err } - signer, err := coordinator.NewEd25519SignerFromHex(*privateKeyHex) + signer, err := coordinator.NewEd25519SignerFromHex(cfg.PrivateKeyHex) if err != nil { return err } - nc, err := nats.Connect(*natsURL) + nc, err := nats.Connect(cfg.NATSURL) if err != nil { return fmt.Errorf("connect to NATS: %w", err) } defer nc.Close() - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) defer stop() - snapshotStore := coordinator.NewAtomicFileSnapshotStore(*snapshotDir) + snapshotStore := coordinator.NewAtomicFileSnapshotStore(cfg.SnapshotDir) sessionStore, err := coordinator.NewMemorySessionStore(ctx, snapshotStore) if err != nil { return fmt.Errorf("restore coordinator state: %w", err) @@ -61,10 +75,10 @@ func run() error { if err := coordinator.RestoreKeyInfoFromSnapshotStore(ctx, snapshotStore, keyInfoStore); err != nil { return fmt.Errorf("restore key info: %w", err) } - _ = relayAvailable + presence := coordinator.NewInMemoryPresenceView() coord, err := coordinator.NewCoordinator(coordinator.CoordinatorConfig{ - CoordinatorID: *coordinatorID, + CoordinatorID: cfg.ID, Signer: signer, EventVerifier: coordinator.Ed25519SessionEventVerifier{}, Store: sessionStore, @@ -72,7 +86,7 @@ func run() error { Presence: presence, Controls: coordinator.NewNATSControlPublisher(nc), Results: coordinator.NewNATSResultPublisher(nc), - DefaultSessionTTL: time.Duration(*defaultSessionTTLSec) * time.Second, + DefaultSessionTTL: cfg.DefaultSessionTTL, }) if err != nil { return err @@ -83,59 +97,21 @@ func run() error { return err } - ticker := time.NewTicker(*tickInterval) + return runTickLoop(ctx, runtime, coord, cfg.TickInterval) +} + +func runTickLoop(ctx context.Context, runtime *coordinator.NATSRuntime, coord *coordinator.Coordinator, interval time.Duration) error { + ticker := time.NewTicker(interval) defer ticker.Stop() + for { select { case <-ctx.Done(): return runtime.Stop() case <-ticker.C: if _, err := coord.Tick(ctx); err != nil { - fmt.Fprintln(os.Stderr, "coordinator tick error:", err) + logger.Error("coordinator tick error", err) } } } } - -func envDefault(name string, fallback string) string { - if value := os.Getenv(name); value != "" { - return value - } - return fallback -} - -func envBoolDefault(name string, fallback bool) bool { - value := os.Getenv(name) - if value == "" { - return fallback - } - parsed, err := strconv.ParseBool(value) - if err != nil { - return fallback - } - return parsed -} - -func envDurationDefault(name string, fallback time.Duration) time.Duration { - value := os.Getenv(name) - if value == "" { - return fallback - } - parsed, err := time.ParseDuration(value) - if err != nil { - return fallback - } - return parsed -} - -func envIntDefault(name string, fallback int) int { - value := os.Getenv(name) - if value == "" { - return fallback - } - parsed, err := strconv.Atoi(value) - if err != nil { - return fallback - } - return parsed -} diff --git a/cmd/mpcium-cosigner/main.go b/cmd/mpcium-cosigner/main.go index 947e0251..efd44160 100644 --- a/cmd/mpcium-cosigner/main.go +++ b/cmd/mpcium-cosigner/main.go @@ -2,96 +2,60 @@ package main import ( "context" - "encoding/hex" - "flag" - "fmt" "os" "os/signal" "syscall" - "time" "github.com/fystack/mpcium/internal/cosigner" + "github.com/fystack/mpcium/pkg/config" + "github.com/fystack/mpcium/pkg/logger" + "github.com/urfave/cli/v3" ) +const cosignerConfigPath = "cosigner.config.yaml" + func main() { - if err := run(); err != nil { - fmt.Fprintln(os.Stderr, err) + logger.Init(os.Getenv("ENVIRONMENT"), false) + + cmd := &cli.Command{ + Name: "mpcium-cosigner", + Usage: "Run MPC cosigner runtime", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "config", + Aliases: []string{"c"}, + Usage: "Path to cosigner config file", + Value: cosignerConfigPath, + }, + }, + Action: run, + } + + if err := cmd.Run(context.Background(), os.Args); err != nil { + logger.Error("cosigner exited with error", err) os.Exit(1) } } -func run() error { - nodeID := flag.String("node-id", envDefault("NODE_ID", ""), "local participant ID") - natsURL := flag.String("nats-url", envDefault("NATS_URL", "nats://127.0.0.1:4222"), "NATS server URL") - coordinatorID := flag.String("coordinator-id", envDefault("COORDINATOR_ID", ""), "coordinator ID") - coordinatorPubHex := flag.String("coordinator-public-key-hex", envDefault("COORDINATOR_PUBLIC_KEY_HEX", ""), "coordinator Ed25519 public key hex") - privateKeyHex := flag.String("identity-private-key-hex", envDefault("IDENTITY_PRIVATE_KEY_HEX", ""), "node Ed25519 private key hex") - dataDir := flag.String("data-dir", envDefault("NODE_V1_DATA_DIR", "node-v1-data"), "node-v1 badger data directory") - maxActive := flag.Int("max-active-sessions", envIntDefault("NODE_V1_MAX_ACTIVE_SESSIONS", 64), "maximum concurrent active sessions") - presenceInterval := flag.Duration("presence-interval", envDurationDefault("NODE_V1_PRESENCE_INTERVAL", 5*time.Second), "presence heartbeat interval") - tickInterval := flag.Duration("tick-interval", envDurationDefault("NODE_V1_TICK_INTERVAL", 100*time.Millisecond), "participant tick interval") - flag.Parse() - - if *nodeID == "" || *coordinatorID == "" || *coordinatorPubHex == "" || *privateKeyHex == "" { - return fmt.Errorf("node-id, coordinator-id, coordinator-public-key-hex, and identity-private-key-hex are required") - } - coordinatorKey, err := hex.DecodeString(*coordinatorPubHex) - if err != nil { - return fmt.Errorf("decode coordinator public key: %w", err) - } - privateKey, err := hex.DecodeString(*privateKeyHex) +func run(ctx context.Context, c *cli.Command) error { + configPath := c.String("config") + config.InitViperConfig(configPath) + cfg, err := cosigner.LoadConfig() if err != nil { - return fmt.Errorf("decode identity private key: %w", err) + return err } - runtime, err := cosigner.NewRuntime(cosigner.Config{ - NodeID: *nodeID, - NATSURL: *natsURL, - CoordinatorID: *coordinatorID, - CoordinatorPublicKey: coordinatorKey, - IdentityPrivateKey: privateKey, - DataDir: *dataDir, - MaxActiveSessions: *maxActive, - PresenceInterval: *presenceInterval, - TickInterval: *tickInterval, - }) + return runCosigner(ctx, cfg) +} + +func runCosigner(ctx context.Context, cfg cosigner.Config) error { + runtime, err := cosigner.NewRuntime(cfg) if err != nil { return err } defer runtime.Close() - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) defer stop() return runtime.Run(ctx) } - -func envDefault(name, fallback string) string { - if value := os.Getenv(name); value != "" { - return value - } - return fallback -} - -func envDurationDefault(name string, fallback time.Duration) time.Duration { - value := os.Getenv(name) - if value == "" { - return fallback - } - parsed, err := time.ParseDuration(value) - if err != nil { - return fallback - } - return parsed -} - -func envIntDefault(name string, fallback int) int { - value := os.Getenv(name) - if value == "" { - return fallback - } - var parsed int - if _, err := fmt.Sscanf(value, "%d", &parsed); err != nil { - return fallback - } - return parsed -} diff --git a/coordinator.config.yaml b/coordinator.config.yaml new file mode 100644 index 00000000..1c649996 --- /dev/null +++ b/coordinator.config.yaml @@ -0,0 +1,8 @@ +nats: + url: nats://127.0.0.1:4222 + +coordinator: + id: coordinator-01 + private_key_hex: "86ed171146e6003841f1686c0958b68ae84f9992974c2c6febfb9df7f424b3adb64ca8ec459081a299aecc2b2b5d555265b15ddfd29e792ddd08bedb418bdd0d" + snapshot_dir: coordinator-snapshots + relay_available: true diff --git a/cosigner.config.yaml b/cosigner.config.yaml new file mode 100644 index 00000000..57af339b --- /dev/null +++ b/cosigner.config.yaml @@ -0,0 +1,11 @@ +nats: + url: nats://127.0.0.1:4222 + +cosigner: + node_id: peer-node-01 + coordinator: + id: coordinator-01 + public_key_hex: "b64ca8ec459081a299aecc2b2b5d555265b15ddfd29e792ddd08bedb418bdd0d" + identity: + private_key_hex: "b14d168636008a9c766a6c231c182446e4b636cd2116817a89d068ffb5cc49e456a47a1103b610d6c85bf23ddb1f78ff6404f7c6f170d46441a268e105873cc4" + data_dir: node-v1-data diff --git a/cosigner2.config.yaml b/cosigner2.config.yaml new file mode 100644 index 00000000..c4ac3f2b --- /dev/null +++ b/cosigner2.config.yaml @@ -0,0 +1,11 @@ +nats: + url: nats://127.0.0.1:4222 + +cosigner: + node_id: peer-node-02 + coordinator: + id: coordinator-01 + public_key_hex: "b64ca8ec459081a299aecc2b2b5d555265b15ddfd29e792ddd08bedb418bdd0d" + identity: + private_key_hex: "a96d8c0de1b5682740f6487b13dc7477aaa739b900c6f5c3db737ca019163efad9034dd84e0dd10a57d6a09a8267b217051d5f121ff52fca66c2b485be16ae02" + data_dir: node-v1-data-02 diff --git a/examples/coordinatorclient-keygen/main.go b/examples/coordinatorclient-keygen/main.go new file mode 100644 index 00000000..58a30101 --- /dev/null +++ b/examples/coordinatorclient-keygen/main.go @@ -0,0 +1,101 @@ +package main + +import ( + "context" + "crypto/ed25519" + "encoding/hex" + "fmt" + "log" + "time" + + "github.com/fystack/mpcium/pkg/coordinatorclient" + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +func main() { + client, err := coordinatorclient.New(coordinatorclient.Config{ + NATSURL: "nats://127.0.0.1:4222", + Timeout: 5 * time.Second, + }) + if err != nil { + log.Fatalf("create coordinator client: %v", err) + } + defer client.Close() + + participants := []coordinatorclient.KeygenParticipant{ + { + ID: "peer-node-01", + IdentityPublicKey: mustPublicKeyFromPrivateHex("b14d168636008a9c766a6c231c182446e4b636cd2116817a89d068ffb5cc49e456a47a1103b610d6c85bf23ddb1f78ff6404f7c6f170d46441a268e105873cc4"), + }, + { + ID: "peer-node-02", + IdentityPublicKey: mustPublicKeyFromPrivateHex("a96d8c0de1b5682740f6487b13dc7477aaa739b900c6f5c3db737ca019163efad9034dd84e0dd10a57d6a09a8267b217051d5f121ff52fca66c2b485be16ae02"), + }, + } + + for _, participant := range participants { + presenceCtx, cancelPresence := context.WithTimeout(context.Background(), 5*time.Second) + if err := client.PublishPresence(presenceCtx, participant.ID); err != nil { + cancelPresence() + log.Fatalf("publish presence for %s: %v", participant.ID, err) + } + cancelPresence() + } + + const walletID = "wallet_demo_001" + runKeygenForProtocol(client, participants, walletID, sdkprotocol.ProtocolTypeECDSA) + runKeygenForProtocol(client, participants, walletID, sdkprotocol.ProtocolTypeEdDSA) +} + +func mustDecodeHex(value string) []byte { + decoded, err := hex.DecodeString(value) + if err != nil { + panic(err) + } + return decoded +} + +func mustPublicKeyFromPrivateHex(privateKeyHex string) []byte { + privateRaw := mustDecodeHex(privateKeyHex) + var private ed25519.PrivateKey + switch len(privateRaw) { + case ed25519.PrivateKeySize: + private = ed25519.PrivateKey(privateRaw) + case ed25519.SeedSize: + private = ed25519.NewKeyFromSeed(privateRaw) + default: + panic(fmt.Sprintf("invalid ed25519 private key length: %d", len(privateRaw))) + } + public := private.Public().(ed25519.PublicKey) + return append([]byte(nil), public...) +} + +func runKeygenForProtocol(client *coordinatorclient.Client, participants []coordinatorclient.KeygenParticipant, walletID string, protocol sdkprotocol.ProtocolType) { + requestCtx, cancelRequest := context.WithTimeout(context.Background(), 10*time.Second) + resp, err := client.RequestKeygen(requestCtx, coordinatorclient.KeygenRequest{ + Protocol: protocol, + Threshold: 1, + WalletID: walletID, + Participants: participants, + }) + cancelRequest() + if err != nil { + log.Fatalf("request keygen (%s): %v", protocol, err) + } + acceptedAt := time.Now() + + resultCtx, cancelResult := context.WithTimeout(context.Background(), 2*time.Minute) + result, err := client.WaitSessionResult(resultCtx, resp.SessionID) + cancelResult() + if err != nil { + log.Fatalf("wait session result (%s): %v (check both cosigners are running and session events are flowing)", protocol, err) + } + if result == nil { + fmt.Printf("protocol=%s session_id=%s result=empty wait_seconds=%.3f\n", protocol, resp.SessionID, time.Since(acceptedAt).Seconds()) + return + } + fmt.Printf("protocol=%s key_id=%s session_id=%s wait_seconds=%.3f\n", protocol, result.KeyShare.KeyID, resp.SessionID, time.Since(acceptedAt).Seconds()) + if result.KeyShare != nil { + fmt.Printf("public_key_hex=%s\n", hex.EncodeToString(result.KeyShare.PublicKey)) + } +} diff --git a/internal/coordinator/config.go b/internal/coordinator/config.go new file mode 100644 index 00000000..fce88b3d --- /dev/null +++ b/internal/coordinator/config.go @@ -0,0 +1,68 @@ +package coordinator + +import ( + "fmt" + "time" + + "github.com/mitchellh/mapstructure" + "github.com/spf13/viper" +) + +type fileConfig struct { + NATS natsConfig `mapstructure:"nats"` + Coordinator coordinatorConfig `mapstructure:"coordinator"` +} + +type natsConfig struct { + URL string `mapstructure:"url"` +} + +type coordinatorConfig struct { + ID string `mapstructure:"id"` + PrivateKeyHex string `mapstructure:"private_key_hex"` + SnapshotDir string `mapstructure:"snapshot_dir"` +} + +type RuntimeConfig struct { + NATSURL string + ID string + PrivateKeyHex string + SnapshotDir string + DefaultSessionTTL time.Duration + TickInterval time.Duration +} + +func (cfg RuntimeConfig) Validate() error { + if cfg.NATSURL == "" { + return fmt.Errorf("nats-url is required") + } + if cfg.ID == "" { + return fmt.Errorf("coordinator-id is required") + } + if cfg.PrivateKeyHex == "" { + return fmt.Errorf("coordinator-private-key-hex is required") + } + if cfg.SnapshotDir == "" { + return fmt.Errorf("coordinator-snapshot-dir is required") + } + return nil +} + +func LoadRuntimeConfig() (RuntimeConfig, error) { + var cfg fileConfig + if err := viper.Unmarshal(&cfg, viper.DecodeHook(mapstructure.StringToTimeDurationHookFunc())); err != nil { + return RuntimeConfig{}, fmt.Errorf("decode config: %w", err) + } + return cfg.Coordinator.runtimeConfig(cfg.NATS.URL), nil +} + +func (cfg coordinatorConfig) runtimeConfig(natsURL string) RuntimeConfig { + return RuntimeConfig{ + NATSURL: natsURL, + ID: cfg.ID, + PrivateKeyHex: cfg.PrivateKeyHex, + SnapshotDir: cfg.SnapshotDir, + DefaultSessionTTL: 120 * time.Second, + TickInterval: time.Second, + } +} diff --git a/internal/coordinator/coordinator.go b/internal/coordinator/coordinator.go index a0485470..6a6d1dcf 100644 --- a/internal/coordinator/coordinator.go +++ b/internal/coordinator/coordinator.go @@ -6,8 +6,10 @@ import ( "encoding/hex" "encoding/json" "fmt" + "strings" "time" + "github.com/fystack/mpcium/pkg/logger" "github.com/google/uuid" sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) @@ -85,9 +87,43 @@ func (c *Coordinator) HandleRequest(ctx context.Context, op Operation, raw []byt if err != nil { return reject(ErrorCodeInvalidJSON, "invalid JSON request"), nil } - if err := c.validateRequest(ctx, op, req); err != nil { + // Backward compatibility: keygen without protocol means dispatch both ECDSA and EdDSA sessions. + if op == OperationKeygen && req.SessionStart != nil && isProtocolUnspecified(req.SessionStart.Protocol) { + protocols := []sdkprotocol.ProtocolType{sdkprotocol.ProtocolTypeECDSA, sdkprotocol.ProtocolTypeEdDSA} + sessionIDs := make([]string, 0, len(protocols)) + var firstAccepted *sdkprotocol.RequestAccepted + + for _, protocol := range protocols { + cloned := cloneSessionStart(req.SessionStart) + cloned.Protocol = protocol + accepted, err := c.acceptRequest(ctx, op, &sdkprotocol.ControlMessage{SessionStart: cloned}) + if err != nil { + return rejectFromError(err), nil + } + sessionIDs = append(sessionIDs, accepted.SessionID) + if firstAccepted == nil { + firstAccepted = accepted + } + } + + logger.Info("coordinator expanded keygen request without protocol", + "operation", string(op), + "sessions", strings.Join(sessionIDs, ","), + ) + return json.Marshal(firstAccepted) + } + + accepted, err := c.acceptRequest(ctx, op, req) + if err != nil { return rejectFromError(err), nil } + return json.Marshal(accepted) +} + +func (c *Coordinator) acceptRequest(ctx context.Context, op Operation, req *sdkprotocol.ControlMessage) (*sdkprotocol.RequestAccepted, error) { + if err := c.validateRequest(ctx, op, req); err != nil { + return nil, err + } now := c.now() sessionID := "sess_" + uuid.NewString() @@ -118,25 +154,31 @@ func (c *Coordinator) HandleRequest(ctx context.Context, op Operation, raw []byt ParticipantKeys: keys, } if err := c.store.Create(ctx, session); err != nil { - return rejectFromError(err), nil + return nil, err } + logger.Info("coordinator accepted request", + "action", string(op), + "protocol", string(start.Protocol), + "session_id", session.ID, + "participant_count", len(session.Participants), + "wallet_id", keygenWalletID(start), + ) if err := c.fanOutSessionStart(ctx, session); err != nil { _ = c.failSession(ctx, session, ErrorCodeInternal, err.Error()) - return reject(ErrorCodeInternal, "failed to publish session start"), nil + return nil, newCoordinatorError(ErrorCodeInternal, "failed to publish session start") } session.State = SessionWaitingParticipants session.UpdatedAt = c.now() if err := c.store.Save(ctx, session); err != nil { - return reject(ErrorCodeInternal, "failed to save session"), nil + return nil, newCoordinatorError(ErrorCodeInternal, "failed to save session") } - resp := sdkprotocol.RequestAccepted{ + return &sdkprotocol.RequestAccepted{ Accepted: true, SessionID: session.ID, ExpiresAt: session.ExpiresAt.UTC().Format(time.RFC3339Nano), - } - return json.Marshal(resp) + }, nil } func (c *Coordinator) HandleSessionEvent(ctx context.Context, raw []byte) error { @@ -187,7 +229,7 @@ func (c *Coordinator) HandleSessionEvent(ctx context.Context, raw []byte) error if event.SessionCompleted.Result == nil { return c.failSession(ctx, session, ErrorCodeValidation, "missing result payload") } - state.ResultHash = canonicalResultHash(event.SessionCompleted.Result) + state.ResultHash = canonicalOperationResultHash(session.Op, event.SessionCompleted.Result) case event.PeerFailed != nil: state.Failed = true state.ErrorCode = ErrorCodeParticipantFailed @@ -206,6 +248,11 @@ func (c *Coordinator) HandleSessionEvent(ctx context.Context, raw []byte) error if err := c.advance(ctx, session, &event); err != nil { return err } + logger.Debug("coordinator processed session event", + "session_id", session.ID, + "participant_id", event.ParticipantID, + "state", string(session.State), + ) return c.store.Save(ctx, session) } @@ -236,6 +283,9 @@ func (c *Coordinator) validateRequest(ctx context.Context, op Operation, msg *sd return newCoordinatorError(ErrorCodeValidation, "session_start is required") } start := msg.SessionStart + if isProtocolUnspecified(start.Protocol) { + return newCoordinatorError(ErrorCodeValidation, "protocol is required") + } start.SessionID = "tmp" start.Operation = op.ToSDK() if err := sdkprotocol.ValidateSessionStart(start); err != nil { @@ -255,6 +305,10 @@ func (c *Coordinator) validateRequest(ctx context.Context, op Operation, msg *sd return nil } +func isProtocolUnspecified(protocol sdkprotocol.ProtocolType) bool { + return protocol == sdkprotocol.ProtocolTypeUnspecified || string(protocol) == "" +} + func (c *Coordinator) advance(ctx context.Context, session *Session, event *sdkprotocol.SessionEvent) error { switch session.State { case SessionWaitingParticipants: @@ -347,6 +401,11 @@ func (c *Coordinator) fanOutMPCBegin(ctx context.Context, session *Session) erro } func (c *Coordinator) failSession(ctx context.Context, session *Session, code, message string) error { + logger.Error("coordinator failing session", + fmt.Errorf("%s: %s", code, message), + "session_id", session.ID, + "error_code", code, + ) now := c.now() session.State = SessionFailed session.ErrorCode = code @@ -457,6 +516,27 @@ func allParticipants(session *Session, predicate func(*ParticipantState) bool) b return true } +func canonicalOperationResultHash(op Operation, result *sdkprotocol.Result) string { + if result == nil { + return "" + } + switch op { + case OperationKeygen: + if result.KeyShare == nil { + return "" + } + normalized := &sdkprotocol.Result{ + KeyShare: &sdkprotocol.KeyShareResult{ + KeyID: result.KeyShare.KeyID, + PublicKey: append([]byte(nil), result.KeyShare.PublicKey...), + }, + } + return canonicalResultHash(normalized) + default: + return canonicalResultHash(result) + } +} + func canonicalResultHash(result *sdkprotocol.Result) string { if result == nil { return "" @@ -466,6 +546,13 @@ func canonicalResultHash(result *sdkprotocol.Result) string { return hex.EncodeToString(sum[:]) } +func keygenWalletID(start *sdkprotocol.SessionStart) string { + if start == nil || start.Keygen == nil { + return "" + } + return start.Keygen.KeyID +} + func firstNonEmpty(values ...string) string { for _, value := range values { if value != "" { diff --git a/internal/coordinator/result_hash_test.go b/internal/coordinator/result_hash_test.go new file mode 100644 index 00000000..25b64f27 --- /dev/null +++ b/internal/coordinator/result_hash_test.go @@ -0,0 +1,60 @@ +package coordinator + +import ( + "bytes" + "testing" + + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +func TestCanonicalOperationResultHashIgnoresKeygenShareBlob(t *testing.T) { + resultA := &sdkprotocol.Result{ + KeyShare: &sdkprotocol.KeyShareResult{ + KeyID: "wallet-1", + PublicKey: []byte{1, 2, 3}, + ShareBlob: []byte{9, 9, 9}, + }, + } + resultB := &sdkprotocol.Result{ + KeyShare: &sdkprotocol.KeyShareResult{ + KeyID: "wallet-1", + PublicKey: []byte{1, 2, 3}, + ShareBlob: []byte{8, 8, 8}, + }, + } + + hashA := canonicalOperationResultHash(OperationKeygen, resultA) + hashB := canonicalOperationResultHash(OperationKeygen, resultB) + if hashA == "" || hashB == "" { + t.Fatal("expected non-empty hashes") + } + if hashA != hashB { + t.Fatalf("expected equal hashes for keygen results with different share blobs, got %q != %q", hashA, hashB) + } +} + +func TestCanonicalOperationResultHashUsesFullSignaturePayload(t *testing.T) { + resultA := &sdkprotocol.Result{ + Signature: &sdkprotocol.SignatureResult{ + KeyID: "wallet-1", + Signature: []byte{1, 2, 3}, + }, + } + resultB := &sdkprotocol.Result{ + Signature: &sdkprotocol.SignatureResult{ + KeyID: "wallet-1", + Signature: []byte{1, 2, 4}, + }, + } + + hashA := canonicalOperationResultHash(OperationSign, resultA) + hashB := canonicalOperationResultHash(OperationSign, resultB) + if hashA == hashB { + t.Fatalf("expected different hashes for different signature payloads") + } + + // Guard against accidental normalization that removes signature bytes. + if bytes.Equal(resultA.Signature.Signature, resultB.Signature.Signature) { + t.Fatal("invalid test setup") + } +} diff --git a/internal/coordinator/runtime.go b/internal/coordinator/runtime.go index ae66bf5c..f0df7802 100644 --- a/internal/coordinator/runtime.go +++ b/internal/coordinator/runtime.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" + "github.com/fystack/mpcium/pkg/logger" "github.com/nats-io/nats.go" sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) @@ -21,11 +22,14 @@ func NewNATSRuntime(nc *nats.Conn, coord *Coordinator, presence PresenceView) *N } func (r *NATSRuntime) Start(ctx context.Context) error { + logger.Info("starting coordinator runtime subscriptions") + for _, op := range []Operation{OperationKeygen, OperationSign, OperationReshare} { op := op sub, err := r.nc.Subscribe(RequestSubject(op), func(msg *nats.Msg) { reply, err := r.coord.HandleRequest(ctx, op, msg.Data) if err != nil { + logger.Error("handle coordinator request failed", err, "operation", string(op)) reply = reject(ErrorCodeInternal, err.Error()) } if msg.Reply != "" { @@ -35,20 +39,25 @@ func (r *NATSRuntime) Start(ctx context.Context) error { if err != nil { return fmt.Errorf("subscribe request subject %s: %w", RequestSubject(op), err) } + logger.Info("subscribed coordinator request subject", "subject", RequestSubject(op)) r.subs = append(r.subs, sub) } eventSub, err := r.nc.Subscribe(AllSessionEventsSubject(), func(msg *nats.Msg) { - _ = r.coord.HandleSessionEvent(ctx, msg.Data) + if err := r.coord.HandleSessionEvent(ctx, msg.Data); err != nil { + logger.Error("handle session event failed", err) + } }) if err != nil { return fmt.Errorf("subscribe session events: %w", err) } + logger.Info("subscribed coordinator session events", "subject", AllSessionEventsSubject()) r.subs = append(r.subs, eventSub) presenceSub, err := r.nc.Subscribe(AllPresenceSubject(), func(msg *nats.Msg) { var event sdkprotocol.PresenceEvent if err := json.Unmarshal(msg.Data, &event); err != nil { + logger.Error("decode presence event failed", err) return } _ = r.presence.ApplyPresence(event) @@ -56,12 +65,14 @@ func (r *NATSRuntime) Start(ctx context.Context) error { if err != nil { return fmt.Errorf("subscribe presence events: %w", err) } + logger.Info("subscribed coordinator presence events", "subject", AllPresenceSubject()) r.subs = append(r.subs, presenceSub) return r.nc.Flush() } func (r *NATSRuntime) Stop() error { + logger.Info("stopping coordinator runtime subscriptions") for _, sub := range r.subs { if err := sub.Unsubscribe(); err != nil { return err diff --git a/internal/coordinator/signing.go b/internal/coordinator/signing.go index 6251377a..ef915dbb 100644 --- a/internal/coordinator/signing.go +++ b/internal/coordinator/signing.go @@ -53,6 +53,9 @@ func (Ed25519SessionEventVerifier) VerifySessionEvent(_ context.Context, session if !ok || len(pubKey) == 0 { return newCoordinatorError(ErrorCodeUnauthorized, "unknown participant public key") } + if len(pubKey) != ed25519.PublicKeySize { + return newCoordinatorError(ErrorCodeValidation, "invalid participant public key length") + } payload, err := sdkprotocol.SessionEventSigningBytes(event) if err != nil { return newCoordinatorError(ErrorCodeValidation, err.Error()) diff --git a/internal/coordinator/signing_test.go b/internal/coordinator/signing_test.go new file mode 100644 index 00000000..be004c00 --- /dev/null +++ b/internal/coordinator/signing_test.go @@ -0,0 +1,29 @@ +package coordinator + +import ( + "context" + "strings" + "testing" + + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +func TestEd25519SessionEventVerifierRejectsInvalidPublicKeyLength(t *testing.T) { + verifier := Ed25519SessionEventVerifier{} + session := &Session{ + ParticipantKeys: map[string][]byte{ + "peer-1": make([]byte, 64), + }, + } + event := &sdkprotocol.SessionEvent{ + ParticipantID: "peer-1", + } + + err := verifier.VerifySessionEvent(context.Background(), session, event) + if err == nil { + t.Fatal("expected error for invalid participant public key length") + } + if !strings.Contains(err.Error(), "invalid participant public key length") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/internal/cosigner/config.go b/internal/cosigner/config.go index e4353a94..4a521c85 100644 --- a/internal/cosigner/config.go +++ b/internal/cosigner/config.go @@ -1,6 +1,13 @@ package cosigner -import "time" +import ( + "encoding/hex" + "fmt" + "time" + + "github.com/mitchellh/mapstructure" + "github.com/spf13/viper" +) type Config struct { NodeID string @@ -13,3 +20,65 @@ type Config struct { PresenceInterval time.Duration TickInterval time.Duration } + +type fileConfig struct { + NATS natsConfig `mapstructure:"nats"` + Cosigner cosignerConfig `mapstructure:"cosigner"` +} + +type natsConfig struct { + URL string `mapstructure:"url"` +} + +type cosignerConfig struct { + NodeID string `mapstructure:"node_id"` + DataDir string `mapstructure:"data_dir"` + Coordinator coordinatorConfig `mapstructure:"coordinator"` + Identity identityConfig `mapstructure:"identity"` +} + +type coordinatorConfig struct { + ID string `mapstructure:"id"` + PublicKey string `mapstructure:"public_key_hex"` +} + +type identityConfig struct { + PrivateKey string `mapstructure:"private_key_hex"` +} + +func LoadConfig() (Config, error) { + var cfg fileConfig + if err := viper.Unmarshal(&cfg, viper.DecodeHook(mapstructure.StringToTimeDurationHookFunc())); err != nil { + return Config{}, fmt.Errorf("decode config: %w", err) + } + + coordinatorKey, err := decodeHexKey(cfg.Cosigner.Coordinator.PublicKey, "coordinator public key") + if err != nil { + return Config{}, err + } + + privateKey, err := decodeHexKey(cfg.Cosigner.Identity.PrivateKey, "identity private key") + if err != nil { + return Config{}, err + } + + return Config{ + NodeID: cfg.Cosigner.NodeID, + NATSURL: cfg.NATS.URL, + CoordinatorID: cfg.Cosigner.Coordinator.ID, + CoordinatorPublicKey: coordinatorKey, + IdentityPrivateKey: privateKey, + DataDir: cfg.Cosigner.DataDir, + MaxActiveSessions: 64, + PresenceInterval: 5 * time.Second, + TickInterval: 100 * time.Millisecond, + }, nil +} + +func decodeHexKey(value, name string) ([]byte, error) { + decoded, err := hex.DecodeString(value) + if err != nil { + return nil, fmt.Errorf("decode %s: %w", name, err) + } + return decoded, nil +} diff --git a/internal/cosigner/runtime.go b/internal/cosigner/runtime.go index 29ec8300..0aa7f380 100644 --- a/internal/cosigner/runtime.go +++ b/internal/cosigner/runtime.go @@ -6,9 +6,11 @@ import ( "encoding/json" "errors" "fmt" + "strings" "sync" "time" + "github.com/fystack/mpcium/pkg/logger" "github.com/nats-io/nats.go" "github.com/vietddude/mpcium-sdk/participant" sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" @@ -22,9 +24,15 @@ type Runtime struct { coordLookup *coordinatorLookup sessionsMu sync.RWMutex sessions map[string]*participant.ParticipantSession + sessionMeta map[string]sessionMeta subs []*nats.Subscription } +type sessionMeta struct { + protocol string + action string +} + func NewRuntime(cfg Config) (*Runtime, error) { if cfg.NodeID == "" { return nil, errors.New("node_id is required") @@ -67,7 +75,8 @@ func NewRuntime(cfg Config) (*Runtime, error) { coordLookup: &coordinatorLookup{keys: map[string]ed25519.PublicKey{ cfg.CoordinatorID: append([]byte(nil), cfg.CoordinatorPublicKey...), }}, - sessions: map[string]*participant.ParticipantSession{}, + sessions: map[string]*participant.ParticipantSession{}, + sessionMeta: map[string]sessionMeta{}, }, nil } @@ -85,6 +94,7 @@ func (r *Runtime) Close() error { } func (r *Runtime) Run(ctx context.Context) error { + logger.Info("cosigner runtime started", "node_id", r.cfg.NodeID) if err := r.subscribe(); err != nil { return err } @@ -99,6 +109,7 @@ func (r *Runtime) Run(ctx context.Context) error { for { select { case <-ctx.Done(): + logger.Info("cosigner runtime stopping", "node_id", r.cfg.NodeID) _ = r.publishPresence(sdkprotocol.PresenceStatusOffline) return nil case <-tick.C: @@ -115,19 +126,25 @@ func (r *Runtime) Run(ctx context.Context) error { func (r *Runtime) subscribe() error { controlSub, err := r.nc.Subscribe(controlSubject(r.cfg.NodeID), func(msg *nats.Msg) { - _ = r.handleControl(msg.Data) + if err := r.handleControl(msg.Data); err != nil { + logger.Error("handle control message failed", err) + } }) if err != nil { return err } + logger.Info("subscribed control subject") r.subs = append(r.subs, controlSub) p2pSub, err := r.nc.Subscribe(p2pWildcardSubject(r.cfg.NodeID), func(msg *nats.Msg) { - _ = r.handlePeer(msg.Data) + if err := r.handlePeer(msg.Data); err != nil { + logger.Error("handle peer message failed", err) + } }) if err != nil { return err } + logger.Info("subscribed p2p subject") r.subs = append(r.subs, p2pSub) return r.nc.Flush() @@ -143,11 +160,29 @@ func (r *Runtime) handleControl(raw []byte) error { } if msg.SessionStart != nil { - return r.startSession(&msg) - } + meta := sessionMeta{ + protocol: protocolLabel(msg.SessionStart.Protocol), + action: actionLabel(msg.SessionStart.Operation), + } + logger.Info("cosigner received session start", + "session_id", msg.SessionID, + "action", meta.action, + ) + return r.startSession(&msg, meta) + } + meta := r.getSessionMeta(msg.SessionID) + logger.Debug("cosigner received control message", + "node_id", r.cfg.NodeID, + "session_id", msg.SessionID, + "sequence", msg.Sequence, + "control_type", controlType(&msg), + "protocol", meta.protocol, + "action", meta.action, + ) session := r.getSession(msg.SessionID) if session == nil { - return fmt.Errorf("unknown session %s", msg.SessionID) + logger.Warn("ignoring control for unknown session", "session_id", msg.SessionID) + return nil } effects, err := session.HandleControl(&msg) if err != nil { @@ -156,7 +191,7 @@ func (r *Runtime) handleControl(raw []byte) error { return r.publishEffects(effects) } -func (r *Runtime) startSession(msg *sdkprotocol.ControlMessage) error { +func (r *Runtime) startSession(msg *sdkprotocol.ControlMessage, meta sessionMeta) error { if len(r.sessions) >= r.cfg.MaxActiveSessions { return errors.New("max active sessions reached") } @@ -185,7 +220,9 @@ func (r *Runtime) startSession(msg *sdkprotocol.ControlMessage) error { } r.sessionsMu.Lock() r.sessions[msg.SessionID] = sess + r.sessionMeta[msg.SessionID] = meta r.sessionsMu.Unlock() + logger.Info("cosigner started session", "session_id", msg.SessionID, "action", meta.action) effects, err := sess.Start() if err != nil { @@ -199,9 +236,16 @@ func (r *Runtime) handlePeer(raw []byte) error { if err := json.Unmarshal(raw, &msg); err != nil { return err } + logger.Debug("cosigner received peer message", + "node_id", r.cfg.NodeID, + "session_id", msg.SessionID, + "from_participant", msg.FromParticipantID, + "phase", string(msg.Phase), + ) session := r.getSession(msg.SessionID) if session == nil { - return fmt.Errorf("unknown session %s", msg.SessionID) + logger.Warn("ignoring peer message for unknown session", "session_id", msg.SessionID) + return nil } effects, err := session.HandlePeer(&msg) if err != nil { @@ -234,6 +278,28 @@ func (r *Runtime) tickSessions() error { } func (r *Runtime) publishEffects(effects participant.Effects) error { + meta := r.metaFromEffects(effects) + if effects.Cleanup != nil { + outcome := "finished" + if effects.Result == nil { + outcome = "failed" + } + logger.Info("cosigner ended session", "session_id", effects.Cleanup.SessionID, "outcome", outcome) + } + if len(effects.SessionEvents) > 0 { + logger.Debug("cosigner publishing session events", + "node_id", r.cfg.NodeID, + "session_events", len(effects.SessionEvents), + "protocol", meta.protocol, + "action", meta.action, + ) + } + if len(effects.PeerMessages) > 0 { + logger.Debug("cosigner publishing peer messages", + "node_id", r.cfg.NodeID, + "peer_messages", len(effects.PeerMessages), + ) + } for _, peerMsg := range effects.PeerMessages { raw, err := json.Marshal(peerMsg) if err != nil { @@ -253,11 +319,76 @@ func (r *Runtime) publishEffects(effects participant.Effects) error { } } if effects.Cleanup != nil && effects.Cleanup.DropArtifacts { + r.dropSessionMeta(effects.Cleanup.SessionID) _ = r.stores.DeleteSessionArtifacts(effects.Cleanup.SessionID) } return nil } +func (r *Runtime) getSessionMeta(sessionID string) sessionMeta { + r.sessionsMu.RLock() + defer r.sessionsMu.RUnlock() + if meta, ok := r.sessionMeta[sessionID]; ok { + return meta + } + return sessionMeta{protocol: "unknown", action: "unknown"} +} + +func (r *Runtime) metaFromEffects(effects participant.Effects) sessionMeta { + if len(effects.SessionEvents) > 0 && effects.SessionEvents[0] != nil { + return r.getSessionMeta(effects.SessionEvents[0].SessionID) + } + if len(effects.PeerMessages) > 0 && effects.PeerMessages[0] != nil { + return r.getSessionMeta(effects.PeerMessages[0].SessionID) + } + return sessionMeta{protocol: "unknown", action: "unknown"} +} + +func (r *Runtime) dropSessionMeta(sessionID string) { + r.sessionsMu.Lock() + defer r.sessionsMu.Unlock() + delete(r.sessionMeta, sessionID) + delete(r.sessions, sessionID) +} + +func controlType(msg *sdkprotocol.ControlMessage) string { + switch { + case msg == nil: + return "unknown" + case msg.KeyExchange != nil: + return "key_exchange_begin" + case msg.MPCBegin != nil: + return "mpc_begin" + case msg.SessionAbort != nil: + return "session_abort" + case msg.SessionStart != nil: + return "session_start" + default: + return "unknown" + } +} + +func protocolLabel(protocol sdkprotocol.ProtocolType) string { + value := strings.TrimSpace(string(protocol)) + if value == "" || value == string(sdkprotocol.ProtocolTypeUnspecified) { + return "unknown" + } + return strings.ToLower(value) +} + +func actionLabel(operation sdkprotocol.OperationType) string { + switch operation { + case sdkprotocol.OperationTypeKeygen: + return "keygen" + case sdkprotocol.OperationTypeSign: + return "sign" + case sdkprotocol.OperationTypeReshare: + return "reshare" + default: + return "unknown" + } +} + func (r *Runtime) publishPresence(status sdkprotocol.PresenceStatus) error { event := sdkprotocol.PresenceEvent{ PeerID: r.cfg.NodeID, diff --git a/pkg/coordinatorclient/client.go b/pkg/coordinatorclient/client.go new file mode 100644 index 00000000..f58b8872 --- /dev/null +++ b/pkg/coordinatorclient/client.go @@ -0,0 +1,208 @@ +package coordinatorclient + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/nats-io/nats.go" + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +const ( + topicPrefix = "mpc.v1" + requestKeygenSubject = topicPrefix + ".request.keygen" +) + +type Client struct { + nc *nats.Conn + timeout time.Duration +} + +type Config struct { + NATSURL string + Timeout time.Duration +} + +type KeygenParticipant struct { + ID string + IdentityPublicKey []byte +} + +type KeygenRequest struct { + Protocol sdkprotocol.ProtocolType + Threshold uint32 + WalletID string + Participants []KeygenParticipant +} + +func New(cfg Config) (*Client, error) { + if cfg.NATSURL == "" { + cfg.NATSURL = nats.DefaultURL + } + if cfg.Timeout <= 0 { + cfg.Timeout = 5 * time.Second + } + + nc, err := nats.Connect(cfg.NATSURL) + if err != nil { + return nil, fmt.Errorf("connect to NATS: %w", err) + } + + return &Client{ + nc: nc, + timeout: cfg.Timeout, + }, nil +} + +func (c *Client) Close() { + if c == nil || c.nc == nil { + return + } + c.nc.Close() +} + +func (c *Client) PublishPresence(ctx context.Context, peerID string) error { + if peerID == "" { + return fmt.Errorf("peerID is required") + } + + event := &sdkprotocol.PresenceEvent{ + PeerID: peerID, + Status: sdkprotocol.PresenceStatusOnline, + Transport: sdkprotocol.TransportTypeNATS, + LastSeenUnixMs: time.Now().UTC().UnixMilli(), + } + + payload, err := json.Marshal(event) + if err != nil { + return fmt.Errorf("marshal presence event: %w", err) + } + + subject := fmt.Sprintf("%s.peer.%s.presence", topicPrefix, peerID) + if err := c.nc.Publish(subject, payload); err != nil { + return fmt.Errorf("publish presence: %w", err) + } + + return c.nc.FlushWithContext(ctx) +} + +func (c *Client) RequestKeygen(ctx context.Context, req KeygenRequest) (*sdkprotocol.RequestAccepted, error) { + if err := validateKeygenRequest(req); err != nil { + return nil, err + } + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, c.timeout) + defer cancel() + } + + msg := &sdkprotocol.ControlMessage{ + SessionStart: &sdkprotocol.SessionStart{ + SessionID: "tmp", // coordinator replaces this value when accepting request + Protocol: normalizeProtocol(req.Protocol), + Operation: sdkprotocol.OperationTypeKeygen, + Threshold: req.Threshold, + Participants: mapParticipants(req.Participants), + Keygen: &sdkprotocol.KeygenPayload{ + KeyID: req.WalletID, + }, + }, + } + + payload, err := json.Marshal(msg) + if err != nil { + return nil, fmt.Errorf("marshal keygen request: %w", err) + } + + respMsg, err := c.nc.RequestWithContext(ctx, requestKeygenSubject, payload) + if err != nil { + return nil, fmt.Errorf("request keygen: %w", err) + } + + var accepted sdkprotocol.RequestAccepted + if err := json.Unmarshal(respMsg.Data, &accepted); err == nil && accepted.Accepted { + return &accepted, nil + } + + var rejected sdkprotocol.RequestRejected + if err := json.Unmarshal(respMsg.Data, &rejected); err == nil && !rejected.Accepted { + return nil, fmt.Errorf("coordinator rejected request (%s): %s", rejected.ErrorCode, rejected.ErrorMessage) + } + + return nil, fmt.Errorf("unexpected coordinator response: %s", string(respMsg.Data)) +} + +func normalizeProtocol(protocol sdkprotocol.ProtocolType) sdkprotocol.ProtocolType { + if string(protocol) == "" { + return sdkprotocol.ProtocolTypeUnspecified + } + return protocol +} + +func (c *Client) WaitSessionResult(ctx context.Context, sessionID string) (*sdkprotocol.Result, error) { + if sessionID == "" { + return nil, fmt.Errorf("sessionID is required") + } + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, c.timeout) + defer cancel() + } + + subject := fmt.Sprintf("%s.session.%s.result", topicPrefix, sessionID) + sub, err := c.nc.SubscribeSync(subject) + if err != nil { + return nil, fmt.Errorf("subscribe session result: %w", err) + } + defer sub.Unsubscribe() + + if err := c.nc.FlushWithContext(ctx); err != nil { + return nil, fmt.Errorf("flush subscribe: %w", err) + } + + msg, err := sub.NextMsgWithContext(ctx) + if err != nil { + return nil, fmt.Errorf("wait session result: %w", err) + } + + var result *sdkprotocol.Result + if err := json.Unmarshal(msg.Data, &result); err != nil { + return nil, fmt.Errorf("decode session result: %w", err) + } + return result, nil +} + +func validateKeygenRequest(req KeygenRequest) error { + if req.WalletID == "" { + return fmt.Errorf("walletID is required") + } + if len(req.Participants) == 0 { + return fmt.Errorf("participants are required") + } + if req.Threshold < 1 || int(req.Threshold) >= len(req.Participants) { + return fmt.Errorf("invalid threshold %d for %d participants", req.Threshold, len(req.Participants)) + } + for _, participant := range req.Participants { + if participant.ID == "" { + return fmt.Errorf("participant ID is required") + } + if len(participant.IdentityPublicKey) == 0 { + return fmt.Errorf("identity public key is required for participant %q", participant.ID) + } + } + return nil +} + +func mapParticipants(participants []KeygenParticipant) []*sdkprotocol.SessionParticipant { + mapped := make([]*sdkprotocol.SessionParticipant, 0, len(participants)) + for _, participant := range participants { + mapped = append(mapped, &sdkprotocol.SessionParticipant{ + ParticipantID: participant.ID, + PartyKey: []byte(participant.ID), + IdentityPublicKey: participant.IdentityPublicKey, + }) + } + return mapped +} From b166d804e6e0f36bb8a01be6613e8b95d4bd070d Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 17 Apr 2026 10:36:31 +0700 Subject: [PATCH 04/23] Refactor cosigner runtime to enhance configuration management and identity handling. Introduce validation and default application for configuration parameters. Implement transport abstraction for NATS messaging, improving session management and peer communication. Add new interfaces for storage management of preparams, shares, and session artifacts. --- internal/cosigner/config.go | 44 ++++++++- internal/cosigner/identity.go | 32 +++++++ internal/cosigner/runtime.go | 161 +++++++++++++-------------------- internal/cosigner/storage.go | 23 +++++ internal/cosigner/transport.go | 54 +++++++++++ 5 files changed, 211 insertions(+), 103 deletions(-) create mode 100644 internal/cosigner/transport.go diff --git a/internal/cosigner/config.go b/internal/cosigner/config.go index 4a521c85..6530e8be 100644 --- a/internal/cosigner/config.go +++ b/internal/cosigner/config.go @@ -1,6 +1,7 @@ package cosigner import ( + "crypto/ed25519" "encoding/hex" "fmt" "time" @@ -62,17 +63,19 @@ func LoadConfig() (Config, error) { return Config{}, err } - return Config{ + runtimeCfg := Config{ NodeID: cfg.Cosigner.NodeID, NATSURL: cfg.NATS.URL, CoordinatorID: cfg.Cosigner.Coordinator.ID, CoordinatorPublicKey: coordinatorKey, IdentityPrivateKey: privateKey, DataDir: cfg.Cosigner.DataDir, - MaxActiveSessions: 64, - PresenceInterval: 5 * time.Second, - TickInterval: 100 * time.Millisecond, - }, nil + } + runtimeCfg.applyDefaults() + if err := runtimeCfg.Validate(); err != nil { + return Config{}, err + } + return runtimeCfg, nil } func decodeHexKey(value, name string) ([]byte, error) { @@ -82,3 +85,34 @@ func decodeHexKey(value, name string) ([]byte, error) { } return decoded, nil } + +func (cfg *Config) applyDefaults() { + if cfg.MaxActiveSessions <= 0 { + cfg.MaxActiveSessions = 10 + } + if cfg.PresenceInterval <= 0 { + cfg.PresenceInterval = 5 * time.Second + } + if cfg.TickInterval <= 0 { + cfg.TickInterval = 100 * time.Millisecond + } +} + +func (cfg Config) Validate() error { + if cfg.NodeID == "" { + return fmt.Errorf("node_id is required") + } + if cfg.NATSURL == "" { + return fmt.Errorf("nats_url is required") + } + if cfg.CoordinatorID == "" || len(cfg.CoordinatorPublicKey) != ed25519.PublicKeySize { + return fmt.Errorf("valid coordinator key is required") + } + if len(cfg.IdentityPrivateKey) != ed25519.PrivateKeySize { + return fmt.Errorf("valid identity private key is required") + } + if cfg.DataDir == "" { + return fmt.Errorf("data_dir is required") + } + return nil +} diff --git a/internal/cosigner/identity.go b/internal/cosigner/identity.go index a196a293..9ea8ce48 100644 --- a/internal/cosigner/identity.go +++ b/internal/cosigner/identity.go @@ -11,6 +11,18 @@ type localIdentity struct { privateKey ed25519.PrivateKey } +func NewLocalIdentity(nodeID string, privateKey []byte) (*localIdentity, error) { + if nodeID == "" { + return nil, fmt.Errorf("node_id is required") + } + if len(privateKey) != ed25519.PrivateKeySize { + return nil, fmt.Errorf("invalid identity private key size") + } + private := ed25519.PrivateKey(append([]byte(nil), privateKey...)) + public := private.Public().(ed25519.PublicKey) + return &localIdentity{participantID: nodeID, publicKey: public, privateKey: private}, nil +} + func (i *localIdentity) ParticipantID() string { return i.participantID } func (i *localIdentity) PublicKey() ed25519.PublicKey { return i.publicKey @@ -21,6 +33,14 @@ func (i *localIdentity) Sign(message []byte) ([]byte, error) { type peerLookup struct{ keys map[string]ed25519.PublicKey } +func NewPeerLookup(keys map[string]ed25519.PublicKey) *peerLookup { + cloned := make(map[string]ed25519.PublicKey, len(keys)) + for id, key := range keys { + cloned[id] = append([]byte(nil), key...) + } + return &peerLookup{keys: cloned} +} + func (l *peerLookup) LookupParticipant(participantID string) (ed25519.PublicKey, error) { key, ok := l.keys[participantID] if !ok { @@ -31,6 +51,18 @@ func (l *peerLookup) LookupParticipant(participantID string) (ed25519.PublicKey, type coordinatorLookup struct{ keys map[string]ed25519.PublicKey } +func NewCoordinatorLookup(coordinatorID string, publicKey []byte) (*coordinatorLookup, error) { + if coordinatorID == "" { + return nil, fmt.Errorf("coordinator_id is required") + } + if len(publicKey) != ed25519.PublicKeySize { + return nil, fmt.Errorf("invalid coordinator public key size") + } + return &coordinatorLookup{keys: map[string]ed25519.PublicKey{ + coordinatorID: append([]byte(nil), publicKey...), + }}, nil +} + func (l *coordinatorLookup) LookupCoordinator(coordinatorID string) (ed25519.PublicKey, error) { key, ok := l.keys[coordinatorID] if !ok { diff --git a/internal/cosigner/runtime.go b/internal/cosigner/runtime.go index 0aa7f380..2e528459 100644 --- a/internal/cosigner/runtime.go +++ b/internal/cosigner/runtime.go @@ -5,27 +5,25 @@ import ( "crypto/ed25519" "encoding/json" "errors" - "fmt" "strings" "sync" "time" "github.com/fystack/mpcium/pkg/logger" - "github.com/nats-io/nats.go" "github.com/vietddude/mpcium-sdk/participant" sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) type Runtime struct { cfg Config - nc *nats.Conn - stores *badgerStores + transport Transport + stores Stores identity *localIdentity coordLookup *coordinatorLookup sessionsMu sync.RWMutex sessions map[string]*participant.ParticipantSession sessionMeta map[string]sessionMeta - subs []*nats.Subscription + subs []Subscription } type sessionMeta struct { @@ -34,47 +32,40 @@ type sessionMeta struct { } func NewRuntime(cfg Config) (*Runtime, error) { - if cfg.NodeID == "" { - return nil, errors.New("node_id is required") - } - if cfg.NATSURL == "" { - return nil, errors.New("nats_url is required") - } - if cfg.CoordinatorID == "" || len(cfg.CoordinatorPublicKey) != ed25519.PublicKeySize { - return nil, errors.New("valid coordinator key is required") - } - if len(cfg.IdentityPrivateKey) != ed25519.PrivateKeySize { - return nil, errors.New("valid identity private key is required") - } - if cfg.MaxActiveSessions <= 0 { - cfg.MaxActiveSessions = 64 + transport, err := NewNATSTransport(cfg.NATSURL) + if err != nil { + return nil, err } - if cfg.PresenceInterval <= 0 { - cfg.PresenceInterval = 5 * time.Second + return NewRuntimeWithTransport(cfg, transport) +} + +func NewRuntimeWithTransport(cfg Config, transport Transport) (*Runtime, error) { + if transport == nil { + return nil, errors.New("transport is required") } - if cfg.TickInterval <= 0 { - cfg.TickInterval = 100 * time.Millisecond + stores, err := newBadgerStores(cfg.DataDir) + if err != nil { + transport.Close() + return nil, err } - - nc, err := nats.Connect(cfg.NATSURL) + identity, err := NewLocalIdentity(cfg.NodeID, cfg.IdentityPrivateKey) if err != nil { - return nil, fmt.Errorf("connect nats: %w", err) + transport.Close() + _ = stores.Close() + return nil, err } - stores, err := newBadgerStores(cfg.DataDir) + coordLookup, err := NewCoordinatorLookup(cfg.CoordinatorID, cfg.CoordinatorPublicKey) if err != nil { - nc.Close() + transport.Close() + _ = stores.Close() return nil, err } - private := ed25519.PrivateKey(cfg.IdentityPrivateKey) - public := private.Public().(ed25519.PublicKey) return &Runtime{ - cfg: cfg, - nc: nc, - stores: stores, - identity: &localIdentity{participantID: cfg.NodeID, publicKey: public, privateKey: private}, - coordLookup: &coordinatorLookup{keys: map[string]ed25519.PublicKey{ - cfg.CoordinatorID: append([]byte(nil), cfg.CoordinatorPublicKey...), - }}, + cfg: cfg, + transport: transport, + stores: stores, + identity: identity, + coordLookup: coordLookup, sessions: map[string]*participant.ParticipantSession{}, sessionMeta: map[string]sessionMeta{}, }, nil @@ -84,8 +75,8 @@ func (r *Runtime) Close() error { for _, sub := range r.subs { _ = sub.Unsubscribe() } - if r.nc != nil { - r.nc.Close() + if r.transport != nil { + r.transport.Close() } if r.stores != nil { return r.stores.Close() @@ -125,29 +116,29 @@ func (r *Runtime) Run(ctx context.Context) error { } func (r *Runtime) subscribe() error { - controlSub, err := r.nc.Subscribe(controlSubject(r.cfg.NodeID), func(msg *nats.Msg) { - if err := r.handleControl(msg.Data); err != nil { + controlSub, err := r.transport.Subscribe(controlSubject(r.cfg.NodeID), func(raw []byte) { + if err := r.handleControl(raw); err != nil { logger.Error("handle control message failed", err) } }) if err != nil { return err } - logger.Info("subscribed control subject") + logger.Info("subscribed control subject", "subject", controlSubject(r.cfg.NodeID)) r.subs = append(r.subs, controlSub) - p2pSub, err := r.nc.Subscribe(p2pWildcardSubject(r.cfg.NodeID), func(msg *nats.Msg) { - if err := r.handlePeer(msg.Data); err != nil { + p2pSub, err := r.transport.Subscribe(p2pWildcardSubject(r.cfg.NodeID), func(raw []byte) { + if err := r.handlePeer(raw); err != nil { logger.Error("handle peer message failed", err) } }) if err != nil { return err } - logger.Info("subscribed p2p subject") + logger.Info("subscribed p2p subject", "subject", p2pWildcardSubject(r.cfg.NodeID)) r.subs = append(r.subs, p2pSub) - return r.nc.Flush() + return r.transport.Flush() } func (r *Runtime) handleControl(raw []byte) error { @@ -184,11 +175,11 @@ func (r *Runtime) handleControl(raw []byte) error { logger.Warn("ignoring control for unknown session", "session_id", msg.SessionID) return nil } - effects, err := session.HandleControl(&msg) + actions, err := session.HandleControl(&msg) if err != nil { return err } - return r.publishEffects(effects) + return r.dispatchActions(actions) } func (r *Runtime) startSession(msg *sdkprotocol.ControlMessage, meta sessionMeta) error { @@ -209,7 +200,7 @@ func (r *Runtime) startSession(msg *sdkprotocol.ControlMessage, meta sessionMeta Start: msg.SessionStart, LocalParticipantID: r.cfg.NodeID, Identity: r.identity, - Peers: &peerLookup{keys: peerKeys}, + Peers: NewPeerLookup(peerKeys), Coordinator: r.coordLookup, Preparams: r.stores, Shares: r.stores, @@ -224,11 +215,11 @@ func (r *Runtime) startSession(msg *sdkprotocol.ControlMessage, meta sessionMeta r.sessionsMu.Unlock() logger.Info("cosigner started session", "session_id", msg.SessionID, "action", meta.action) - effects, err := sess.Start() + actions, err := sess.Start() if err != nil { return err } - return r.publishEffects(effects) + return r.dispatchActions(actions) } func (r *Runtime) handlePeer(raw []byte) error { @@ -247,11 +238,11 @@ func (r *Runtime) handlePeer(raw []byte) error { logger.Warn("ignoring peer message for unknown session", "session_id", msg.SessionID) return nil } - effects, err := session.HandlePeer(&msg) + actions, err := session.HandlePeer(&msg) if err != nil { return err } - return r.publishEffects(effects) + return r.dispatchActions(actions) } func (r *Runtime) tickSessions() error { @@ -266,61 +257,40 @@ func (r *Runtime) tickSessions() error { if session == nil { continue } - effects, err := session.Tick(time.Now()) + actions, err := session.Tick(time.Now()) if err != nil { return err } - if err := r.publishEffects(effects); err != nil { + if err := r.dispatchActions(actions); err != nil { return err } } return nil } -func (r *Runtime) publishEffects(effects participant.Effects) error { - meta := r.metaFromEffects(effects) - if effects.Cleanup != nil { - outcome := "finished" - if effects.Result == nil { - outcome = "failed" - } - logger.Info("cosigner ended session", "session_id", effects.Cleanup.SessionID, "outcome", outcome) - } - if len(effects.SessionEvents) > 0 { - logger.Debug("cosigner publishing session events", - "node_id", r.cfg.NodeID, - "session_events", len(effects.SessionEvents), - "protocol", meta.protocol, - "action", meta.action, - ) - } - if len(effects.PeerMessages) > 0 { - logger.Debug("cosigner publishing peer messages", - "node_id", r.cfg.NodeID, - "peer_messages", len(effects.PeerMessages), - ) - } - for _, peerMsg := range effects.PeerMessages { +func (r *Runtime) dispatchActions(actions participant.Actions) error { + logger.Debug("dispatching actions", "actions", actions) + for _, peerMsg := range actions.PeerMessages { raw, err := json.Marshal(peerMsg) if err != nil { return err } - if err := r.nc.Publish(p2pSubject(peerMsg.ToParticipantID, peerMsg.SessionID), raw); err != nil { + if err := r.transport.Publish(p2pSubject(peerMsg.ToParticipantID, peerMsg.SessionID), raw); err != nil { return err } } - for _, event := range effects.SessionEvents { + for _, event := range actions.SessionEvents { raw, err := json.Marshal(event) if err != nil { return err } - if err := r.nc.Publish(sessionEventSubject(event.SessionID), raw); err != nil { + if err := r.transport.Publish(sessionEventSubject(event.SessionID), raw); err != nil { return err } } - if effects.Cleanup != nil && effects.Cleanup.DropArtifacts { - r.dropSessionMeta(effects.Cleanup.SessionID) - _ = r.stores.DeleteSessionArtifacts(effects.Cleanup.SessionID) + if actions.Cleanup != nil && actions.Cleanup.DropArtifacts { + r.dropSessionMeta(actions.Cleanup.SessionID) + _ = r.stores.DeleteSessionArtifacts(actions.Cleanup.SessionID) } return nil } @@ -334,16 +304,6 @@ func (r *Runtime) getSessionMeta(sessionID string) sessionMeta { return sessionMeta{protocol: "unknown", action: "unknown"} } -func (r *Runtime) metaFromEffects(effects participant.Effects) sessionMeta { - if len(effects.SessionEvents) > 0 && effects.SessionEvents[0] != nil { - return r.getSessionMeta(effects.SessionEvents[0].SessionID) - } - if len(effects.PeerMessages) > 0 && effects.PeerMessages[0] != nil { - return r.getSessionMeta(effects.PeerMessages[0].SessionID) - } - return sessionMeta{protocol: "unknown", action: "unknown"} -} - func (r *Runtime) dropSessionMeta(sessionID string) { r.sessionsMu.Lock() defer r.sessionsMu.Unlock() @@ -390,20 +350,25 @@ func actionLabel(operation sdkprotocol.OperationType) string { } func (r *Runtime) publishPresence(status sdkprotocol.PresenceStatus) error { + transportType := r.transport.ProtocolType() + connectionPrefix := strings.ToLower(string(transportType)) + if connectionPrefix == "" || transportType == sdkprotocol.TransportTypeUnspecified { + connectionPrefix = "transport" + } event := sdkprotocol.PresenceEvent{ PeerID: r.cfg.NodeID, Status: status, - Transport: sdkprotocol.TransportTypeNATS, + Transport: transportType, LastSeenUnixMs: time.Now().UTC().UnixMilli(), } if status == sdkprotocol.PresenceStatusOnline { - event.ConnectionID = "nats:" + r.cfg.NodeID + event.ConnectionID = connectionPrefix + ":" + r.cfg.NodeID } raw, err := json.Marshal(event) if err != nil { return err } - return r.nc.Publish(presenceSubject(r.cfg.NodeID), raw) + return r.transport.Publish(presenceSubject(r.cfg.NodeID), raw) } func (r *Runtime) getSession(sessionID string) *participant.ParticipantSession { diff --git a/internal/cosigner/storage.go b/internal/cosigner/storage.go index 9e206f5e..babe5feb 100644 --- a/internal/cosigner/storage.go +++ b/internal/cosigner/storage.go @@ -8,6 +8,29 @@ import ( sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) +type PreparamsStore interface { + LoadPreparams(protocolType sdkprotocol.ProtocolType, keyID string) ([]byte, error) + SavePreparams(protocolType sdkprotocol.ProtocolType, keyID string, preparams []byte) error +} + +type SharesStore interface { + LoadShare(protocolType sdkprotocol.ProtocolType, keyID string) ([]byte, error) + SaveShare(protocolType sdkprotocol.ProtocolType, keyID string, share []byte) error +} + +type SessionArtifactsStore interface { + LoadSessionArtifacts(sessionID string) ([]byte, error) + SaveSessionArtifacts(sessionID string, artifact []byte) error + DeleteSessionArtifacts(sessionID string) error +} + +type Stores interface { + PreparamsStore + SharesStore + SessionArtifactsStore + Close() error +} + type badgerStores struct { db *badger.DB } diff --git a/internal/cosigner/transport.go b/internal/cosigner/transport.go new file mode 100644 index 00000000..9c478052 --- /dev/null +++ b/internal/cosigner/transport.go @@ -0,0 +1,54 @@ +package cosigner + +import ( + "fmt" + + "github.com/nats-io/nats.go" + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +type Subscription interface { + Unsubscribe() error +} + +type Transport interface { + Subscribe(subject string, handler func([]byte)) (Subscription, error) + Publish(subject string, payload []byte) error + Flush() error + Close() + ProtocolType() sdkprotocol.TransportType +} + +type natsTransport struct { + nc *nats.Conn +} + +func NewNATSTransport(url string) (Transport, error) { + nc, err := nats.Connect(url) + if err != nil { + return nil, fmt.Errorf("connect nats: %w", err) + } + return &natsTransport{nc: nc}, nil +} + +func (t *natsTransport) Subscribe(subject string, handler func([]byte)) (Subscription, error) { + return t.nc.Subscribe(subject, func(msg *nats.Msg) { + handler(msg.Data) + }) +} + +func (t *natsTransport) Publish(subject string, payload []byte) error { + return t.nc.Publish(subject, payload) +} + +func (t *natsTransport) Flush() error { + return t.nc.Flush() +} + +func (t *natsTransport) Close() { + t.nc.Close() +} + +func (t *natsTransport) ProtocolType() sdkprotocol.TransportType { + return sdkprotocol.TransportTypeNATS +} From d00f0d8d850e02ca74c63ed819a87b3b49da6f62 Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 17 Apr 2026 14:54:10 +0700 Subject: [PATCH 05/23] Enhance logger tests to verify output includes test file name and excludes source file name. Update Error function to skip an additional stack frame for accurate caller reporting. --- pkg/logger/logger.go | 4 +++- pkg/logger/logger_test.go | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 034711e7..5f12e17f 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -116,7 +116,9 @@ func Error(msg string, err error, keyValues ...interface{}) { ctx = ctx.Interface(key, value) } - ctx.Caller().Stack().Err(err).Msg(msg) + // Skip one additional frame so caller points to the code using logger.Error, + // not this wrapper function itself. + ctx.Caller(1).Stack().Err(err).Msg(msg) } // Fatal logs a fatal message and exits the program. diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go index ab41fc3f..5fda0134 100644 --- a/pkg/logger/logger_test.go +++ b/pkg/logger/logger_test.go @@ -38,6 +38,8 @@ func TestError_WithError(t *testing.T) { assert.Contains(t, output, "test error message") assert.Contains(t, output, "level\":\"error\"") assert.Contains(t, output, "test error") + assert.Contains(t, output, "logger_test.go") + assert.NotContains(t, output, "pkg/logger/logger.go") } func TestError_WithoutError(t *testing.T) { From f8fb62193457e3ec741e6edbca90e028877928a8 Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 17 Apr 2026 15:27:24 +0700 Subject: [PATCH 06/23] Refactor cosigner to support multiple relay providers (NATS and MQTT). Introduce Relay interface for unified messaging operations, enhancing configuration management and session handling. Update runtime to utilize relay abstraction, improving flexibility and maintainability. --- internal/cosigner/config.go | 77 +++++++++------ internal/cosigner/relay.go | 30 ++++++ internal/cosigner/relay_mqtt.go | 100 +++++++++++++++++++ internal/cosigner/relay_nats.go | 45 +++++++++ internal/cosigner/runtime.go | 164 ++++++++++++++++++++++++++------ internal/cosigner/storage.go | 4 +- internal/cosigner/transport.go | 54 ----------- 7 files changed, 361 insertions(+), 113 deletions(-) create mode 100644 internal/cosigner/relay.go create mode 100644 internal/cosigner/relay_mqtt.go create mode 100644 internal/cosigner/relay_nats.go delete mode 100644 internal/cosigner/transport.go diff --git a/internal/cosigner/config.go b/internal/cosigner/config.go index 6530e8be..b09bf477 100644 --- a/internal/cosigner/config.go +++ b/internal/cosigner/config.go @@ -10,9 +10,18 @@ import ( "github.com/spf13/viper" ) +type RelayProvider string + +const ( + RelayProviderNATS RelayProvider = "nats" + RelayProviderMQTT RelayProvider = "mqtt" +) + type Config struct { + RelayProvider RelayProvider NodeID string NATSURL string + MQTT mqttConfig CoordinatorID string CoordinatorPublicKey []byte IdentityPrivateKey []byte @@ -22,29 +31,23 @@ type Config struct { TickInterval time.Duration } +// Flat keys for compact config style. type fileConfig struct { - NATS natsConfig `mapstructure:"nats"` - Cosigner cosignerConfig `mapstructure:"cosigner"` + RelayProvider RelayProvider `mapstructure:"relay_provider"` + NATSURL string `mapstructure:"nats_url"` + MQTT mqttConfig `mapstructure:"mqtt"` + NodeID string `mapstructure:"node_id"` + DataDir string `mapstructure:"data_dir"` + CoordinatorID string `mapstructure:"coordinator_id"` + CoordinatorPublicKeyHex string `mapstructure:"coordinator_public_key_hex"` + IdentityPrivateKeyHex string `mapstructure:"identity_private_key_hex"` } -type natsConfig struct { - URL string `mapstructure:"url"` -} - -type cosignerConfig struct { - NodeID string `mapstructure:"node_id"` - DataDir string `mapstructure:"data_dir"` - Coordinator coordinatorConfig `mapstructure:"coordinator"` - Identity identityConfig `mapstructure:"identity"` -} - -type coordinatorConfig struct { - ID string `mapstructure:"id"` - PublicKey string `mapstructure:"public_key_hex"` -} - -type identityConfig struct { - PrivateKey string `mapstructure:"private_key_hex"` +type mqttConfig struct { + Broker string `mapstructure:"broker"` + ClientID string `mapstructure:"client_id"` + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` } func LoadConfig() (Config, error) { @@ -52,24 +55,25 @@ func LoadConfig() (Config, error) { if err := viper.Unmarshal(&cfg, viper.DecodeHook(mapstructure.StringToTimeDurationHookFunc())); err != nil { return Config{}, fmt.Errorf("decode config: %w", err) } - - coordinatorKey, err := decodeHexKey(cfg.Cosigner.Coordinator.PublicKey, "coordinator public key") + coordinatorKey, err := decodeHexKey(cfg.CoordinatorPublicKeyHex, "coordinator public key") if err != nil { return Config{}, err } - privateKey, err := decodeHexKey(cfg.Cosigner.Identity.PrivateKey, "identity private key") + privateKey, err := decodeHexKey(cfg.IdentityPrivateKeyHex, "identity private key") if err != nil { return Config{}, err } runtimeCfg := Config{ - NodeID: cfg.Cosigner.NodeID, - NATSURL: cfg.NATS.URL, - CoordinatorID: cfg.Cosigner.Coordinator.ID, + RelayProvider: cfg.RelayProvider, + NodeID: cfg.NodeID, + NATSURL: cfg.NATSURL, + MQTT: cfg.MQTT, + CoordinatorID: cfg.CoordinatorID, CoordinatorPublicKey: coordinatorKey, IdentityPrivateKey: privateKey, - DataDir: cfg.Cosigner.DataDir, + DataDir: cfg.DataDir, } runtimeCfg.applyDefaults() if err := runtimeCfg.Validate(); err != nil { @@ -87,6 +91,9 @@ func decodeHexKey(value, name string) ([]byte, error) { } func (cfg *Config) applyDefaults() { + if cfg.RelayProvider == "" { + cfg.RelayProvider = RelayProviderNATS + } if cfg.MaxActiveSessions <= 0 { cfg.MaxActiveSessions = 10 } @@ -102,8 +109,20 @@ func (cfg Config) Validate() error { if cfg.NodeID == "" { return fmt.Errorf("node_id is required") } - if cfg.NATSURL == "" { - return fmt.Errorf("nats_url is required") + switch cfg.RelayProvider { + case RelayProviderNATS: + if cfg.NATSURL == "" { + return fmt.Errorf("nats_url is required for relay provider nats") + } + case RelayProviderMQTT: + if cfg.MQTT.Broker == "" { + return fmt.Errorf("mqtt.broker is required for relay provider mqtt") + } + if cfg.MQTT.ClientID == "" { + return fmt.Errorf("mqtt.client_id is required for relay provider mqtt") + } + default: + return fmt.Errorf("unsupported relay provider: %s", cfg.RelayProvider) } if cfg.CoordinatorID == "" || len(cfg.CoordinatorPublicKey) != ed25519.PublicKeySize { return fmt.Errorf("valid coordinator key is required") diff --git a/internal/cosigner/relay.go b/internal/cosigner/relay.go new file mode 100644 index 00000000..ee969445 --- /dev/null +++ b/internal/cosigner/relay.go @@ -0,0 +1,30 @@ +package cosigner + +import ( + "fmt" + + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +type Subscription interface { + Unsubscribe() error +} + +type Relay interface { + Subscribe(subject string, handler func([]byte)) (Subscription, error) + Publish(subject string, payload []byte) error + Flush() error + Close() + ProtocolType() sdkprotocol.TransportType +} + +func NewRelayFromConfig(cfg Config) (Relay, error) { + switch cfg.RelayProvider { + case RelayProviderNATS: + return NewNATSRelay(cfg.NATSURL) + case RelayProviderMQTT: + return NewMQTTRelay(cfg.MQTT) + default: + return nil, fmt.Errorf("unsupported relay provider: %s", cfg.RelayProvider) + } +} diff --git a/internal/cosigner/relay_mqtt.go b/internal/cosigner/relay_mqtt.go new file mode 100644 index 00000000..ec07b2c9 --- /dev/null +++ b/internal/cosigner/relay_mqtt.go @@ -0,0 +1,100 @@ +package cosigner + +import ( + "fmt" + "strings" + "time" + + mqtt "github.com/eclipse/paho.mqtt.golang" + "github.com/fystack/mpcium/pkg/logger" + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +const mqttOperationTimeout = 10 * time.Second + +type mqttRelay struct { + client mqtt.Client +} + +func NewMQTTRelay(cfg mqttConfig) (Relay, error) { + opts := mqtt.NewClientOptions() + opts.AddBroker(cfg.Broker) + opts.SetClientID(cfg.ClientID) + opts.SetUsername(cfg.Username) + opts.SetPassword(cfg.Password) + opts.SetCleanSession(true) + opts.SetAutoReconnect(true) + opts.SetOrderMatters(false) + + client := mqtt.NewClient(opts) + token := client.Connect() + if !token.WaitTimeout(mqttOperationTimeout) { + return nil, fmt.Errorf("connect mqtt timeout") + } + if err := token.Error(); err != nil { + return nil, fmt.Errorf("connect mqtt: %w", err) + } + return &mqttRelay{client: client}, nil +} + +func (r *mqttRelay) Subscribe(subject string, handler func([]byte)) (Subscription, error) { + topic := natsToMQTTTopic(subject) + logger.Info("relay mqtt subscribe", "subject", subject, "topic", topic) + token := r.client.Subscribe(topic, 1, func(_ mqtt.Client, msg mqtt.Message) { + logger.Debug("relay mqtt received message", "topic", msg.Topic(), "bytes", len(msg.Payload())) + handler(append([]byte(nil), msg.Payload()...)) + }) + if !token.WaitTimeout(mqttOperationTimeout) { + return nil, fmt.Errorf("subscribe mqtt timeout topic=%s", topic) + } + if err := token.Error(); err != nil { + return nil, fmt.Errorf("subscribe mqtt topic=%s: %w", topic, err) + } + return mqttSubscription{client: r.client, topic: topic}, nil +} + +func (r *mqttRelay) Publish(subject string, payload []byte) error { + topic := natsToMQTTTopic(subject) + logger.Debug("relay mqtt publish", "subject", subject, "topic", topic) + token := r.client.Publish(topic, 1, false, payload) + if !token.WaitTimeout(mqttOperationTimeout) { + return fmt.Errorf("publish mqtt timeout topic=%s", topic) + } + if err := token.Error(); err != nil { + return fmt.Errorf("publish mqtt topic=%s: %w", topic, err) + } + return nil +} + +func (r *mqttRelay) Flush() error { + return nil +} + +func (r *mqttRelay) Close() { + r.client.Disconnect(250) +} + +func (r *mqttRelay) ProtocolType() sdkprotocol.TransportType { + return sdkprotocol.TransportTypeMQTT +} + +type mqttSubscription struct { + client mqtt.Client + topic string +} + +func (s mqttSubscription) Unsubscribe() error { + if s.client == nil || s.topic == "" { + return nil + } + unsub := s.client.Unsubscribe(s.topic) + if !unsub.WaitTimeout(mqttOperationTimeout) { + return fmt.Errorf("unsubscribe mqtt timeout") + } + return unsub.Error() +} + +func natsToMQTTTopic(subject string) string { + replacer := strings.NewReplacer(".", "/", "*", "+") + return replacer.Replace(subject) +} diff --git a/internal/cosigner/relay_nats.go b/internal/cosigner/relay_nats.go new file mode 100644 index 00000000..6d9706f4 --- /dev/null +++ b/internal/cosigner/relay_nats.go @@ -0,0 +1,45 @@ +package cosigner + +import ( + "fmt" + + "github.com/fystack/mpcium/pkg/logger" + "github.com/nats-io/nats.go" + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +type NATSRelay struct { + nc *nats.Conn +} + +func NewNATSRelay(url string) (Relay, error) { + nc, err := nats.Connect(url) + if err != nil { + return nil, fmt.Errorf("connect nats: %w", err) + } + return &NATSRelay{nc: nc}, nil +} + +func (t *NATSRelay) Subscribe(subject string, handler func([]byte)) (Subscription, error) { + logger.Info("relay nats subscribe", "subject", subject) + return t.nc.Subscribe(subject, func(msg *nats.Msg) { + handler(msg.Data) + }) +} + +func (t *NATSRelay) Publish(subject string, payload []byte) error { + logger.Debug("relay nats publish", "subject", subject) + return t.nc.Publish(subject, payload) +} + +func (t *NATSRelay) Flush() error { + return t.nc.Flush() +} + +func (t *NATSRelay) Close() { + t.nc.Close() +} + +func (t *NATSRelay) ProtocolType() sdkprotocol.TransportType { + return sdkprotocol.TransportTypeNATS +} diff --git a/internal/cosigner/runtime.go b/internal/cosigner/runtime.go index 2e528459..2ece0dc7 100644 --- a/internal/cosigner/runtime.go +++ b/internal/cosigner/runtime.go @@ -5,6 +5,7 @@ import ( "crypto/ed25519" "encoding/json" "errors" + "fmt" "strings" "sync" "time" @@ -16,7 +17,7 @@ import ( type Runtime struct { cfg Config - transport Transport + relay Relay stores Stores identity *localIdentity coordLookup *coordinatorLookup @@ -32,37 +33,30 @@ type sessionMeta struct { } func NewRuntime(cfg Config) (*Runtime, error) { - transport, err := NewNATSTransport(cfg.NATSURL) + relay, err := NewRelayFromConfig(cfg) if err != nil { return nil, err } - return NewRuntimeWithTransport(cfg, transport) -} - -func NewRuntimeWithTransport(cfg Config, transport Transport) (*Runtime, error) { - if transport == nil { - return nil, errors.New("transport is required") - } - stores, err := newBadgerStores(cfg.DataDir) + stores, err := newBadgerStores(cfg.DataDir, cfg.NodeID) if err != nil { - transport.Close() + relay.Close() return nil, err } identity, err := NewLocalIdentity(cfg.NodeID, cfg.IdentityPrivateKey) if err != nil { - transport.Close() + relay.Close() _ = stores.Close() return nil, err } coordLookup, err := NewCoordinatorLookup(cfg.CoordinatorID, cfg.CoordinatorPublicKey) if err != nil { - transport.Close() + relay.Close() _ = stores.Close() return nil, err } return &Runtime{ cfg: cfg, - transport: transport, + relay: relay, stores: stores, identity: identity, coordLookup: coordLookup, @@ -75,8 +69,8 @@ func (r *Runtime) Close() error { for _, sub := range r.subs { _ = sub.Unsubscribe() } - if r.transport != nil { - r.transport.Close() + if r.relay != nil { + r.relay.Close() } if r.stores != nil { return r.stores.Close() @@ -101,7 +95,7 @@ func (r *Runtime) Run(ctx context.Context) error { select { case <-ctx.Done(): logger.Info("cosigner runtime stopping", "node_id", r.cfg.NodeID) - _ = r.publishPresence(sdkprotocol.PresenceStatusOffline) + r.publishPresenceOnShutdown() return nil case <-tick.C: if err := r.tickSessions(); err != nil { @@ -116,7 +110,7 @@ func (r *Runtime) Run(ctx context.Context) error { } func (r *Runtime) subscribe() error { - controlSub, err := r.transport.Subscribe(controlSubject(r.cfg.NodeID), func(raw []byte) { + controlSub, err := r.relay.Subscribe(controlSubject(r.cfg.NodeID), func(raw []byte) { if err := r.handleControl(raw); err != nil { logger.Error("handle control message failed", err) } @@ -124,10 +118,9 @@ func (r *Runtime) subscribe() error { if err != nil { return err } - logger.Info("subscribed control subject", "subject", controlSubject(r.cfg.NodeID)) r.subs = append(r.subs, controlSub) - p2pSub, err := r.transport.Subscribe(p2pWildcardSubject(r.cfg.NodeID), func(raw []byte) { + p2pSub, err := r.relay.Subscribe(p2pWildcardSubject(r.cfg.NodeID), func(raw []byte) { if err := r.handlePeer(raw); err != nil { logger.Error("handle peer message failed", err) } @@ -135,10 +128,9 @@ func (r *Runtime) subscribe() error { if err != nil { return err } - logger.Info("subscribed p2p subject", "subject", p2pWildcardSubject(r.cfg.NodeID)) r.subs = append(r.subs, p2pSub) - return r.transport.Flush() + return r.relay.Flush() } func (r *Runtime) handleControl(raw []byte) error { @@ -147,6 +139,21 @@ func (r *Runtime) handleControl(raw []byte) error { return err } if err := sdkprotocol.ValidateControlMessage(&msg); err != nil { + if !hasControlBody(&msg) { + logger.Warn("ignoring control message without body") + return nil + } + logger.Error("invalid control message received", err, + "node_id", r.cfg.NodeID, + "session_id", msg.SessionID, + "sequence", msg.Sequence, + "coordinator_id", msg.CoordinatorID, + "has_session_start", msg.SessionStart != nil, + "has_key_exchange", msg.KeyExchange != nil, + "has_mpc_begin", msg.MPCBegin != nil, + "has_session_abort", msg.SessionAbort != nil, + "raw_control_json", string(raw), + ) return err } @@ -175,8 +182,35 @@ func (r *Runtime) handleControl(raw []byte) error { logger.Warn("ignoring control for unknown session", "session_id", msg.SessionID) return nil } + if msg.SessionAbort != nil { + // Current SDK participant session doesn't handle SessionAbort control messages. + // Treat abort as terminal, clean up local session state, and stop processing. + logger.Warn("cosigner received session abort", + "node_id", r.cfg.NodeID, + "session_id", msg.SessionID, + "reason", msg.SessionAbort.Reason, + "detail", msg.SessionAbort.Detail, + ) + logger.Info("cosigner session ended", + "node_id", r.cfg.NodeID, + "session_id", msg.SessionID, + "outcome", "aborted", + "reason", msg.SessionAbort.Reason, + ) + r.dropSessionMeta(msg.SessionID) + _ = r.stores.DeleteSessionArtifacts(msg.SessionID) + return nil + } actions, err := session.HandleControl(&msg) if err != nil { + logger.Error("session handle control failed", err, + "node_id", r.cfg.NodeID, + "session_id", msg.SessionID, + "sequence", msg.Sequence, + "coordinator_id", msg.CoordinatorID, + "control_type", controlType(&msg), + "raw_control_json", string(raw), + ) return err } return r.dispatchActions(actions) @@ -269,26 +303,43 @@ func (r *Runtime) tickSessions() error { } func (r *Runtime) dispatchActions(actions participant.Actions) error { - logger.Debug("dispatching actions", "actions", actions) for _, peerMsg := range actions.PeerMessages { raw, err := json.Marshal(peerMsg) if err != nil { return err } - if err := r.transport.Publish(p2pSubject(peerMsg.ToParticipantID, peerMsg.SessionID), raw); err != nil { + if err := r.relay.Publish(p2pSubject(peerMsg.ToParticipantID, peerMsg.SessionID), raw); err != nil { return err } } for _, event := range actions.SessionEvents { - raw, err := json.Marshal(event) + sanitized, err := sanitizeAndResignSessionEvent(event, r.cfg.IdentityPrivateKey) if err != nil { return err } - if err := r.transport.Publish(sessionEventSubject(event.SessionID), raw); err != nil { + raw, err := json.Marshal(sanitized) + if err != nil { + return err + } + if err := r.relay.Publish(sessionEventSubject(sanitized.SessionID), raw); err != nil { return err } } if actions.Cleanup != nil && actions.Cleanup.DropArtifacts { + outcome := "cleanup" + if actions.Result != nil { + switch { + case actions.Result.KeyShare != nil: + outcome = "completed_keygen" + case actions.Result.Signature != nil: + outcome = "completed_sign" + } + } + logger.Info("cosigner session ended", + "node_id", r.cfg.NodeID, + "session_id", actions.Cleanup.SessionID, + "outcome", outcome, + ) r.dropSessionMeta(actions.Cleanup.SessionID) _ = r.stores.DeleteSessionArtifacts(actions.Cleanup.SessionID) } @@ -328,6 +379,38 @@ func controlType(msg *sdkprotocol.ControlMessage) string { } } +func sanitizeSessionEvent(event *sdkprotocol.SessionEvent) *sdkprotocol.SessionEvent { + if event == nil || event.SessionCompleted == nil || event.SessionCompleted.Result == nil || event.SessionCompleted.Result.KeyShare == nil { + return event + } + clone := *event + completed := *event.SessionCompleted + result := *event.SessionCompleted.Result + keyShare := *event.SessionCompleted.Result.KeyShare + // Never publish secret share material over relay topics. + keyShare.ShareBlob = nil + result.KeyShare = &keyShare + completed.Result = &result + clone.SessionCompleted = &completed + return &clone +} + +func sanitizeAndResignSessionEvent(event *sdkprotocol.SessionEvent, privateKey []byte) (*sdkprotocol.SessionEvent, error) { + sanitized := sanitizeSessionEvent(event) + if sanitized == nil || sanitized == event { + return event, nil + } + if len(privateKey) != ed25519.PrivateKeySize { + return nil, fmt.Errorf("invalid identity private key size: %d", len(privateKey)) + } + payload, err := sdkprotocol.SessionEventSigningBytes(sanitized) + if err != nil { + return nil, err + } + sanitized.Signature = ed25519.Sign(ed25519.PrivateKey(privateKey), payload) + return sanitized, nil +} + func protocolLabel(protocol sdkprotocol.ProtocolType) string { value := strings.TrimSpace(string(protocol)) if value == "" || value == string(sdkprotocol.ProtocolTypeUnspecified) { @@ -350,7 +433,7 @@ func actionLabel(operation sdkprotocol.OperationType) string { } func (r *Runtime) publishPresence(status sdkprotocol.PresenceStatus) error { - transportType := r.transport.ProtocolType() + transportType := r.relay.ProtocolType() connectionPrefix := strings.ToLower(string(transportType)) if connectionPrefix == "" || transportType == sdkprotocol.TransportTypeUnspecified { connectionPrefix = "transport" @@ -368,7 +451,7 @@ func (r *Runtime) publishPresence(status sdkprotocol.PresenceStatus) error { if err != nil { return err } - return r.transport.Publish(presenceSubject(r.cfg.NodeID), raw) + return r.relay.Publish(presenceSubject(r.cfg.NodeID), raw) } func (r *Runtime) getSession(sessionID string) *participant.ParticipantSession { @@ -391,3 +474,28 @@ func (r *Runtime) verifyControlSignature(msg *sdkprotocol.ControlMessage) error } return nil } + +func (r *Runtime) publishPresenceOnShutdown() { + done := make(chan error, 1) + go func() { + done <- r.publishPresence(sdkprotocol.PresenceStatusOffline) + }() + select { + case err := <-done: + if err != nil { + logger.Warn("failed to publish offline presence", "error", err) + } + case <-time.After(500 * time.Millisecond): + logger.Warn("timed out publishing offline presence") + } +} + +func hasControlBody(msg *sdkprotocol.ControlMessage) bool { + if msg == nil { + return false + } + return msg.SessionStart != nil || + msg.KeyExchange != nil || + msg.MPCBegin != nil || + msg.SessionAbort != nil +} diff --git a/internal/cosigner/storage.go b/internal/cosigner/storage.go index babe5feb..a801891e 100644 --- a/internal/cosigner/storage.go +++ b/internal/cosigner/storage.go @@ -35,8 +35,8 @@ type badgerStores struct { db *badger.DB } -func newBadgerStores(dataDir string) (*badgerStores, error) { - opts := badger.DefaultOptions(filepath.Join(dataDir, "node-v1-badger")) +func newBadgerStores(dataDir string, nodeID string) (*badgerStores, error) { + opts := badger.DefaultOptions(filepath.Join(dataDir, nodeID)) opts.Logger = nil db, err := badger.Open(opts) if err != nil { diff --git a/internal/cosigner/transport.go b/internal/cosigner/transport.go deleted file mode 100644 index 9c478052..00000000 --- a/internal/cosigner/transport.go +++ /dev/null @@ -1,54 +0,0 @@ -package cosigner - -import ( - "fmt" - - "github.com/nats-io/nats.go" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" -) - -type Subscription interface { - Unsubscribe() error -} - -type Transport interface { - Subscribe(subject string, handler func([]byte)) (Subscription, error) - Publish(subject string, payload []byte) error - Flush() error - Close() - ProtocolType() sdkprotocol.TransportType -} - -type natsTransport struct { - nc *nats.Conn -} - -func NewNATSTransport(url string) (Transport, error) { - nc, err := nats.Connect(url) - if err != nil { - return nil, fmt.Errorf("connect nats: %w", err) - } - return &natsTransport{nc: nc}, nil -} - -func (t *natsTransport) Subscribe(subject string, handler func([]byte)) (Subscription, error) { - return t.nc.Subscribe(subject, func(msg *nats.Msg) { - handler(msg.Data) - }) -} - -func (t *natsTransport) Publish(subject string, payload []byte) error { - return t.nc.Publish(subject, payload) -} - -func (t *natsTransport) Flush() error { - return t.nc.Flush() -} - -func (t *natsTransport) Close() { - t.nc.Close() -} - -func (t *natsTransport) ProtocolType() sdkprotocol.TransportType { - return sdkprotocol.TransportTypeNATS -} From 5d1d03aa5708c11b9bad45460098f3dc0e44308b Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 17 Apr 2026 15:33:14 +0700 Subject: [PATCH 07/23] Add initial implementation of MQTT relay with NATS integration. Introduce configuration management for relay settings, including credential loading and topic mapping. Implement runtime for handling MQTT and NATS messaging, enhancing session management and peer communication. --- cmd/mpcium-relay/main.go | 57 ++++++ internal/relay/auth.go | 56 ++++++ internal/relay/config.go | 122 +++++++++++++ internal/relay/runtime.go | 376 ++++++++++++++++++++++++++++++++++++++ internal/relay/topics.go | 151 +++++++++++++++ 5 files changed, 762 insertions(+) create mode 100644 cmd/mpcium-relay/main.go create mode 100644 internal/relay/auth.go create mode 100644 internal/relay/config.go create mode 100644 internal/relay/runtime.go create mode 100644 internal/relay/topics.go diff --git a/cmd/mpcium-relay/main.go b/cmd/mpcium-relay/main.go new file mode 100644 index 00000000..4dd984a4 --- /dev/null +++ b/cmd/mpcium-relay/main.go @@ -0,0 +1,57 @@ +package main + +import ( + "context" + "os" + "os/signal" + "syscall" + + "github.com/fystack/mpcium/internal/relay" + "github.com/fystack/mpcium/pkg/config" + "github.com/fystack/mpcium/pkg/logger" + "github.com/urfave/cli/v3" +) + +const relayConfigPath = "relay.config.yaml" + +func main() { + logger.Init(os.Getenv("ENVIRONMENT"), false) + + cmd := &cli.Command{ + Name: "mpcium-relay", + Usage: "Run MQTT relay runtime", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "config", + Aliases: []string{"c"}, + Usage: "Path to relay config file", + Value: relayConfigPath, + }, + }, + Action: run, + } + + if err := cmd.Run(context.Background(), os.Args); err != nil { + logger.Error("relay exited with error", err) + os.Exit(1) + } +} + +func run(ctx context.Context, c *cli.Command) error { + configPath := c.String("config") + config.InitViperConfig(configPath) + cfg, err := relay.LoadConfig() + if err != nil { + return err + } + + runtime, err := relay.NewRuntime(cfg) + if err != nil { + return err + } + defer runtime.Close() + + ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) + defer stop() + return runtime.Run(ctx) +} diff --git a/internal/relay/auth.go b/internal/relay/auth.go new file mode 100644 index 00000000..e32e94d3 --- /dev/null +++ b/internal/relay/auth.go @@ -0,0 +1,56 @@ +package relay + +import ( + "bufio" + "fmt" + "os" + "strings" +) + +type credentialsStore struct { + values map[string]string +} + +func loadCredentials(path string) (*credentialsStore, error) { + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("open credentials file: %w", err) + } + defer file.Close() + + values := map[string]string{} + scanner := bufio.NewScanner(file) + line := 0 + for scanner.Scan() { + line++ + raw := strings.TrimSpace(scanner.Text()) + if raw == "" || strings.HasPrefix(raw, "#") { + continue + } + parts := strings.SplitN(raw, ":", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid credentials format at line %d", line) + } + username := strings.TrimSpace(parts[0]) + password := strings.TrimSpace(parts[1]) + if username == "" || password == "" { + return nil, fmt.Errorf("invalid credentials format at line %d", line) + } + values[username] = password + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("read credentials file: %w", err) + } + if len(values) == 0 { + return nil, fmt.Errorf("credentials file has no entries") + } + return &credentialsStore{values: values}, nil +} + +func (s *credentialsStore) check(username, password string) bool { + if s == nil { + return false + } + expected, ok := s.values[username] + return ok && expected == password +} diff --git a/internal/relay/config.go b/internal/relay/config.go new file mode 100644 index 00000000..df6d70e4 --- /dev/null +++ b/internal/relay/config.go @@ -0,0 +1,122 @@ +package relay + +import ( + "fmt" + "strings" + + "github.com/spf13/viper" +) + +type RuntimeConfig struct { + NATS NATSConfig `mapstructure:"nats"` + MQTT MQTTConfig `mapstructure:"relay.mqtt"` + Bridge BridgeConfig `mapstructure:"relay.bridge"` + Presence PresenceConfig `mapstructure:"relay.presence"` +} + +type NATSConfig struct { + URL string `mapstructure:"url"` + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + TLS *TLSConfig `mapstructure:"tls"` +} + +type TLSConfig struct { + ClientCert string `mapstructure:"client_cert"` + ClientKey string `mapstructure:"client_key"` + CACert string `mapstructure:"ca_cert"` +} + +type MQTTConfig struct { + ListenAddress string `mapstructure:"listen_address"` + UsernamePasswordFile string `mapstructure:"username_password_file"` +} + +type BridgeConfig struct { + NATSPrefix string `mapstructure:"nats_prefix"` + MQTTPrefix string `mapstructure:"mqtt_prefix"` + MQTTQoS byte `mapstructure:"mqtt_qos"` + OriginHeader string `mapstructure:"origin_header"` +} + +type PresenceConfig struct { + EmitConnectDisconnect bool `mapstructure:"emit_connect_disconnect"` +} + +func LoadConfig() (RuntimeConfig, error) { + setDefaults() + + var cfg RuntimeConfig + if err := viper.Unmarshal(&cfg); err != nil { + return RuntimeConfig{}, fmt.Errorf("decode relay config: %w", err) + } + + cfg.normalize() + + if err := cfg.Validate(); err != nil { + return RuntimeConfig{}, err + } + + return cfg, nil +} + +func setDefaults() { + viper.SetDefault("relay.mqtt.listen_address", ":1883") + viper.SetDefault("relay.bridge.nats_prefix", "mpc.v1") + viper.SetDefault("relay.bridge.mqtt_prefix", "mpc/v1") + viper.SetDefault("relay.bridge.mqtt_qos", 1) + viper.SetDefault("relay.bridge.origin_header", "X-MPCIUM-Relay-Origin") + viper.SetDefault("relay.presence.emit_connect_disconnect", true) +} + +func (cfg *RuntimeConfig) normalize() { + cfg.NATS.URL = strings.TrimSpace(cfg.NATS.URL) + cfg.NATS.Username = strings.TrimSpace(cfg.NATS.Username) + cfg.NATS.Password = strings.TrimSpace(cfg.NATS.Password) + + if cfg.NATS.TLS != nil { + cfg.NATS.TLS.ClientCert = strings.TrimSpace(cfg.NATS.TLS.ClientCert) + cfg.NATS.TLS.ClientKey = strings.TrimSpace(cfg.NATS.TLS.ClientKey) + cfg.NATS.TLS.CACert = strings.TrimSpace(cfg.NATS.TLS.CACert) + } + + cfg.MQTT.ListenAddress = strings.TrimSpace(cfg.MQTT.ListenAddress) + cfg.MQTT.UsernamePasswordFile = strings.TrimSpace(cfg.MQTT.UsernamePasswordFile) + + cfg.Bridge.NATSPrefix = strings.TrimSpace(cfg.Bridge.NATSPrefix) + cfg.Bridge.MQTTPrefix = strings.TrimSpace(cfg.Bridge.MQTTPrefix) + cfg.Bridge.OriginHeader = strings.TrimSpace(cfg.Bridge.OriginHeader) +} + +func (cfg RuntimeConfig) Validate() error { + if cfg.NATS.URL == "" { + return fmt.Errorf("nats.url is required") + } + if cfg.MQTT.ListenAddress == "" { + return fmt.Errorf("relay.mqtt.listen_address is required") + } + if cfg.MQTT.UsernamePasswordFile == "" { + return fmt.Errorf("relay.mqtt.username_password_file is required") + } + if cfg.Bridge.NATSPrefix == "" { + return fmt.Errorf("relay.bridge.nats_prefix is required") + } + if cfg.Bridge.MQTTPrefix == "" { + return fmt.Errorf("relay.bridge.mqtt_prefix is required") + } + if cfg.Bridge.OriginHeader == "" { + return fmt.Errorf("relay.bridge.origin_header is required") + } + if cfg.Bridge.MQTTQoS > 2 { + return fmt.Errorf("relay.bridge.mqtt_qos must be 0, 1, or 2") + } + if cfg.NATS.TLS != nil { + if cfg.NATS.TLS.ClientCert == "" { + return fmt.Errorf("nats.tls.client_cert is required when nats.tls is set") + } + if cfg.NATS.TLS.ClientKey == "" { + return fmt.Errorf("nats.tls.client_key is required when nats.tls is set") + } + } + return nil +} diff --git a/internal/relay/runtime.go b/internal/relay/runtime.go new file mode 100644 index 00000000..e1617f4b --- /dev/null +++ b/internal/relay/runtime.go @@ -0,0 +1,376 @@ +package relay + +import ( + "context" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "log/slog" + "os" + "strings" + "sync" + "time" + + "github.com/fystack/mpcium/pkg/logger" + mqtt "github.com/mochi-mqtt/server/v2" + "github.com/mochi-mqtt/server/v2/listeners" + "github.com/mochi-mqtt/server/v2/packets" + "github.com/nats-io/nats.go" + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +const mqttOriginValue = "mqtt" + +type Runtime struct { + cfg RuntimeConfig + nc *nats.Conn + mqttServer *mqtt.Server + mapper topicMapper + credentials *credentialsStore + subs []*nats.Subscription + subsMu sync.Mutex + echoMu sync.Mutex + recentEcho map[string]time.Time + closeOnce sync.Once + closeErr error +} + +func NewRuntime(cfg RuntimeConfig) (*Runtime, error) { + credentials, err := loadCredentials(cfg.MQTT.UsernamePasswordFile) + if err != nil { + return nil, err + } + + nc, err := connectNATS(cfg.NATS) + if err != nil { + return nil, err + } + + mochiLevel := new(slog.LevelVar) + mochiLevel.Set(slog.LevelError) + mochiLogger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: mochiLevel})) + + server := mqtt.New(&mqtt.Options{ + InlineClient: true, + Logger: mochiLogger, + }) + r := &Runtime{ + cfg: cfg, + nc: nc, + mqttServer: server, + mapper: newTopicMapper(cfg.Bridge.NATSPrefix, cfg.Bridge.MQTTPrefix), + credentials: credentials, + recentEcho: map[string]time.Time{}, + } + + hook := &relayHook{runtime: r} + if err := r.mqttServer.AddHook(hook, nil); err != nil { + _ = nc.Drain() + return nil, fmt.Errorf("add relay hook: %w", err) + } + + tcp := listeners.NewTCP(listeners.Config{ID: "mpcium-relay", Address: cfg.MQTT.ListenAddress}) + if err := r.mqttServer.AddListener(tcp); err != nil { + _ = nc.Drain() + return nil, fmt.Errorf("add mqtt listener: %w", err) + } + + return r, nil +} + +func (r *Runtime) Run(ctx context.Context) error { + if err := r.subscribeNATS(); err != nil { + return err + } + if err := r.subscribeMQTTInline(); err != nil { + return err + } + if err := r.mqttServer.Serve(); err != nil { + return fmt.Errorf("mqtt server stopped: %w", err) + } + + logger.Info("relay runtime started", "mqtt_listen", r.cfg.MQTT.ListenAddress, "nats_url", r.cfg.NATS.URL) + + <-ctx.Done() + return r.Close() +} + +func (r *Runtime) Close() error { + r.closeOnce.Do(func() { + r.subsMu.Lock() + for _, sub := range r.subs { + _ = sub.Unsubscribe() + } + r.subs = nil + r.subsMu.Unlock() + + if r.mqttServer != nil { + if err := r.mqttServer.Close(); err != nil { + r.closeErr = err + } + } + if r.nc != nil && !r.nc.IsClosed() { + if err := r.nc.Drain(); err != nil { + r.nc.Close() + } + } + }) + return r.closeErr +} + +func (r *Runtime) subscribeNATS() error { + for _, filter := range []string{r.mapper.natsControlFilter(), r.mapper.natsP2PFilter()} { + filter := filter + logger.Info("relay subscribed NATS filter", "filter", filter) + sub, err := r.nc.Subscribe(filter, func(msg *nats.Msg) { + if strings.EqualFold(msg.Header.Get(r.cfg.Bridge.OriginHeader), mqttOriginValue) { + return + } + topic, ok := r.mapper.natsToMQTT(msg.Subject) + if !ok { + return + } + logger.Debug("relay bridge NATS->MQTT", "subject", msg.Subject, "topic", topic, "bytes", len(msg.Data)) + r.markNATSEcho(topic, msg.Data) + if err := r.mqttServer.Publish(topic, msg.Data, false, r.cfg.Bridge.MQTTQoS); err != nil { + logger.Error("relay publish NATS->MQTT failed", err, "subject", msg.Subject, "topic", topic) + } + }) + if err != nil { + return fmt.Errorf("subscribe nats filter %s: %w", filter, err) + } + r.subsMu.Lock() + r.subs = append(r.subs, sub) + r.subsMu.Unlock() + } + + if err := r.nc.Flush(); err != nil { + return fmt.Errorf("flush nats subscriptions: %w", err) + } + return nil +} + +func (r *Runtime) subscribeMQTTInline() error { + for idx, filter := range []string{r.mapper.mqttP2PFilter(), r.mapper.mqttSessionEventFilter(), r.mapper.mqttPresenceFilter()} { + filter := filter + subID := idx + 1 + logger.Info("relay subscribed MQTT filter", "filter", filter, "sub_id", subID) + if err := r.mqttServer.Subscribe(filter, subID, func(cl *mqtt.Client, _ packets.Subscription, pk packets.Packet) { + clientID := "unknown" + if cl != nil { + clientID = cl.ID + } + if cl == nil { + logger.Debug("relay mqtt callback without client context", "topic", pk.TopicName) + } + if pk.TopicName == "" { + logger.Warn("relay mqtt callback empty topic", "client_id", clientID) + return + } + if r.isRecentNATSEcho(pk.TopicName, pk.Payload) { + logger.Debug("relay skipping echoed NATS->MQTT message", "topic", pk.TopicName, "bytes", len(pk.Payload)) + return + } + subject, ok := r.mapper.mqttToNATS(pk.TopicName) + if !ok { + logger.Warn("relay mqtt topic rejected by mapper", "client_id", clientID, "topic", pk.TopicName) + return + } + logger.Debug("relay bridge MQTT->NATS", "client_id", clientID, "topic", pk.TopicName, "subject", subject, "bytes", len(pk.Payload)) + msg := &nats.Msg{ + Subject: subject, + Data: append([]byte(nil), pk.Payload...), + Header: nats.Header{}, + } + msg.Header.Set(r.cfg.Bridge.OriginHeader, mqttOriginValue) + if err := r.nc.PublishMsg(msg); err != nil { + logger.Error("relay publish MQTT->NATS failed", err, "topic", pk.TopicName, "subject", subject) + } + }); err != nil { + return fmt.Errorf("subscribe mqtt inline filter %s: %w", filter, err) + } + } + return nil +} + +func (r *Runtime) publishPresence(peerID string, status sdkprotocol.PresenceStatus) { + if !r.cfg.Presence.EmitConnectDisconnect { + return + } + event := sdkprotocol.PresenceEvent{ + PeerID: peerID, + Status: status, + Transport: sdkprotocol.TransportTypeMQTT, + LastSeenUnixMs: time.Now().UTC().UnixMilli(), + } + if status == sdkprotocol.PresenceStatusOnline { + event.ConnectionID = "mqtt:" + peerID + } + raw, err := json.Marshal(event) + if err != nil { + logger.Error("relay marshal presence failed", err, "peer_id", peerID) + return + } + subject := r.mapper.natsPresenceSubject(peerID) + if err := r.nc.Publish(subject, raw); err != nil { + logger.Error("relay publish presence failed", err, "subject", subject, "peer_id", peerID) + } +} + +type relayHook struct { + mqtt.HookBase + runtime *Runtime +} + +func (h *relayHook) ID() string { + return "mpcium-relay" +} + +func (h *relayHook) Provides(b byte) bool { + supported := []byte{mqtt.OnConnectAuthenticate, mqtt.OnACLCheck, mqtt.OnSessionEstablished, mqtt.OnDisconnect} + for _, item := range supported { + if item == b { + return true + } + } + return false +} + +func (h *relayHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool { + if cl == nil { + return false + } + username := string(pk.Connect.Username) + password := string(pk.Connect.Password) + if username == "" || username != cl.ID { + logger.Warn("relay mqtt auth rejected", "client_id", cl.ID, "reason", "username must equal client_id") + return false + } + ok := h.runtime.credentials.check(username, password) + if !ok { + logger.Warn("relay mqtt auth rejected", "client_id", cl.ID, "reason", "bad username or password") + } + return ok +} + +func (h *relayHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool { + if cl == nil { + return false + } + if write { + allowed := h.runtime.mapper.allowMQTTWrite(topic) + logger.Debug("relay mqtt acl check", "client_id", cl.ID, "write", true, "topic", topic, "allowed", allowed) + return allowed + } + allowed := h.runtime.mapper.allowMQTTRead(cl.ID, topic) + logger.Debug("relay mqtt acl check", "client_id", cl.ID, "write", false, "topic", topic, "allowed", allowed) + return allowed +} + +func (h *relayHook) OnConnect(cl *mqtt.Client, _ packets.Packet) error { + // Keep hook for compatibility, but only treat a client as online after + // session establishment to avoid logging "connected" on auth failures. + return nil +} + +func (h *relayHook) OnSessionEstablished(cl *mqtt.Client, _ packets.Packet) { + if cl == nil || cl.ID == mqtt.InlineClientId { + return + } + h.runtime.publishPresence(cl.ID, sdkprotocol.PresenceStatusOnline) + logger.Info("relay mqtt client connected", "client_id", cl.ID) +} + +func (h *relayHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) { + if cl == nil || cl.ID == mqtt.InlineClientId { + return + } + h.runtime.publishPresence(cl.ID, sdkprotocol.PresenceStatusOffline) + logger.Info("relay mqtt client disconnected", "client_id", cl.ID, "expire", expire) +} + +func connectNATS(cfg NATSConfig) (*nats.Conn, error) { + opts := []nats.Option{ + nats.MaxReconnects(-1), + nats.ReconnectWait(2 * time.Second), + } + if cfg.Username != "" { + opts = append(opts, nats.UserInfo(cfg.Username, cfg.Password)) + } + if cfg.TLS != nil { + tlsCfg, err := buildTLSConfig(cfg.TLS) + if err != nil { + return nil, err + } + opts = append(opts, nats.Secure(tlsCfg)) + } + nc, err := nats.Connect(cfg.URL, opts...) + if err != nil { + return nil, fmt.Errorf("connect nats: %w", err) + } + return nc, nil +} + +func buildTLSConfig(cfg *TLSConfig) (*tls.Config, error) { + tlsCfg := &tls.Config{MinVersion: tls.VersionTLS12} + if cfg.CACert != "" { + caPem, err := os.ReadFile(cfg.CACert) + if err != nil { + return nil, fmt.Errorf("read nats ca cert: %w", err) + } + pool := x509.NewCertPool() + if ok := pool.AppendCertsFromPEM(caPem); !ok { + return nil, fmt.Errorf("parse nats ca cert") + } + tlsCfg.RootCAs = pool + } + if cfg.ClientCert != "" || cfg.ClientKey != "" { + if cfg.ClientCert == "" || cfg.ClientKey == "" { + return nil, fmt.Errorf("both nats tls client_cert and client_key are required") + } + cert, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey) + if err != nil { + return nil, fmt.Errorf("load nats client cert: %w", err) + } + tlsCfg.Certificates = []tls.Certificate{cert} + } + return tlsCfg, nil +} + +func echoKey(topic string, payload []byte) string { + sum := sha256.Sum256(payload) + return fmt.Sprintf("%s|%x", topic, sum[:]) +} + +func (r *Runtime) markNATSEcho(topic string, payload []byte) { + key := echoKey(topic, payload) + now := time.Now() + r.echoMu.Lock() + r.recentEcho[key] = now + // Best-effort cleanup of stale markers. + for k, ts := range r.recentEcho { + if now.Sub(ts) > 3*time.Second { + delete(r.recentEcho, k) + } + } + r.echoMu.Unlock() +} + +func (r *Runtime) isRecentNATSEcho(topic string, payload []byte) bool { + key := echoKey(topic, payload) + now := time.Now() + r.echoMu.Lock() + defer r.echoMu.Unlock() + ts, ok := r.recentEcho[key] + if !ok { + return false + } + if now.Sub(ts) > 3*time.Second { + delete(r.recentEcho, key) + return false + } + delete(r.recentEcho, key) + return true +} diff --git a/internal/relay/topics.go b/internal/relay/topics.go new file mode 100644 index 00000000..11be1aeb --- /dev/null +++ b/internal/relay/topics.go @@ -0,0 +1,151 @@ +package relay + +import ( + "fmt" + "strings" +) + +const ( + natsControlSuffix = ".peer.*.control" + natsP2PSuffix = ".peer.*.session.*.p2p" + + mqttP2PFilterSuffix = "/peer/+/session/+/p2p" + mqttEventFilterSuffix = "/session/+/event" + mqttPresenceFilterSuffix = "/peer/+/presence" +) + +type topicMapper struct { + natsPrefix string + mqttPrefix string +} + +func newTopicMapper(natsPrefix, mqttPrefix string) topicMapper { + return topicMapper{ + natsPrefix: strings.Trim(natsPrefix, "."), + mqttPrefix: strings.Trim(mqttPrefix, "/"), + } +} + +func (m topicMapper) natsToMQTT(subject string) (string, bool) { + subject = strings.TrimSpace(subject) + if subject == "" { + return "", false + } + if !strings.HasPrefix(subject, m.natsPrefix+".") { + return "", false + } + if !m.allowNATSBridgeSubject(subject) { + return "", false + } + trimmed := strings.TrimPrefix(subject, m.natsPrefix+".") + return m.mqttPrefix + "/" + strings.ReplaceAll(trimmed, ".", "/"), true +} + +func (m topicMapper) mqttToNATS(topic string) (string, bool) { + topic = strings.Trim(strings.TrimSpace(topic), "/") + if topic == "" { + return "", false + } + if !strings.HasPrefix(topic, m.mqttPrefix+"/") { + return "", false + } + if !m.allowMQTTBridgeTopic(topic) { + return "", false + } + trimmed := strings.TrimPrefix(topic, m.mqttPrefix+"/") + return m.natsPrefix + "." + strings.ReplaceAll(trimmed, "/", "."), true +} + +func (m topicMapper) natsControlFilter() string { + return m.natsPrefix + natsControlSuffix +} + +func (m topicMapper) natsP2PFilter() string { + return m.natsPrefix + natsP2PSuffix +} + +func (m topicMapper) mqttP2PFilter() string { + return m.mqttPrefix + mqttP2PFilterSuffix +} + +func (m topicMapper) mqttSessionEventFilter() string { + return m.mqttPrefix + mqttEventFilterSuffix +} + +func (m topicMapper) mqttPresenceFilter() string { + return m.mqttPrefix + mqttPresenceFilterSuffix +} + +func (m topicMapper) natsPresenceSubject(peerID string) string { + return fmt.Sprintf("%s.peer.%s.presence", m.natsPrefix, peerID) +} + +func (m topicMapper) allowNATSBridgeSubject(subject string) bool { + parts := strings.Split(subject, ".") + prefix := strings.Split(m.natsPrefix, ".") + if len(parts) < len(prefix)+3 { + return false + } + for i := range prefix { + if parts[i] != prefix[i] { + return false + } + } + rel := parts[len(prefix):] + if len(rel) == 3 && rel[0] == "peer" && rel[2] == "control" { + return rel[1] != "" + } + if len(rel) == 5 && rel[0] == "peer" && rel[2] == "session" && rel[4] == "p2p" { + return rel[1] != "" && rel[3] != "" + } + return false +} + +func (m topicMapper) allowMQTTBridgeTopic(topic string) bool { + parts := strings.Split(strings.Trim(topic, "/"), "/") + prefix := strings.Split(m.mqttPrefix, "/") + if len(parts) < len(prefix)+3 { + return false + } + for i := range prefix { + if parts[i] != prefix[i] { + return false + } + } + rel := parts[len(prefix):] + if len(rel) == 5 && rel[0] == "peer" && rel[2] == "session" && rel[4] == "p2p" { + return rel[1] != "" && rel[3] != "" + } + if len(rel) == 3 && rel[0] == "peer" && rel[2] == "presence" { + return rel[1] != "" + } + if len(rel) == 3 && rel[0] == "session" && rel[2] == "event" { + return rel[1] != "" + } + return false +} + +func (m topicMapper) allowMQTTRead(clientID, topic string) bool { + parts := strings.Split(strings.Trim(topic, "/"), "/") + prefix := strings.Split(m.mqttPrefix, "/") + if len(parts) < len(prefix)+3 { + return false + } + for i := range prefix { + if parts[i] != prefix[i] { + return false + } + } + rel := parts[len(prefix):] + if len(rel) == 3 && rel[0] == "peer" && rel[2] == "control" { + return rel[1] == clientID + } + if len(rel) == 5 && rel[0] == "peer" && rel[2] == "session" && rel[4] == "p2p" { + return rel[1] == clientID + } + return false +} + +func (m topicMapper) allowMQTTWrite(topic string) bool { + return m.allowMQTTBridgeTopic(topic) +} From ac85deabb2f3ddb118e6e91d4e7f71f732bc388a Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 17 Apr 2026 15:42:38 +0700 Subject: [PATCH 08/23] Refactor NATS configuration in cosigner to use structured natsConfig type. Enhance NATS relay initialization with TLS support and credential management. Update validation logic to ensure required fields are checked for NATS relay provider. --- cosigner.config.yaml | 28 ++++++++++++------- cosigner2.config.yaml | 26 +++++++++++------- internal/cosigner/config.go | 41 ++++++++++++++++++++++++---- internal/cosigner/relay.go | 2 +- internal/cosigner/relay_nats.go | 48 +++++++++++++++++++++++++++++++-- relay.config.yaml | 20 ++++++++++++++ 6 files changed, 138 insertions(+), 27 deletions(-) create mode 100644 relay.config.yaml diff --git a/cosigner.config.yaml b/cosigner.config.yaml index 57af339b..7ef841f3 100644 --- a/cosigner.config.yaml +++ b/cosigner.config.yaml @@ -1,11 +1,21 @@ +relay_provider: nats nats: - url: nats://127.0.0.1:4222 + url: "nats://127.0.0.1:4222" + # username: "" + # password: "" + # tls: + # client_cert: "" + # client_key: "" + # ca_cert: "" +# mqtt: +# broker: "tcp://localhost:1883" +# client_id: "cosigner-1" +# username: "" +# password: "" -cosigner: - node_id: peer-node-01 - coordinator: - id: coordinator-01 - public_key_hex: "b64ca8ec459081a299aecc2b2b5d555265b15ddfd29e792ddd08bedb418bdd0d" - identity: - private_key_hex: "b14d168636008a9c766a6c231c182446e4b636cd2116817a89d068ffb5cc49e456a47a1103b610d6c85bf23ddb1f78ff6404f7c6f170d46441a268e105873cc4" - data_dir: node-v1-data +node_id: peer-node-01 +data_dir: node-v1-data +identity_private_key_hex: "b14d168636008a9c766a6c231c182446e4b636cd2116817a89d068ffb5cc49e456a47a1103b610d6c85bf23ddb1f78ff6404f7c6f170d46441a268e105873cc4" + +coordinator_id: coordinator-01 +coordinator_public_key_hex: "b64ca8ec459081a299aecc2b2b5d555265b15ddfd29e792ddd08bedb418bdd0d" diff --git a/cosigner2.config.yaml b/cosigner2.config.yaml index c4ac3f2b..f8261343 100644 --- a/cosigner2.config.yaml +++ b/cosigner2.config.yaml @@ -1,11 +1,17 @@ -nats: - url: nats://127.0.0.1:4222 +relay: + provider: mqtt + mqtt: + broker: tcp://127.0.0.1:1883 + client_id: peer-node-02 + username: peer-node-02 + password: peer-node-02 -cosigner: - node_id: peer-node-02 - coordinator: - id: coordinator-01 - public_key_hex: "b64ca8ec459081a299aecc2b2b5d555265b15ddfd29e792ddd08bedb418bdd0d" - identity: - private_key_hex: "a96d8c0de1b5682740f6487b13dc7477aaa739b900c6f5c3db737ca019163efad9034dd84e0dd10a57d6a09a8267b217051d5f121ff52fca66c2b485be16ae02" - data_dir: node-v1-data-02 +# nats: +# url: nats://127.0.0.1:4222 + +node_id: peer-node-02 +data_dir: node-v1-data-02 +identity_private_key_hex: "a96d8c0de1b5682740f6487b13dc7477aaa739b900c6f5c3db737ca019163efad9034dd84e0dd10a57d6a09a8267b217051d5f121ff52fca66c2b485be16ae02" + +coordinator_id: coordinator-01 +coordinator_public_key_hex: "b64ca8ec459081a299aecc2b2b5d555265b15ddfd29e792ddd08bedb418bdd0d" diff --git a/internal/cosigner/config.go b/internal/cosigner/config.go index b09bf477..5396ca51 100644 --- a/internal/cosigner/config.go +++ b/internal/cosigner/config.go @@ -4,6 +4,7 @@ import ( "crypto/ed25519" "encoding/hex" "fmt" + "strings" "time" "github.com/mitchellh/mapstructure" @@ -20,7 +21,7 @@ const ( type Config struct { RelayProvider RelayProvider NodeID string - NATSURL string + NATS natsConfig MQTT mqttConfig CoordinatorID string CoordinatorPublicKey []byte @@ -34,7 +35,7 @@ type Config struct { // Flat keys for compact config style. type fileConfig struct { RelayProvider RelayProvider `mapstructure:"relay_provider"` - NATSURL string `mapstructure:"nats_url"` + NATS natsConfig `mapstructure:"nats"` MQTT mqttConfig `mapstructure:"mqtt"` NodeID string `mapstructure:"node_id"` DataDir string `mapstructure:"data_dir"` @@ -43,6 +44,19 @@ type fileConfig struct { IdentityPrivateKeyHex string `mapstructure:"identity_private_key_hex"` } +type natsConfig struct { + URL string `mapstructure:"url"` + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + TLS *tlsConfig `mapstructure:"tls"` +} + +type tlsConfig struct { + ClientCert string `mapstructure:"client_cert"` + ClientKey string `mapstructure:"client_key"` + CACert string `mapstructure:"ca_cert"` +} + type mqttConfig struct { Broker string `mapstructure:"broker"` ClientID string `mapstructure:"client_id"` @@ -68,7 +82,7 @@ func LoadConfig() (Config, error) { runtimeCfg := Config{ RelayProvider: cfg.RelayProvider, NodeID: cfg.NodeID, - NATSURL: cfg.NATSURL, + NATS: cfg.NATS, MQTT: cfg.MQTT, CoordinatorID: cfg.CoordinatorID, CoordinatorPublicKey: coordinatorKey, @@ -103,6 +117,15 @@ func (cfg *Config) applyDefaults() { if cfg.TickInterval <= 0 { cfg.TickInterval = 100 * time.Millisecond } + + cfg.NATS.URL = strings.TrimSpace(cfg.NATS.URL) + cfg.NATS.Username = strings.TrimSpace(cfg.NATS.Username) + cfg.NATS.Password = strings.TrimSpace(cfg.NATS.Password) + if cfg.NATS.TLS != nil { + cfg.NATS.TLS.ClientCert = strings.TrimSpace(cfg.NATS.TLS.ClientCert) + cfg.NATS.TLS.ClientKey = strings.TrimSpace(cfg.NATS.TLS.ClientKey) + cfg.NATS.TLS.CACert = strings.TrimSpace(cfg.NATS.TLS.CACert) + } } func (cfg Config) Validate() error { @@ -111,8 +134,16 @@ func (cfg Config) Validate() error { } switch cfg.RelayProvider { case RelayProviderNATS: - if cfg.NATSURL == "" { - return fmt.Errorf("nats_url is required for relay provider nats") + if cfg.NATS.URL == "" { + return fmt.Errorf("nats.url is required for relay provider nats") + } + if cfg.NATS.TLS != nil { + if cfg.NATS.TLS.ClientCert == "" { + return fmt.Errorf("nats.tls.client_cert is required when nats.tls is set") + } + if cfg.NATS.TLS.ClientKey == "" { + return fmt.Errorf("nats.tls.client_key is required when nats.tls is set") + } } case RelayProviderMQTT: if cfg.MQTT.Broker == "" { diff --git a/internal/cosigner/relay.go b/internal/cosigner/relay.go index ee969445..99f3c36c 100644 --- a/internal/cosigner/relay.go +++ b/internal/cosigner/relay.go @@ -21,7 +21,7 @@ type Relay interface { func NewRelayFromConfig(cfg Config) (Relay, error) { switch cfg.RelayProvider { case RelayProviderNATS: - return NewNATSRelay(cfg.NATSURL) + return NewNATSRelay(cfg.NATS) case RelayProviderMQTT: return NewMQTTRelay(cfg.MQTT) default: diff --git a/internal/cosigner/relay_nats.go b/internal/cosigner/relay_nats.go index 6d9706f4..9c957e41 100644 --- a/internal/cosigner/relay_nats.go +++ b/internal/cosigner/relay_nats.go @@ -1,7 +1,11 @@ package cosigner import ( + "crypto/tls" + "crypto/x509" "fmt" + "os" + "time" "github.com/fystack/mpcium/pkg/logger" "github.com/nats-io/nats.go" @@ -12,14 +16,54 @@ type NATSRelay struct { nc *nats.Conn } -func NewNATSRelay(url string) (Relay, error) { - nc, err := nats.Connect(url) +func NewNATSRelay(cfg natsConfig) (Relay, error) { + opts := []nats.Option{ + nats.MaxReconnects(-1), + nats.ReconnectWait(2 * time.Second), + } + if cfg.Username != "" { + opts = append(opts, nats.UserInfo(cfg.Username, cfg.Password)) + } + if cfg.TLS != nil { + tlsCfg, err := buildNATSTLSConfig(cfg.TLS) + if err != nil { + return nil, err + } + opts = append(opts, nats.Secure(tlsCfg)) + } + nc, err := nats.Connect(cfg.URL, opts...) if err != nil { return nil, fmt.Errorf("connect nats: %w", err) } return &NATSRelay{nc: nc}, nil } +func buildNATSTLSConfig(cfg *tlsConfig) (*tls.Config, error) { + tlsCfg := &tls.Config{MinVersion: tls.VersionTLS12} + if cfg.CACert != "" { + caPEM, err := os.ReadFile(cfg.CACert) + if err != nil { + return nil, fmt.Errorf("read nats ca cert: %w", err) + } + pool := x509.NewCertPool() + if ok := pool.AppendCertsFromPEM(caPEM); !ok { + return nil, fmt.Errorf("parse nats ca cert") + } + tlsCfg.RootCAs = pool + } + if cfg.ClientCert != "" || cfg.ClientKey != "" { + if cfg.ClientCert == "" || cfg.ClientKey == "" { + return nil, fmt.Errorf("both nats tls client_cert and client_key are required") + } + cert, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey) + if err != nil { + return nil, fmt.Errorf("load nats client cert: %w", err) + } + tlsCfg.Certificates = []tls.Certificate{cert} + } + return tlsCfg, nil +} + func (t *NATSRelay) Subscribe(subject string, handler func([]byte)) (Subscription, error) { logger.Info("relay nats subscribe", "subject", subject) return t.nc.Subscribe(subject, func(msg *nats.Msg) { diff --git a/relay.config.yaml b/relay.config.yaml new file mode 100644 index 00000000..5b77339b --- /dev/null +++ b/relay.config.yaml @@ -0,0 +1,20 @@ +nats: + url: nats://127.0.0.1:4222 + # username: "" + # password: "" + # tls: + # client_cert: "/path/to/client.crt" + # client_key: "/path/to/client.key" + # ca_cert: "/path/to/ca.crt" + +relay: + mqtt: + listen_address: ":1883" + username_password_file: ./relay.credentials + bridge: + nats_prefix: mpc.v1 + mqtt_prefix: mpc/v1 + mqtt_qos: 1 + origin_header: X-MPCIUM-Relay-Origin + presence: + emit_connect_disconnect: true From 033c2ee0be8435a9420e23f709d94a28495c227a Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 17 Apr 2026 16:17:14 +0700 Subject: [PATCH 09/23] Refactor cosigner configuration to enhance MQTT relay management and introduce ECDSA preparams handling. Update .gitignore to include new configuration files and dependencies. Improve session management with default values for configuration parameters. --- .gitignore | 4 ++- cosigner2.config.yaml | 13 ++++----- go.mod | 8 ++++++ go.sum | 7 +++++ internal/cosigner/config.go | 12 ++++++-- internal/cosigner/runtime.go | 53 ++++++++++++++++++++++++++++++++++++ internal/cosigner/storage.go | 34 +++++++++++++++++------ internal/relay/config.go | 21 ++++++++++---- 8 files changed, 127 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index 3768ac41..0f6aa76b 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,6 @@ config.yaml .vagrant .chain_code .gomodcache -.codex \ No newline at end of file +.codex +coordinator-snapshots/ +relay.credentials \ No newline at end of file diff --git a/cosigner2.config.yaml b/cosigner2.config.yaml index f8261343..eed5293a 100644 --- a/cosigner2.config.yaml +++ b/cosigner2.config.yaml @@ -1,10 +1,9 @@ -relay: - provider: mqtt - mqtt: - broker: tcp://127.0.0.1:1883 - client_id: peer-node-02 - username: peer-node-02 - password: peer-node-02 +relay_provider: mqtt +mqtt: + broker: tcp://127.0.0.1:1883 + client_id: peer-node-02 + username: peer-node-02 + password: peer-node-02 # nats: # url: nats://127.0.0.1:4222 diff --git a/go.mod b/go.mod index 38d1262d..b1341112 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,14 @@ require ( golang.org/x/term v0.40.0 ) +require ( + github.com/eclipse/paho.mqtt.golang v1.5.1 // indirect + github.com/gorilla/websocket v1.5.3 // indirect + github.com/mochi-mqtt/server/v2 v2.7.9 // indirect + github.com/rs/xid v1.6.0 // indirect + golang.org/x/sync v0.19.0 // indirect +) + require ( filippo.io/hpke v0.4.0 // indirect github.com/agl/ed25519 v0.0.0-20200225211852-fd4d107ace12 // indirect diff --git a/go.sum b/go.sum index 91c05d3d..b6a1d170 100644 --- a/go.sum +++ b/go.sum @@ -109,6 +109,8 @@ github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa5 github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/eclipse/paho.mqtt.golang v1.5.1 h1:/VSOv3oDLlpqR2Epjn1Q7b2bSTplJIeV2ISgCl2W7nE= +github.com/eclipse/paho.mqtt.golang v1.5.1/go.mod h1:1/yJCneuyOoCOzKSsOTUc0AJfpsItBGWvYpBLimhArU= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= @@ -165,6 +167,8 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/consul/api v1.33.2 h1:Q6mE0WZsUTJerlnl9TuXzqrtZ0cKdOCsxcZhj5mKbMs= github.com/hashicorp/consul/api v1.33.2/go.mod h1:K3yoL/vnIBcQV/25NeMZVokRvPPERiqp2Udtr4xAfhs= github.com/hashicorp/consul/sdk v0.17.1 h1:LumAh8larSXmXw2wvw/lK5ZALkJ2wK8VRwWMLVV5M5c= @@ -253,6 +257,8 @@ github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mochi-mqtt/server/v2 v2.7.9 h1:y0g4vrSLAag7T07l2oCzOa/+nKVLoazKEWAArwqBNYI= +github.com/mochi-mqtt/server/v2 v2.7.9/go.mod h1:lZD3j35AVNqJL5cezlnSkuG05c0FCHSsfAKSPBOSbqc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -316,6 +322,7 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1 github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= diff --git a/internal/cosigner/config.go b/internal/cosigner/config.go index 5396ca51..379c9730 100644 --- a/internal/cosigner/config.go +++ b/internal/cosigner/config.go @@ -18,6 +18,12 @@ const ( RelayProviderMQTT RelayProvider = "mqtt" ) +const ( + DefaultMaxActiveSessions = 5 + DefaultPresenceInterval = 5 * time.Second + DefaultTickInterval = 100 * time.Millisecond +) + type Config struct { RelayProvider RelayProvider NodeID string @@ -109,13 +115,13 @@ func (cfg *Config) applyDefaults() { cfg.RelayProvider = RelayProviderNATS } if cfg.MaxActiveSessions <= 0 { - cfg.MaxActiveSessions = 10 + cfg.MaxActiveSessions = DefaultMaxActiveSessions } if cfg.PresenceInterval <= 0 { - cfg.PresenceInterval = 5 * time.Second + cfg.PresenceInterval = DefaultPresenceInterval } if cfg.TickInterval <= 0 { - cfg.TickInterval = 100 * time.Millisecond + cfg.TickInterval = DefaultTickInterval } cfg.NATS.URL = strings.TrimSpace(cfg.NATS.URL) diff --git a/internal/cosigner/runtime.go b/internal/cosigner/runtime.go index 2ece0dc7..773b9952 100644 --- a/internal/cosigner/runtime.go +++ b/internal/cosigner/runtime.go @@ -1,8 +1,10 @@ package cosigner import ( + "bytes" "context" "crypto/ed25519" + "encoding/gob" "encoding/json" "errors" "fmt" @@ -10,6 +12,7 @@ import ( "sync" "time" + ecdsaKeygen "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" "github.com/fystack/mpcium/pkg/logger" "github.com/vietddude/mpcium-sdk/participant" sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" @@ -32,6 +35,8 @@ type sessionMeta struct { action string } +const bootstrapPreparamsSlot = "bootstrap" + func NewRuntime(cfg Config) (*Runtime, error) { relay, err := NewRelayFromConfig(cfg) if err != nil { @@ -80,6 +85,9 @@ func (r *Runtime) Close() error { func (r *Runtime) Run(ctx context.Context) error { logger.Info("cosigner runtime started", "node_id", r.cfg.NodeID) + if err := r.ensureECDSAPreparams(); err != nil { + return err + } if err := r.subscribe(); err != nil { return err } @@ -109,6 +117,51 @@ func (r *Runtime) Run(ctx context.Context) error { } } +func (r *Runtime) ensureECDSAPreparams() error { + activeSlot, err := r.stores.LoadActivePreparamsSlot(sdkprotocol.ProtocolTypeECDSA) + if err != nil { + return fmt.Errorf("load active ecdsa preparams slot: %w", err) + } + if activeSlot != "" { + existing, loadErr := r.stores.LoadPreparamsSlot(sdkprotocol.ProtocolTypeECDSA, activeSlot) + if loadErr != nil { + return fmt.Errorf("load ecdsa preparams slot %q: %w", activeSlot, loadErr) + } + if len(existing) > 0 { + logger.Info("cosigner preparams ready", "protocol", "ecdsa", "source", "store", "slot", activeSlot) + return nil + } + logger.Warn("active ecdsa preparams slot is empty; regenerating", "slot", activeSlot) + } + + logger.Info("cosigner preparams missing; generating", "protocol", "ecdsa") + startedAt := time.Now() + preparams, err := ecdsaKeygen.GeneratePreParams(5 * time.Minute) + if err != nil { + return fmt.Errorf("generate ecdsa preparams: %w", err) + } + blob, err := encodeECDSAPreparams(preparams) + if err != nil { + return fmt.Errorf("encode ecdsa preparams: %w", err) + } + if err := r.stores.SavePreparamsSlot(sdkprotocol.ProtocolTypeECDSA, bootstrapPreparamsSlot, blob); err != nil { + return fmt.Errorf("save ecdsa preparams slot %q: %w", bootstrapPreparamsSlot, err) + } + if err := r.stores.SaveActivePreparamsSlot(sdkprotocol.ProtocolTypeECDSA, bootstrapPreparamsSlot); err != nil { + return fmt.Errorf("save active ecdsa preparams slot: %w", err) + } + logger.Info("cosigner preparams generated", "protocol", "ecdsa", "slot", bootstrapPreparamsSlot, "elapsed", time.Since(startedAt).Round(time.Millisecond)) + return nil +} + +func encodeECDSAPreparams(data *ecdsaKeygen.LocalPreParams) ([]byte, error) { + var buffer bytes.Buffer + if err := gob.NewEncoder(&buffer).Encode(data); err != nil { + return nil, err + } + return buffer.Bytes(), nil +} + func (r *Runtime) subscribe() error { controlSub, err := r.relay.Subscribe(controlSubject(r.cfg.NodeID), func(raw []byte) { if err := r.handleControl(raw); err != nil { diff --git a/internal/cosigner/storage.go b/internal/cosigner/storage.go index a801891e..52cc3903 100644 --- a/internal/cosigner/storage.go +++ b/internal/cosigner/storage.go @@ -9,8 +9,10 @@ import ( ) type PreparamsStore interface { - LoadPreparams(protocolType sdkprotocol.ProtocolType, keyID string) ([]byte, error) - SavePreparams(protocolType sdkprotocol.ProtocolType, keyID string, preparams []byte) error + LoadPreparamsSlot(protocolType sdkprotocol.ProtocolType, slot string) ([]byte, error) + SavePreparamsSlot(protocolType sdkprotocol.ProtocolType, slot string, preparams []byte) error + LoadActivePreparamsSlot(protocolType sdkprotocol.ProtocolType) (string, error) + SaveActivePreparamsSlot(protocolType sdkprotocol.ProtocolType, slot string) error } type SharesStore interface { @@ -52,12 +54,24 @@ func (s *badgerStores) Close() error { return s.db.Close() } -func (s *badgerStores) LoadPreparams(protocolType sdkprotocol.ProtocolType, keyID string) ([]byte, error) { - return s.load(keyPreparams(protocolType, keyID)) +func (s *badgerStores) LoadPreparamsSlot(protocolType sdkprotocol.ProtocolType, slot string) ([]byte, error) { + return s.load(keyPreparamsSlot(protocolType, slot)) } -func (s *badgerStores) SavePreparams(protocolType sdkprotocol.ProtocolType, keyID string, preparams []byte) error { - return s.save(keyPreparams(protocolType, keyID), preparams) +func (s *badgerStores) SavePreparamsSlot(protocolType sdkprotocol.ProtocolType, slot string, preparams []byte) error { + return s.save(keyPreparamsSlot(protocolType, slot), preparams) +} + +func (s *badgerStores) LoadActivePreparamsSlot(protocolType sdkprotocol.ProtocolType) (string, error) { + value, err := s.load(keyPreparamsActiveSlot(protocolType)) + if err != nil { + return "", err + } + return string(value), nil +} + +func (s *badgerStores) SaveActivePreparamsSlot(protocolType sdkprotocol.ProtocolType, slot string) error { + return s.save(keyPreparamsActiveSlot(protocolType), []byte(slot)) } func (s *badgerStores) LoadShare(protocolType sdkprotocol.ProtocolType, keyID string) ([]byte, error) { @@ -107,8 +121,12 @@ func (s *badgerStores) save(key string, value []byte) error { }) } -func keyPreparams(protocolType sdkprotocol.ProtocolType, keyID string) string { - return fmt.Sprintf("preparams:%s:%s", protocolType, keyID) +func keyPreparamsSlot(protocolType sdkprotocol.ProtocolType, slot string) string { + return fmt.Sprintf("preparams:%s:%s", protocolType, slot) +} + +func keyPreparamsActiveSlot(protocolType sdkprotocol.ProtocolType) string { + return fmt.Sprintf("preparams:%s:active_slot", protocolType) } func keyShare(protocolType sdkprotocol.ProtocolType, keyID string) string { diff --git a/internal/relay/config.go b/internal/relay/config.go index df6d70e4..03c1697d 100644 --- a/internal/relay/config.go +++ b/internal/relay/config.go @@ -8,10 +8,10 @@ import ( ) type RuntimeConfig struct { - NATS NATSConfig `mapstructure:"nats"` - MQTT MQTTConfig `mapstructure:"relay.mqtt"` - Bridge BridgeConfig `mapstructure:"relay.bridge"` - Presence PresenceConfig `mapstructure:"relay.presence"` + NATS NATSConfig + MQTT MQTTConfig + Bridge BridgeConfig + Presence PresenceConfig } type NATSConfig struct { @@ -47,8 +47,17 @@ func LoadConfig() (RuntimeConfig, error) { setDefaults() var cfg RuntimeConfig - if err := viper.Unmarshal(&cfg); err != nil { - return RuntimeConfig{}, fmt.Errorf("decode relay config: %w", err) + if err := viper.UnmarshalKey("nats", &cfg.NATS); err != nil { + return RuntimeConfig{}, fmt.Errorf("decode relay config nats: %w", err) + } + if err := viper.UnmarshalKey("relay.mqtt", &cfg.MQTT); err != nil { + return RuntimeConfig{}, fmt.Errorf("decode relay config relay.mqtt: %w", err) + } + if err := viper.UnmarshalKey("relay.bridge", &cfg.Bridge); err != nil { + return RuntimeConfig{}, fmt.Errorf("decode relay config relay.bridge: %w", err) + } + if err := viper.UnmarshalKey("relay.presence", &cfg.Presence); err != nil { + return RuntimeConfig{}, fmt.Errorf("decode relay config relay.presence: %w", err) } cfg.normalize() From e6980cfc26aa166dbf510b99cecdb5762ec3eb1e Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 17 Apr 2026 16:21:49 +0700 Subject: [PATCH 10/23] Enhance coordinator functionality by enabling detailed logging during initialization and improving key generation request handling. Introduce UUID-based wallet ID generation and update validation logic to ensure proper session management. Refactor configuration management to apply default values and validate required fields, enhancing overall robustness. --- cmd/mpcium-coordinator/main.go | 2 +- examples/coordinatorclient-keygen/main.go | 14 +- internal/coordinator/config.go | 54 ++++- internal/coordinator/coordinator.go | 137 ++++++++---- internal/coordinator/coordinator_test.go | 259 ++++++++++++++++++++++ internal/coordinator/keyinfo.go | 18 +- internal/coordinator/presence.go | 2 +- internal/coordinator/result_hash_test.go | 60 ----- internal/coordinator/signing_test.go | 29 --- internal/coordinator/store.go | 1 - 10 files changed, 430 insertions(+), 146 deletions(-) delete mode 100644 internal/coordinator/result_hash_test.go delete mode 100644 internal/coordinator/signing_test.go diff --git a/cmd/mpcium-coordinator/main.go b/cmd/mpcium-coordinator/main.go index 5de34807..51515cf4 100644 --- a/cmd/mpcium-coordinator/main.go +++ b/cmd/mpcium-coordinator/main.go @@ -18,7 +18,7 @@ import ( const coordinatorConfigPath = "coordinator.config.yaml" func main() { - logger.Init(os.Getenv("ENVIRONMENT"), false) + logger.Init(os.Getenv("ENVIRONMENT"), true) cmd := &cli.Command{ Name: "mpcium-coordinator", diff --git a/examples/coordinatorclient-keygen/main.go b/examples/coordinatorclient-keygen/main.go index 58a30101..c4ad1287 100644 --- a/examples/coordinatorclient-keygen/main.go +++ b/examples/coordinatorclient-keygen/main.go @@ -9,6 +9,7 @@ import ( "time" "github.com/fystack/mpcium/pkg/coordinatorclient" + "github.com/google/uuid" sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) @@ -33,16 +34,7 @@ func main() { }, } - for _, participant := range participants { - presenceCtx, cancelPresence := context.WithTimeout(context.Background(), 5*time.Second) - if err := client.PublishPresence(presenceCtx, participant.ID); err != nil { - cancelPresence() - log.Fatalf("publish presence for %s: %v", participant.ID, err) - } - cancelPresence() - } - - const walletID = "wallet_demo_001" + walletID := "wallet_" + uuid.New().String() runKeygenForProtocol(client, participants, walletID, sdkprotocol.ProtocolTypeECDSA) runKeygenForProtocol(client, participants, walletID, sdkprotocol.ProtocolTypeEdDSA) } @@ -80,7 +72,7 @@ func runKeygenForProtocol(client *coordinatorclient.Client, participants []coord }) cancelRequest() if err != nil { - log.Fatalf("request keygen (%s): %v", protocol, err) + log.Fatalf("request keygen (%s): %v (verify both cosigners are online and publishing real presence)", protocol, err) } acceptedAt := time.Now() diff --git a/internal/coordinator/config.go b/internal/coordinator/config.go index fce88b3d..3d898405 100644 --- a/internal/coordinator/config.go +++ b/internal/coordinator/config.go @@ -8,6 +8,11 @@ import ( "github.com/spf13/viper" ) +const ( + DefaultSessionTTL = 120 * time.Second + DefaultTickInterval = time.Second +) + type fileConfig struct { NATS natsConfig `mapstructure:"nats"` Coordinator coordinatorConfig `mapstructure:"coordinator"` @@ -62,7 +67,52 @@ func (cfg coordinatorConfig) runtimeConfig(natsURL string) RuntimeConfig { ID: cfg.ID, PrivateKeyHex: cfg.PrivateKeyHex, SnapshotDir: cfg.SnapshotDir, - DefaultSessionTTL: 120 * time.Second, - TickInterval: time.Second, + DefaultSessionTTL: DefaultSessionTTL, + TickInterval: DefaultTickInterval, + } +} + +type CoordinatorConfig struct { + CoordinatorID string + Signer Signer + EventVerifier SessionEventVerifier + Store *MemorySessionStore + KeyInfoStore *MemoryKeyInfoStore + Presence PresenceView + Controls ControlPublisher + Results ResultPublisher + DefaultSessionTTL time.Duration + Now func() time.Time +} + +func applyDefaults(cfg CoordinatorConfig) CoordinatorConfig { + if cfg.Now == nil { + cfg.Now = func() time.Time { return time.Now().UTC() } + } + if cfg.DefaultSessionTTL <= 0 { + cfg.DefaultSessionTTL = 120 * time.Second } + return cfg +} + +func (cfg CoordinatorConfig) Validate() error { + if cfg.CoordinatorID == "" { + return fmt.Errorf("coordinator ID is required") + } + if cfg.Signer == nil { + return fmt.Errorf("signer is required") + } + if cfg.Store == nil { + return fmt.Errorf("session store is required") + } + if cfg.Presence == nil { + return fmt.Errorf("presence view is required") + } + if cfg.Controls == nil { + return fmt.Errorf("control publisher is required") + } + if cfg.Results == nil { + return fmt.Errorf("result publisher is required") + } + return nil } diff --git a/internal/coordinator/coordinator.go b/internal/coordinator/coordinator.go index 6a6d1dcf..f7c8600c 100644 --- a/internal/coordinator/coordinator.go +++ b/internal/coordinator/coordinator.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "sort" "strings" "time" @@ -14,19 +15,6 @@ import ( sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) -type CoordinatorConfig struct { - CoordinatorID string - Signer Signer - EventVerifier SessionEventVerifier - Store *MemorySessionStore - KeyInfoStore *MemoryKeyInfoStore - Presence PresenceView - Controls ControlPublisher - Results ResultPublisher - DefaultSessionTTL time.Duration - Now func() time.Time -} - type Coordinator struct { id string signer Signer @@ -41,30 +29,11 @@ type Coordinator struct { } func NewCoordinator(cfg CoordinatorConfig) (*Coordinator, error) { - if cfg.CoordinatorID == "" { - return nil, fmt.Errorf("coordinator ID is required") - } - if cfg.Signer == nil { - return nil, fmt.Errorf("signer is required") - } - if cfg.Store == nil { - return nil, fmt.Errorf("session store is required") - } - if cfg.Presence == nil { - return nil, fmt.Errorf("presence view is required") - } - if cfg.Controls == nil { - return nil, fmt.Errorf("control publisher is required") - } - if cfg.Results == nil { - return nil, fmt.Errorf("result publisher is required") - } - if cfg.Now == nil { - cfg.Now = func() time.Time { return time.Now().UTC() } - } - if cfg.DefaultSessionTTL <= 0 { - cfg.DefaultSessionTTL = 120 * time.Second + cfg = applyDefaults(cfg) + if err := cfg.Validate(); err != nil { + return nil, err } + return &Coordinator{ id: cfg.CoordinatorID, signer: cfg.Signer, @@ -92,12 +61,21 @@ func (c *Coordinator) HandleRequest(ctx context.Context, op Operation, raw []byt protocols := []sdkprotocol.ProtocolType{sdkprotocol.ProtocolTypeECDSA, sdkprotocol.ProtocolTypeEdDSA} sessionIDs := make([]string, 0, len(protocols)) var firstAccepted *sdkprotocol.RequestAccepted + var firstErr error for _, protocol := range protocols { cloned := cloneSessionStart(req.SessionStart) cloned.Protocol = protocol accepted, err := c.acceptRequest(ctx, op, &sdkprotocol.ControlMessage{SessionStart: cloned}) if err != nil { + var coordErr *CoordinatorError + if AsCoordinatorError(err, &coordErr) && coordErr.Code == ErrorCodeConflict { + // Allow fanout to continue: one protocol might already exist while the other doesn't. + if firstErr == nil { + firstErr = err + } + continue + } return rejectFromError(err), nil } sessionIDs = append(sessionIDs, accepted.SessionID) @@ -105,6 +83,12 @@ func (c *Coordinator) HandleRequest(ctx context.Context, op Operation, raw []byt firstAccepted = accepted } } + if firstAccepted == nil { + if firstErr != nil { + return rejectFromError(firstErr), nil + } + return reject(ErrorCodeConflict, "no keygen sessions created"), nil + } logger.Info("coordinator expanded keygen request without protocol", "operation", string(op), @@ -186,7 +170,7 @@ func (c *Coordinator) HandleSessionEvent(ctx context.Context, raw []byte) error if err := json.Unmarshal(raw, &event); err != nil { return newCoordinatorError(ErrorCodeInvalidJSON, "invalid JSON session event") } - if err := sdkprotocol.ValidateSessionEvent(&event); err != nil { + if err := validateSessionEventCompat(&event); err != nil { return newCoordinatorError(ErrorCodeValidation, err.Error()) } @@ -256,6 +240,41 @@ func (c *Coordinator) HandleSessionEvent(ctx context.Context, raw []byte) error return c.store.Save(ctx, session) } +func validateSessionEventCompat(event *sdkprotocol.SessionEvent) error { + if event == nil { + return sdkprotocol.ValidateSessionEvent(event) + } + if err := sdkprotocol.ValidateSessionEvent(event); err == nil { + return nil + } + // Compatibility: allow keygen completion events without share_blob. + if event.SessionCompleted == nil || event.SessionCompleted.Result == nil || event.SessionCompleted.Result.KeyShare == nil || len(event.SessionCompleted.Result.KeyShare.ShareBlob) > 0 { + return sdkprotocol.ValidateSessionEvent(event) + } + clone := cloneSessionEventForValidation(event) + clone.SessionCompleted.Result.KeyShare.ShareBlob = []byte{0} + return sdkprotocol.ValidateSessionEvent(clone) +} + +func cloneSessionEventForValidation(event *sdkprotocol.SessionEvent) *sdkprotocol.SessionEvent { + clone := *event + if event.SessionCompleted != nil { + completed := *event.SessionCompleted + clone.SessionCompleted = &completed + if event.SessionCompleted.Result != nil { + result := *event.SessionCompleted.Result + clone.SessionCompleted.Result = &result + if event.SessionCompleted.Result.KeyShare != nil { + keyShare := *event.SessionCompleted.Result.KeyShare + keyShare.PublicKey = append([]byte(nil), event.SessionCompleted.Result.KeyShare.PublicKey...) + keyShare.ShareBlob = append([]byte(nil), event.SessionCompleted.Result.KeyShare.ShareBlob...) + clone.SessionCompleted.Result.KeyShare = &keyShare + } + } + } + return &clone +} + func (c *Coordinator) Tick(ctx context.Context) (int, error) { now := c.now() expired := 0 @@ -302,6 +321,16 @@ func (c *Coordinator) validateRequest(ctx context.Context, op Operation, msg *sd return newCoordinatorError(ErrorCodeUnavailable, "participant is offline") } } + if op == OperationKeygen && c.keyInfoStore != nil { + walletID := keygenWalletID(start) + if walletID == "" { + return newCoordinatorError(ErrorCodeValidation, "wallet_id is required") + } + protocol := string(start.Protocol) + if _, exists := c.keyInfoStore.Get(walletID, protocol); exists { + return newCoordinatorError(ErrorCodeConflict, "wallet key already exists") + } + } return nil } @@ -331,6 +360,9 @@ func (c *Coordinator) advance(ctx context.Context, session *Session, event *sdkp if err != nil { return c.failSession(ctx, session, ErrorCodeResultHashMismatch, err.Error()) } + if err := c.persistKeyInfoIfNeeded(session, result); err != nil { + return c.failSession(ctx, session, ErrorCodeInternal, err.Error()) + } now := c.now() session.State = SessionCompleted session.ResultHash = resultHash @@ -346,6 +378,37 @@ func (c *Coordinator) advance(ctx context.Context, session *Session, event *sdkp return nil } +func (c *Coordinator) persistKeyInfoIfNeeded(session *Session, result *sdkprotocol.Result) error { + if c.keyInfoStore == nil || session == nil || result == nil || session.Op != OperationKeygen || result.KeyShare == nil { + return nil + } + walletID := result.KeyShare.KeyID + if walletID == "" { + walletID = keygenWalletID(session.Start) + } + if walletID == "" { + return fmt.Errorf("missing wallet id in keygen result") + } + participantIDs := make([]string, 0, len(session.Participants)) + for _, participant := range session.Participants { + if participant == nil || participant.ParticipantID == "" { + continue + } + participantIDs = append(participantIDs, participant.ParticipantID) + } + sort.Strings(participantIDs) + info := KeyInfo{ + WalletID: walletID, + KeyType: string(session.Start.Protocol), + Threshold: int(session.Start.Threshold), + Participants: participantIDs, + PublicKey: append([]byte(nil), result.KeyShare.PublicKey...), + CreatedAt: c.now().UTC().Format(time.RFC3339Nano), + } + c.keyInfoStore.Save(info) + return nil +} + func (c *Coordinator) fanOutSessionStart(ctx context.Context, session *Session) error { msg := &sdkprotocol.ControlMessage{ SessionID: session.ID, diff --git a/internal/coordinator/coordinator_test.go b/internal/coordinator/coordinator_test.go index 92c13c09..5d13fdba 100644 --- a/internal/coordinator/coordinator_test.go +++ b/internal/coordinator/coordinator_test.go @@ -133,6 +133,264 @@ func TestLifecycleCompletesSignAndPublishesResult(t *testing.T) { } } +func TestLifecycleCompletesKeygenWithoutShareBlob(t *testing.T) { + ctx := context.Background() + coord, _, results, fixtures := newTestCoordinator(t) + markOnline(t, coord.presence, fixtures["p1"].pub, "p1") + markOnline(t, coord.presence, fixtures["p2"].pub, "p2") + + keygenReq := &sdkprotocol.ControlMessage{ + SessionStart: &sdkprotocol.SessionStart{ + SessionID: "client-supplied", + Protocol: sdkprotocol.ProtocolTypeEdDSA, + Operation: sdkprotocol.OperationTypeKeygen, + Threshold: 1, + Participants: []*sdkprotocol.SessionParticipant{ + {ParticipantID: "p1", PartyKey: []byte("p1"), IdentityPublicKey: fixtures["p1"].pub}, + {ParticipantID: "p2", PartyKey: []byte("p2"), IdentityPublicKey: fixtures["p2"].pub}, + }, + Keygen: &sdkprotocol.KeygenPayload{KeyID: "wallet_demo_001"}, + }, + } + rawReply, err := coord.HandleRequest(ctx, OperationKeygen, mustJSON(t, keygenReq)) + if err != nil { + t.Fatal(err) + } + var reply sdkprotocol.RequestAccepted + if err := json.Unmarshal(rawReply, &reply); err != nil { + t.Fatal(err) + } + + result := &sdkprotocol.Result{ + KeyShare: &sdkprotocol.KeyShareResult{ + KeyID: "wallet_demo_001", + PublicKey: []byte("pub"), + }, + } + for _, participant := range []string{"p1", "p2"} { + emitSignedEvent(t, coord, reply.SessionID, fixtures, participant, &sdkprotocol.SessionEvent{PeerJoined: &sdkprotocol.PeerJoined{ParticipantID: participant}}) + emitSignedEvent(t, coord, reply.SessionID, fixtures, participant, &sdkprotocol.SessionEvent{PeerReady: &sdkprotocol.PeerReady{ParticipantID: participant}}) + } + for _, participant := range []string{"p1", "p2"} { + emitSignedEvent(t, coord, reply.SessionID, fixtures, participant, &sdkprotocol.SessionEvent{PeerKeyExchangeDone: &sdkprotocol.PeerKeyExchangeDone{ParticipantID: participant}}) + } + for _, participant := range []string{"p1", "p2"} { + emitSignedEvent(t, coord, reply.SessionID, fixtures, participant, &sdkprotocol.SessionEvent{SessionCompleted: &sdkprotocol.SessionCompleted{Result: result}}) + } + + published := results.results[reply.SessionID] + if published == nil || published.KeyShare == nil { + t.Fatalf("missing published keygen result") + } + if len(published.KeyShare.ShareBlob) != 0 { + t.Fatalf("share blob should not be required/published") + } +} + +func TestHandleRequestRejectsDuplicateWalletIDAfterCompletedKeygen(t *testing.T) { + ctx := context.Background() + coord, _, _, fixtures := newTestCoordinator(t) + markOnline(t, coord.presence, fixtures["p1"].pub, "p1") + markOnline(t, coord.presence, fixtures["p2"].pub, "p2") + + requestForWallet := func(walletID string, protocol sdkprotocol.ProtocolType) *sdkprotocol.ControlMessage { + return &sdkprotocol.ControlMessage{ + SessionStart: &sdkprotocol.SessionStart{ + SessionID: "client-supplied", + Protocol: protocol, + Operation: sdkprotocol.OperationTypeKeygen, + Threshold: 1, + Participants: []*sdkprotocol.SessionParticipant{ + {ParticipantID: "p1", PartyKey: []byte("p1"), IdentityPublicKey: fixtures["p1"].pub}, + {ParticipantID: "p2", PartyKey: []byte("p2"), IdentityPublicKey: fixtures["p2"].pub}, + }, + Keygen: &sdkprotocol.KeygenPayload{KeyID: walletID}, + }, + } + } + + rawReply, err := coord.HandleRequest(ctx, OperationKeygen, mustJSON(t, requestForWallet("wallet_demo_001", sdkprotocol.ProtocolTypeEdDSA))) + if err != nil { + t.Fatal(err) + } + var accepted sdkprotocol.RequestAccepted + if err := json.Unmarshal(rawReply, &accepted); err != nil { + t.Fatal(err) + } + if !accepted.Accepted || accepted.SessionID == "" { + t.Fatalf("unexpected keygen accepted reply: %+v", accepted) + } + + keygenResult := &sdkprotocol.Result{ + KeyShare: &sdkprotocol.KeyShareResult{ + KeyID: "wallet_demo_001", + PublicKey: []byte("pub"), + }, + } + for _, participant := range []string{"p1", "p2"} { + emitSignedEvent(t, coord, accepted.SessionID, fixtures, participant, &sdkprotocol.SessionEvent{PeerJoined: &sdkprotocol.PeerJoined{ParticipantID: participant}}) + emitSignedEvent(t, coord, accepted.SessionID, fixtures, participant, &sdkprotocol.SessionEvent{PeerReady: &sdkprotocol.PeerReady{ParticipantID: participant}}) + } + for _, participant := range []string{"p1", "p2"} { + emitSignedEvent(t, coord, accepted.SessionID, fixtures, participant, &sdkprotocol.SessionEvent{PeerKeyExchangeDone: &sdkprotocol.PeerKeyExchangeDone{ParticipantID: participant}}) + } + for _, participant := range []string{"p1", "p2"} { + emitSignedEvent(t, coord, accepted.SessionID, fixtures, participant, &sdkprotocol.SessionEvent{SessionCompleted: &sdkprotocol.SessionCompleted{Result: keygenResult}}) + } + + rawDupReply, err := coord.HandleRequest(ctx, OperationKeygen, mustJSON(t, requestForWallet("wallet_demo_001", sdkprotocol.ProtocolTypeEdDSA))) + if err != nil { + t.Fatal(err) + } + var rejected sdkprotocol.RequestRejected + if err := json.Unmarshal(rawDupReply, &rejected); err != nil { + t.Fatal(err) + } + if rejected.Accepted { + t.Fatalf("expected duplicate keygen request to be rejected") + } + if rejected.ErrorCode != ErrorCodeConflict { + t.Fatalf("error code = %s, want %s", rejected.ErrorCode, ErrorCodeConflict) + } + + rawOtherProtocolReply, err := coord.HandleRequest(ctx, OperationKeygen, mustJSON(t, requestForWallet("wallet_demo_001", sdkprotocol.ProtocolTypeECDSA))) + if err != nil { + t.Fatal(err) + } + var acceptedOtherProtocol sdkprotocol.RequestAccepted + if err := json.Unmarshal(rawOtherProtocolReply, &acceptedOtherProtocol); err != nil { + t.Fatal(err) + } + if !acceptedOtherProtocol.Accepted { + t.Fatalf("expected same wallet id with different protocol to be accepted") + } +} + +func TestHandleRequestKeygenWithoutProtocolCreatesBothSessions(t *testing.T) { + ctx := context.Background() + coord, _, _, fixtures := newTestCoordinator(t) + markOnline(t, coord.presence, fixtures["p1"].pub, "p1") + markOnline(t, coord.presence, fixtures["p2"].pub, "p2") + + req := &sdkprotocol.ControlMessage{ + SessionStart: &sdkprotocol.SessionStart{ + SessionID: "client-supplied", + Protocol: sdkprotocol.ProtocolTypeUnspecified, + Operation: sdkprotocol.OperationTypeKeygen, + Threshold: 1, + Participants: []*sdkprotocol.SessionParticipant{ + {ParticipantID: "p1", PartyKey: []byte("p1"), IdentityPublicKey: fixtures["p1"].pub}, + {ParticipantID: "p2", PartyKey: []byte("p2"), IdentityPublicKey: fixtures["p2"].pub}, + }, + Keygen: &sdkprotocol.KeygenPayload{KeyID: "wallet_dual_protocol"}, + }, + } + + rawReply, err := coord.HandleRequest(ctx, OperationKeygen, mustJSON(t, req)) + if err != nil { + t.Fatal(err) + } + var accepted sdkprotocol.RequestAccepted + if err := json.Unmarshal(rawReply, &accepted); err != nil { + t.Fatal(err) + } + if !accepted.Accepted { + t.Fatalf("expected request accepted") + } + active := coord.store.ListActive(ctx) + if len(active) != 2 { + t.Fatalf("expected 2 active sessions, got %d", len(active)) + } + seenProtocols := map[sdkprotocol.ProtocolType]bool{} + for _, session := range active { + seenProtocols[session.Start.Protocol] = true + } + if !seenProtocols[sdkprotocol.ProtocolTypeECDSA] || !seenProtocols[sdkprotocol.ProtocolTypeEdDSA] { + t.Fatalf("expected both ECDSA and EdDSA sessions, got %+v", seenProtocols) + } +} + +func TestHandleRequestSignWithoutProtocolRejected(t *testing.T) { + ctx := context.Background() + coord, _, _, fixtures := newTestCoordinator(t) + markOnline(t, coord.presence, fixtures["p1"].pub, "p1") + markOnline(t, coord.presence, fixtures["p2"].pub, "p2") + + req := &sdkprotocol.ControlMessage{ + SessionStart: &sdkprotocol.SessionStart{ + SessionID: "client-supplied", + Protocol: sdkprotocol.ProtocolTypeUnspecified, + Operation: sdkprotocol.OperationTypeSign, + Threshold: 1, + Participants: []*sdkprotocol.SessionParticipant{ + {ParticipantID: "p1", PartyKey: []byte("p1"), IdentityPublicKey: fixtures["p1"].pub}, + {ParticipantID: "p2", PartyKey: []byte("p2"), IdentityPublicKey: fixtures["p2"].pub}, + }, + Sign: &sdkprotocol.SignPayload{ + KeyID: "wallet-1", + SigningInput: []byte("message"), + }, + }, + } + + rawReply, err := coord.HandleRequest(ctx, OperationSign, mustJSON(t, req)) + if err != nil { + t.Fatal(err) + } + var rejected sdkprotocol.RequestRejected + if err := json.Unmarshal(rawReply, &rejected); err != nil { + t.Fatal(err) + } + if rejected.Accepted { + t.Fatalf("expected sign request without protocol to be rejected") + } + if rejected.ErrorCode != ErrorCodeValidation { + t.Fatalf("error code = %s, want %s", rejected.ErrorCode, ErrorCodeValidation) + } +} + +func TestNewCoordinator_AppliesDefaultNowAndTickDoesNotPanic(t *testing.T) { + store, err := NewMemorySessionStore(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + coord, err := NewCoordinator(CoordinatorConfig{ + CoordinatorID: "coordinator-1", + Signer: fakeSigner{}, + Store: store, + Presence: NewInMemoryPresenceView(), + Controls: &fakeControlPublisher{}, + Results: &fakeResultPublisher{}, + DefaultSessionTTL: 120 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + if coord.now == nil { + t.Fatalf("expected default now function") + } + if _, err := coord.Tick(context.Background()); err != nil { + t.Fatalf("tick returned error: %v", err) + } +} + +func TestNewCoordinator_RejectsInvalidConfig(t *testing.T) { + store, err := NewMemorySessionStore(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + _, err = NewCoordinator(CoordinatorConfig{ + CoordinatorID: "coordinator-1", + Store: store, + Presence: NewInMemoryPresenceView(), + Controls: &fakeControlPublisher{}, + Results: &fakeResultPublisher{}, + }) + if err == nil { + t.Fatalf("expected validation error") + } +} + type participantKey struct { pub ed25519.PublicKey priv ed25519.PrivateKey @@ -160,6 +418,7 @@ func newTestCoordinator(t *testing.T) (*Coordinator, *fakeControlPublisher, *fak Signer: fakeSigner{}, EventVerifier: Ed25519SessionEventVerifier{}, Store: store, + KeyInfoStore: NewMemoryKeyInfoStore(), Presence: NewInMemoryPresenceView(), Controls: controls, Results: results, diff --git a/internal/coordinator/keyinfo.go b/internal/coordinator/keyinfo.go index 40efc24b..ce68be41 100644 --- a/internal/coordinator/keyinfo.go +++ b/internal/coordinator/keyinfo.go @@ -31,14 +31,20 @@ func (s *MemoryKeyInfoStore) Save(info KeyInfo) { if info.CreatedAt == "" { info.CreatedAt = time.Now().UTC().Format(time.RFC3339Nano) } - s.infos[info.WalletID] = info + s.infos[keyInfoStoreKey(info.WalletID, info.KeyType)] = info } -func (s *MemoryKeyInfoStore) Get(walletID string) (KeyInfo, bool) { +func (s *MemoryKeyInfoStore) Get(walletID, keyType string) (KeyInfo, bool) { s.mu.RLock() defer s.mu.RUnlock() - info, ok := s.infos[walletID] - return info, ok + key := keyInfoStoreKey(walletID, keyType) + info, ok := s.infos[key] + if ok { + return info, true + } + // Backward compatibility for legacy snapshots without key type. + legacy, ok := s.infos[keyInfoStoreKey(walletID, "")] + return legacy, ok } func RestoreKeyInfoFromSnapshotStore(ctx context.Context, snapshots SnapshotStore, store *MemoryKeyInfoStore) error { @@ -54,3 +60,7 @@ func RestoreKeyInfoFromSnapshotStore(ctx context.Context, snapshots SnapshotStor } return nil } + +func keyInfoStoreKey(walletID, keyType string) string { + return walletID + "|" + keyType +} diff --git a/internal/coordinator/presence.go b/internal/coordinator/presence.go index d452a255..0ef95c31 100644 --- a/internal/coordinator/presence.go +++ b/internal/coordinator/presence.go @@ -31,7 +31,7 @@ func (p *InMemoryPresenceView) IsOnline(_ context.Context, peerID string) bool { if !ok { return false } - return event.Status == sdkprotocol.PresenceStatusOnline && event.Transport == sdkprotocol.TransportTypeNATS + return event.Status == sdkprotocol.PresenceStatusOnline } func (p *InMemoryPresenceView) ApplyPresence(event sdkprotocol.PresenceEvent) error { diff --git a/internal/coordinator/result_hash_test.go b/internal/coordinator/result_hash_test.go deleted file mode 100644 index 25b64f27..00000000 --- a/internal/coordinator/result_hash_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package coordinator - -import ( - "bytes" - "testing" - - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" -) - -func TestCanonicalOperationResultHashIgnoresKeygenShareBlob(t *testing.T) { - resultA := &sdkprotocol.Result{ - KeyShare: &sdkprotocol.KeyShareResult{ - KeyID: "wallet-1", - PublicKey: []byte{1, 2, 3}, - ShareBlob: []byte{9, 9, 9}, - }, - } - resultB := &sdkprotocol.Result{ - KeyShare: &sdkprotocol.KeyShareResult{ - KeyID: "wallet-1", - PublicKey: []byte{1, 2, 3}, - ShareBlob: []byte{8, 8, 8}, - }, - } - - hashA := canonicalOperationResultHash(OperationKeygen, resultA) - hashB := canonicalOperationResultHash(OperationKeygen, resultB) - if hashA == "" || hashB == "" { - t.Fatal("expected non-empty hashes") - } - if hashA != hashB { - t.Fatalf("expected equal hashes for keygen results with different share blobs, got %q != %q", hashA, hashB) - } -} - -func TestCanonicalOperationResultHashUsesFullSignaturePayload(t *testing.T) { - resultA := &sdkprotocol.Result{ - Signature: &sdkprotocol.SignatureResult{ - KeyID: "wallet-1", - Signature: []byte{1, 2, 3}, - }, - } - resultB := &sdkprotocol.Result{ - Signature: &sdkprotocol.SignatureResult{ - KeyID: "wallet-1", - Signature: []byte{1, 2, 4}, - }, - } - - hashA := canonicalOperationResultHash(OperationSign, resultA) - hashB := canonicalOperationResultHash(OperationSign, resultB) - if hashA == hashB { - t.Fatalf("expected different hashes for different signature payloads") - } - - // Guard against accidental normalization that removes signature bytes. - if bytes.Equal(resultA.Signature.Signature, resultB.Signature.Signature) { - t.Fatal("invalid test setup") - } -} diff --git a/internal/coordinator/signing_test.go b/internal/coordinator/signing_test.go deleted file mode 100644 index be004c00..00000000 --- a/internal/coordinator/signing_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package coordinator - -import ( - "context" - "strings" - "testing" - - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" -) - -func TestEd25519SessionEventVerifierRejectsInvalidPublicKeyLength(t *testing.T) { - verifier := Ed25519SessionEventVerifier{} - session := &Session{ - ParticipantKeys: map[string][]byte{ - "peer-1": make([]byte, 64), - }, - } - event := &sdkprotocol.SessionEvent{ - ParticipantID: "peer-1", - } - - err := verifier.VerifySessionEvent(context.Background(), session, event) - if err == nil { - t.Fatal("expected error for invalid participant public key length") - } - if !strings.Contains(err.Error(), "invalid participant public key length") { - t.Fatalf("unexpected error: %v", err) - } -} diff --git a/internal/coordinator/store.go b/internal/coordinator/store.go index 9d7a0cd2..d298ed1a 100644 --- a/internal/coordinator/store.go +++ b/internal/coordinator/store.go @@ -299,7 +299,6 @@ func cloneResult(result *sdkprotocol.Result) *sdkprotocol.Result { cloned := *result if result.KeyShare != nil { keyShare := *result.KeyShare - keyShare.ShareBlob = append([]byte(nil), result.KeyShare.ShareBlob...) keyShare.PublicKey = append([]byte(nil), result.KeyShare.PublicKey...) cloned.KeyShare = &keyShare } From 666c8a953af4cc0590cded05574f4257f15236f4 Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 17 Apr 2026 17:26:14 +0700 Subject: [PATCH 11/23] Add coordinator client for signing requests with EdDSA protocol. Implement request and validation logic for signing operations, including participant management and session handling. Enhance client functionality with detailed error logging and timeout management for improved reliability. --- examples/coordinatorclient-sign/main.go | 95 +++++++++++++++++++++++++ pkg/coordinatorclient/client.go | 75 +++++++++++++++++-- 2 files changed, 163 insertions(+), 7 deletions(-) create mode 100644 examples/coordinatorclient-sign/main.go diff --git a/examples/coordinatorclient-sign/main.go b/examples/coordinatorclient-sign/main.go new file mode 100644 index 00000000..d65b49ad --- /dev/null +++ b/examples/coordinatorclient-sign/main.go @@ -0,0 +1,95 @@ +package main + +import ( + "context" + "crypto/ed25519" + "encoding/hex" + "fmt" + "log" + "time" + + "github.com/fystack/mpcium/pkg/coordinatorclient" + sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" +) + +func main() { + client, err := coordinatorclient.New(coordinatorclient.Config{ + NATSURL: "nats://127.0.0.1:4222", + Timeout: 5 * time.Second, + }) + if err != nil { + log.Fatalf("create coordinator client: %v", err) + } + defer client.Close() + + participants := []coordinatorclient.SignParticipant{ + { + ID: "peer-node-01", + IdentityPublicKey: mustPublicKeyFromPrivateHex("b14d168636008a9c766a6c231c182446e4b636cd2116817a89d068ffb5cc49e456a47a1103b610d6c85bf23ddb1f78ff6404f7c6f170d46441a268e105873cc4"), + }, + { + ID: "peer-node-02", + IdentityPublicKey: mustPublicKeyFromPrivateHex("a96d8c0de1b5682740f6487b13dc7477aaa739b900c6f5c3db737ca019163efad9034dd84e0dd10a57d6a09a8267b217051d5f121ff52fca66c2b485be16ae02"), + }, + } + + walletID := "wallet_2834c034-489c-40f2-a237-2afd4a73bfaa" + message := []byte("deadbeef") + protocol := sdkprotocol.ProtocolTypeEdDSA + + requestCtx, cancelRequest := context.WithTimeout(context.Background(), 10*time.Second) + resp, err := client.RequestSign(requestCtx, coordinatorclient.SignRequest{ + Protocol: protocol, + Threshold: 1, + WalletID: walletID, + SigningInput: message, + Participants: participants, + }) + cancelRequest() + if err != nil { + log.Fatalf("request sign: %v (verify both cosigners are online and wallet ID exists for this protocol)", err) + } + acceptedAt := time.Now() + + resultCtx, cancelResult := context.WithTimeout(context.Background(), 2*time.Minute) + result, err := client.WaitSessionResult(resultCtx, resp.SessionID) + cancelResult() + if err != nil { + log.Fatalf("wait session result: %v (check both cosigners are running and session events are flowing)", err) + } + if result == nil || result.Signature == nil { + fmt.Printf("session_id=%s result=empty wait_seconds=%.3f\n", resp.SessionID, time.Since(acceptedAt).Seconds()) + return + } + + sig := result.Signature + fmt.Printf("session_id=%s key_id=%s wait_seconds=%.3f\n", resp.SessionID, sig.KeyID, time.Since(acceptedAt).Seconds()) + fmt.Printf("signature_hex=%s\n", hex.EncodeToString(sig.Signature)) + if len(sig.R) > 0 || len(sig.S) > 0 { + fmt.Printf("r_hex=%s\n", hex.EncodeToString(sig.R)) + fmt.Printf("s_hex=%s\n", hex.EncodeToString(sig.S)) + } +} + +func mustDecodeHex(value string) []byte { + decoded, err := hex.DecodeString(value) + if err != nil { + panic(err) + } + return decoded +} + +func mustPublicKeyFromPrivateHex(privateKeyHex string) []byte { + privateRaw := mustDecodeHex(privateKeyHex) + var private ed25519.PrivateKey + switch len(privateRaw) { + case ed25519.PrivateKeySize: + private = ed25519.PrivateKey(privateRaw) + case ed25519.SeedSize: + private = ed25519.NewKeyFromSeed(privateRaw) + default: + panic(fmt.Sprintf("invalid ed25519 private key length: %d", len(privateRaw))) + } + public := private.Public().(ed25519.PublicKey) + return append([]byte(nil), public...) +} diff --git a/pkg/coordinatorclient/client.go b/pkg/coordinatorclient/client.go index f58b8872..19cdc817 100644 --- a/pkg/coordinatorclient/client.go +++ b/pkg/coordinatorclient/client.go @@ -13,6 +13,7 @@ import ( const ( topicPrefix = "mpc.v1" requestKeygenSubject = topicPrefix + ".request.keygen" + requestSignSubject = topicPrefix + ".request.sign" ) type Client struct { @@ -30,6 +31,8 @@ type KeygenParticipant struct { IdentityPublicKey []byte } +type SignParticipant = KeygenParticipant + type KeygenRequest struct { Protocol sdkprotocol.ProtocolType Threshold uint32 @@ -37,6 +40,15 @@ type KeygenRequest struct { Participants []KeygenParticipant } +type SignRequest struct { + Protocol sdkprotocol.ProtocolType + Threshold uint32 + WalletID string + SigningInput []byte + Derivation *sdkprotocol.NonHardenedDerivation + Participants []SignParticipant +} + func New(cfg Config) (*Client, error) { if cfg.NATSURL == "" { cfg.NATSURL = nats.DefaultURL @@ -100,10 +112,10 @@ func (c *Client) RequestKeygen(ctx context.Context, req KeygenRequest) (*sdkprot msg := &sdkprotocol.ControlMessage{ SessionStart: &sdkprotocol.SessionStart{ - SessionID: "tmp", // coordinator replaces this value when accepting request - Protocol: normalizeProtocol(req.Protocol), - Operation: sdkprotocol.OperationTypeKeygen, - Threshold: req.Threshold, + SessionID: "tmp", // coordinator replaces this value when accepting request + Protocol: normalizeProtocol(req.Protocol), + Operation: sdkprotocol.OperationTypeKeygen, + Threshold: req.Threshold, Participants: mapParticipants(req.Participants), Keygen: &sdkprotocol.KeygenPayload{ KeyID: req.WalletID, @@ -111,14 +123,46 @@ func (c *Client) RequestKeygen(ctx context.Context, req KeygenRequest) (*sdkprot }, } + return c.requestSession(ctx, requestKeygenSubject, msg, "keygen") +} + +func (c *Client) RequestSign(ctx context.Context, req SignRequest) (*sdkprotocol.RequestAccepted, error) { + if err := validateSignRequest(req); err != nil { + return nil, err + } + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, c.timeout) + defer cancel() + } + + msg := &sdkprotocol.ControlMessage{ + SessionStart: &sdkprotocol.SessionStart{ + SessionID: "tmp", // coordinator replaces this value when accepting request + Protocol: normalizeProtocol(req.Protocol), + Operation: sdkprotocol.OperationTypeSign, + Threshold: req.Threshold, + Participants: mapParticipants(req.Participants), + Sign: &sdkprotocol.SignPayload{ + KeyID: req.WalletID, + SigningInput: append([]byte(nil), req.SigningInput...), + Derivation: req.Derivation, + }, + }, + } + + return c.requestSession(ctx, requestSignSubject, msg, "sign") +} + +func (c *Client) requestSession(ctx context.Context, subject string, msg *sdkprotocol.ControlMessage, action string) (*sdkprotocol.RequestAccepted, error) { payload, err := json.Marshal(msg) if err != nil { - return nil, fmt.Errorf("marshal keygen request: %w", err) + return nil, fmt.Errorf("marshal %s request: %w", action, err) } - respMsg, err := c.nc.RequestWithContext(ctx, requestKeygenSubject, payload) + respMsg, err := c.nc.RequestWithContext(ctx, subject, payload) if err != nil { - return nil, fmt.Errorf("request keygen: %w", err) + return nil, fmt.Errorf("request %s: %w", action, err) } var accepted sdkprotocol.RequestAccepted @@ -195,6 +239,23 @@ func validateKeygenRequest(req KeygenRequest) error { return nil } +func validateSignRequest(req SignRequest) error { + if req.Protocol == sdkprotocol.ProtocolTypeUnspecified || string(req.Protocol) == "" { + return fmt.Errorf("protocol is required") + } + if len(req.SigningInput) == 0 { + return fmt.Errorf("signingInput is required") + } + if err := validateKeygenRequest(KeygenRequest{ + Threshold: req.Threshold, + WalletID: req.WalletID, + Participants: req.Participants, + }); err != nil { + return err + } + return nil +} + func mapParticipants(participants []KeygenParticipant) []*sdkprotocol.SessionParticipant { mapped := make([]*sdkprotocol.SessionParticipant, 0, len(participants)) for _, participant := range participants { From 97c8d32ea898675ad7ef3ce1c6a055c121b0abc0 Mon Sep 17 00:00:00 2001 From: vietddude Date: Sat, 18 Apr 2026 11:39:28 +0700 Subject: [PATCH 12/23] Update dependencies in go.mod and go.sum; refactor key generation and signing examples to use hex decoding for public keys. Enhance error messages in coordinator validation logic for better clarity. Improve logging in cosigner runtime to include identity public key in startup messages. --- examples/coordinatorclient-keygen/main.go | 24 ++++++----------------- examples/coordinatorclient-sign/main.go | 8 ++++---- go.mod | 8 ++++---- go.sum | 14 +++++++------ internal/coordinator/coordinator.go | 2 +- internal/coordinator/coordinator_test.go | 3 +++ internal/cosigner/runtime.go | 3 ++- 7 files changed, 28 insertions(+), 34 deletions(-) diff --git a/examples/coordinatorclient-keygen/main.go b/examples/coordinatorclient-keygen/main.go index c4ad1287..2a68074a 100644 --- a/examples/coordinatorclient-keygen/main.go +++ b/examples/coordinatorclient-keygen/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "crypto/ed25519" "encoding/hex" "fmt" "log" @@ -26,11 +25,15 @@ func main() { participants := []coordinatorclient.KeygenParticipant{ { ID: "peer-node-01", - IdentityPublicKey: mustPublicKeyFromPrivateHex("b14d168636008a9c766a6c231c182446e4b636cd2116817a89d068ffb5cc49e456a47a1103b610d6c85bf23ddb1f78ff6404f7c6f170d46441a268e105873cc4"), + IdentityPublicKey: mustDecodeHex("56a47a1103b610d6c85bf23ddb1f78ff6404f7c6f170d46441a268e105873cc4"), }, { ID: "peer-node-02", - IdentityPublicKey: mustPublicKeyFromPrivateHex("a96d8c0de1b5682740f6487b13dc7477aaa739b900c6f5c3db737ca019163efad9034dd84e0dd10a57d6a09a8267b217051d5f121ff52fca66c2b485be16ae02"), + IdentityPublicKey: mustDecodeHex("d9034dd84e0dd10a57d6a09a8267b217051d5f121ff52fca66c2b485be16ae02"), + }, + { + ID: "mobile-sample-01", + IdentityPublicKey: mustDecodeHex("0c67697e3142c1c87dd8fa034fdfece14fc8ba00145bc0f123d6cd8bd33640e2"), }, } @@ -47,21 +50,6 @@ func mustDecodeHex(value string) []byte { return decoded } -func mustPublicKeyFromPrivateHex(privateKeyHex string) []byte { - privateRaw := mustDecodeHex(privateKeyHex) - var private ed25519.PrivateKey - switch len(privateRaw) { - case ed25519.PrivateKeySize: - private = ed25519.PrivateKey(privateRaw) - case ed25519.SeedSize: - private = ed25519.NewKeyFromSeed(privateRaw) - default: - panic(fmt.Sprintf("invalid ed25519 private key length: %d", len(privateRaw))) - } - public := private.Public().(ed25519.PublicKey) - return append([]byte(nil), public...) -} - func runKeygenForProtocol(client *coordinatorclient.Client, participants []coordinatorclient.KeygenParticipant, walletID string, protocol sdkprotocol.ProtocolType) { requestCtx, cancelRequest := context.WithTimeout(context.Background(), 10*time.Second) resp, err := client.RequestKeygen(requestCtx, coordinatorclient.KeygenRequest{ diff --git a/examples/coordinatorclient-sign/main.go b/examples/coordinatorclient-sign/main.go index d65b49ad..4622d143 100644 --- a/examples/coordinatorclient-sign/main.go +++ b/examples/coordinatorclient-sign/main.go @@ -25,15 +25,15 @@ func main() { participants := []coordinatorclient.SignParticipant{ { ID: "peer-node-01", - IdentityPublicKey: mustPublicKeyFromPrivateHex("b14d168636008a9c766a6c231c182446e4b636cd2116817a89d068ffb5cc49e456a47a1103b610d6c85bf23ddb1f78ff6404f7c6f170d46441a268e105873cc4"), + IdentityPublicKey: mustDecodeHex("56a47a1103b610d6c85bf23ddb1f78ff6404f7c6f170d46441a268e105873cc4"), }, { - ID: "peer-node-02", - IdentityPublicKey: mustPublicKeyFromPrivateHex("a96d8c0de1b5682740f6487b13dc7477aaa739b900c6f5c3db737ca019163efad9034dd84e0dd10a57d6a09a8267b217051d5f121ff52fca66c2b485be16ae02"), + ID: "mobile-sample-01", + IdentityPublicKey: mustDecodeHex("0c67697e3142c1c87dd8fa034fdfece14fc8ba00145bc0f123d6cd8bd33640e2"), }, } - walletID := "wallet_2834c034-489c-40f2-a237-2afd4a73bfaa" + walletID := "wallet_f8029c22-a222-4828-b135-8aacc021d716" message := []byte("deadbeef") protocol := sdkprotocol.ProtocolTypeEdDSA diff --git a/go.mod b/go.mod index b1341112..bd0280ec 100644 --- a/go.mod +++ b/go.mod @@ -14,9 +14,11 @@ require ( github.com/btcsuite/btcutil v1.0.2 github.com/decred/dcrd/dcrec/edwards/v2 v2.0.4 github.com/dgraph-io/badger/v4 v4.9.0 + github.com/eclipse/paho.mqtt.golang v1.5.1 github.com/google/uuid v1.6.0 github.com/hashicorp/consul/api v1.33.2 github.com/mitchellh/mapstructure v1.5.0 + github.com/mochi-mqtt/server/v2 v2.7.9 github.com/nats-io/nats.go v1.48.0 github.com/rs/zerolog v1.34.0 github.com/samber/lo v1.52.0 @@ -28,11 +30,9 @@ require ( ) require ( - github.com/eclipse/paho.mqtt.golang v1.5.1 // indirect github.com/gorilla/websocket v1.5.3 // indirect - github.com/mochi-mqtt/server/v2 v2.7.9 // indirect github.com/rs/xid v1.6.0 // indirect - golang.org/x/sync v0.19.0 // indirect + golang.org/x/sync v0.20.0 // indirect ) require ( @@ -104,7 +104,7 @@ require ( go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect golang.org/x/net v0.49.0 // indirect - golang.org/x/sys v0.42.0 // indirect + golang.org/x/sys v0.43.0 // indirect golang.org/x/text v0.34.0 // indirect google.golang.org/protobuf v1.36.11 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index b6a1d170..c9bead4d 100644 --- a/go.sum +++ b/go.sum @@ -216,6 +216,8 @@ github.com/ipfs/go-log/v2 v2.9.0 h1:l4b06AwVXwldIzbVPZy5z7sKp9lHFTX0KWfTBCtHaOk= github.com/ipfs/go-log/v2 v2.9.0/go.mod h1:UhIYAwMV7Nb4ZmihUxfIRM2Istw/y9cAk3xaK+4Zs2c= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg= +github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= @@ -444,8 +446,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -479,8 +481,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= -golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -507,8 +509,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= -golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c= +golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/coordinator/coordinator.go b/internal/coordinator/coordinator.go index f7c8600c..543b5e72 100644 --- a/internal/coordinator/coordinator.go +++ b/internal/coordinator/coordinator.go @@ -318,7 +318,7 @@ func (c *Coordinator) validateRequest(ctx context.Context, op Operation, msg *sd return newCoordinatorError(ErrorCodeValidation, "party_key must equal participant_id bytes") } if !c.presence.IsOnline(ctx, participant.ParticipantID) { - return newCoordinatorError(ErrorCodeUnavailable, "participant is offline") + return newCoordinatorError(ErrorCodeUnavailable, fmt.Sprintf("participant %q is offline", participant.ParticipantID)) } } if op == OperationKeygen && c.keyInfoStore != nil { diff --git a/internal/coordinator/coordinator_test.go b/internal/coordinator/coordinator_test.go index 5d13fdba..c787ff85 100644 --- a/internal/coordinator/coordinator_test.go +++ b/internal/coordinator/coordinator_test.go @@ -97,6 +97,9 @@ func TestHandleRequestRejectsOfflineParticipant(t *testing.T) { if reply.ErrorCode != ErrorCodeUnavailable { t.Fatalf("error code = %s, want %s", reply.ErrorCode, ErrorCodeUnavailable) } + if reply.ErrorMessage != `participant "p1" is offline` { + t.Fatalf("error message = %q", reply.ErrorMessage) + } } func TestLifecycleCompletesSignAndPublishesResult(t *testing.T) { diff --git a/internal/cosigner/runtime.go b/internal/cosigner/runtime.go index 773b9952..58fd2f99 100644 --- a/internal/cosigner/runtime.go +++ b/internal/cosigner/runtime.go @@ -5,6 +5,7 @@ import ( "context" "crypto/ed25519" "encoding/gob" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -84,7 +85,7 @@ func (r *Runtime) Close() error { } func (r *Runtime) Run(ctx context.Context) error { - logger.Info("cosigner runtime started", "node_id", r.cfg.NodeID) + logger.Info("cosigner runtime started", "node_id", r.cfg.NodeID, "identity_public_key_hex", hex.EncodeToString(r.identity.PublicKey())) if err := r.ensureECDSAPreparams(); err != nil { return err } From 2027943963326713409b97eea737664b951487d8 Mon Sep 17 00:00:00 2001 From: vietddude Date: Sat, 18 Apr 2026 11:46:20 +0700 Subject: [PATCH 13/23] Add documentation for local mixed transport setup with NATS and MQTT cosigners. Include configuration details, run order, and troubleshooting steps for coordinator, relay, and cosigner nodes. --- docs/local-coordinator-relay-cosigners.md | 199 ++++++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 docs/local-coordinator-relay-cosigners.md diff --git a/docs/local-coordinator-relay-cosigners.md b/docs/local-coordinator-relay-cosigners.md new file mode 100644 index 00000000..fcb5534f --- /dev/null +++ b/docs/local-coordinator-relay-cosigners.md @@ -0,0 +1,199 @@ +# Local Coordinator, Relay, NATS Node, and MQTT Cosigner + +This guide runs a local mixed transport setup: + +- 1 NATS server +- 1 coordinator +- 1 relay that bridges NATS and MQTT +- 1 MPCium cosigner node over NATS: `peer-node-01` +- 1 cosigner over MQTT: `peer-node-02` + +The sample configs used by this guide are already in the repository: + +- `coordinator.config.yaml` +- `relay.config.yaml` +- `cosigner.config.yaml` +- `cosigner2.config.yaml` + +## Prerequisites + +Install and start a local NATS server on `127.0.0.1:4222`. + +```sh +nats-server +``` + +The relay listens on MQTT port `1883`. Make sure nothing else is using that port. + +## Config Overview + +`cosigner.config.yaml` runs `peer-node-01` through NATS: + +```yaml +relay_provider: nats +node_id: peer-node-01 +nats: + url: "nats://127.0.0.1:4222" +``` + +`cosigner2.config.yaml` runs `peer-node-02` through MQTT: + +```yaml +relay_provider: mqtt +node_id: peer-node-02 +mqtt: + broker: tcp://127.0.0.1:1883 + client_id: peer-node-02 + username: peer-node-02 + password: peer-node-02 +``` + +## MQTT Credentials + +Create `relay.credentials` in the repository root: + +```txt +mobile-sample-01:mobile-sample-01 +peer-node-02:peer-node-02 +``` + +The relay reads this file from `relay.config.yaml`: + +```yaml +relay: + mqtt: + username_password_file: ./relay.credentials +``` + +Each line is: + +```txt +username:password +``` + +The relay requires the MQTT username to match the MQTT client ID. For `cosigner2.config.yaml`, all three values are `peer-node-02`: + +```yaml +mqtt: + client_id: peer-node-02 + username: peer-node-02 + password: peer-node-02 +``` + +If the mobile sample connects through MQTT as `mobile-sample-01`, it must use: + +```txt +client_id: mobile-sample-01 +username: mobile-sample-01 +password: mobile-sample-01 +``` + +## Run Order + +Open one terminal per process. + +### 1. Coordinator + +```sh +go run ./cmd/mpcium-coordinator/main.go -c coordinator.config.yaml +``` + +Expected logs include coordinator request, presence, and session event subscriptions. + +### 2. Relay + +```sh +go run ./cmd/mpcium-relay/main.go -c relay.config.yaml +``` + +Expected logs include: + +```txt +relay subscribed NATS filter +relay subscribed MQTT filter +relay runtime started +``` + +### 3. NATS Cosigner Node + +```sh +go run ./cmd/mpcium-cosigner/main.go -c cosigner.config.yaml +``` + +Expected logs include: + +```txt +cosigner runtime started node_id=peer-node-01 +relay nats subscribe subject=mpc.v1.peer.peer-node-01.control +``` + +### 4. MQTT Cosigner + +```sh +go run ./cmd/mpcium-cosigner/main.go -c cosigner2.config.yaml +``` + +Expected logs include: + +```txt +cosigner runtime started node_id=peer-node-02 +relay mqtt subscribe subject=mpc.v1.peer.peer-node-02.control topic=mpc/v1/peer/peer-node-02/control +``` + +The relay should also log that `peer-node-02` connected. + +## Wait for Presence + +The coordinator keeps presence in memory. After starting or restarting the coordinator, relay, or MQTT cosigner, wait a few seconds before sending a keygen request. + +Each online participant must publish presence before the coordinator accepts a session. If you send a request too early, the coordinator can reject it with: + +```txt +coordinator rejected request (UNAVAILABLE): participant "peer-node-02" is offline +``` + +That means the session has not started yet. Wait for the cosigner heartbeat, then retry. + +## Run Keygen + +After both cosigners are online, run: + +```sh +go run ./examples/coordinatorclient-keygen +``` + +Expected output: + +```txt +protocol=ECDSA key_id=wallet_... session_id=sess_... wait_seconds=... +public_key_hex=... +protocol=EdDSA key_id=wallet_... session_id=sess_... wait_seconds=... +public_key_hex=... +``` + +## Troubleshooting + +If `peer-node-02` is offline: + +- Confirm the relay is running on `127.0.0.1:1883`. +- Confirm `cosigner2.config.yaml` uses `client_id`, `username`, and `password` set to `peer-node-02`. +- Confirm `relay.credentials` contains `peer-node-02:peer-node-02`. +- Restart the MQTT cosigner and wait for presence before retrying keygen. + +If `peer-node-01` is offline: + +- Confirm the NATS cosigner is running with `cosigner.config.yaml`. +- Confirm it can connect to `nats://127.0.0.1:4222`. +- Restart the NATS cosigner and wait for presence before retrying keygen. + +If the relay starts but MQTT traffic does not reach NATS: + +- Confirm `relay.config.yaml` uses `relay.bridge.nats_prefix: mpc.v1`. +- Confirm `relay.config.yaml` uses `relay.bridge.mqtt_prefix: mpc/v1`. +- Confirm the MQTT cosigner subscribes to `mpc/v1/peer/peer-node-02/control`. + +If keygen hangs after the request is accepted: + +- Check both cosigner logs for `cosigner received session start`. +- Check relay logs for `NATS->MQTT` and `MQTT->NATS` bridge logs. +- Make sure every participant in `examples/coordinatorclient-keygen/main.go` is running and online. From 6693f0392bc68b1f24d1b66db3c1205bbe3544c2 Mon Sep 17 00:00:00 2001 From: vietddude Date: Sat, 18 Apr 2026 11:55:42 +0700 Subject: [PATCH 14/23] Update SDK dependency references from vietddude to fystack in go.mod and various source files. This change ensures consistency in the usage of the mpcium-sdk across the project. --- examples/coordinatorclient-keygen/main.go | 2 +- examples/coordinatorclient-sign/main.go | 2 +- go.mod | 4 ++-- internal/coordinator/coordinator.go | 2 +- internal/coordinator/coordinator_test.go | 2 +- internal/coordinator/presence.go | 2 +- internal/coordinator/publisher.go | 2 +- internal/coordinator/runtime.go | 2 +- internal/coordinator/signing.go | 2 +- internal/coordinator/store.go | 2 +- internal/coordinator/types.go | 2 +- internal/cosigner/relay.go | 2 +- internal/cosigner/relay_mqtt.go | 2 +- internal/cosigner/relay_nats.go | 2 +- internal/cosigner/runtime.go | 4 ++-- internal/cosigner/storage.go | 2 +- internal/relay/runtime.go | 2 +- pkg/coordinatorclient/client.go | 2 +- 18 files changed, 20 insertions(+), 20 deletions(-) diff --git a/examples/coordinatorclient-keygen/main.go b/examples/coordinatorclient-keygen/main.go index 2a68074a..9e5604d4 100644 --- a/examples/coordinatorclient-keygen/main.go +++ b/examples/coordinatorclient-keygen/main.go @@ -7,9 +7,9 @@ import ( "log" "time" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" "github.com/fystack/mpcium/pkg/coordinatorclient" "github.com/google/uuid" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) func main() { diff --git a/examples/coordinatorclient-sign/main.go b/examples/coordinatorclient-sign/main.go index 4622d143..9e5d7c62 100644 --- a/examples/coordinatorclient-sign/main.go +++ b/examples/coordinatorclient-sign/main.go @@ -8,8 +8,8 @@ import ( "log" "time" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" "github.com/fystack/mpcium/pkg/coordinatorclient" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) func main() { diff --git a/go.mod b/go.mod index bd0280ec..c8dc56d4 100644 --- a/go.mod +++ b/go.mod @@ -94,7 +94,7 @@ require ( github.com/spf13/pflag v1.0.10 // indirect github.com/stretchr/objx v0.5.3 // indirect github.com/subosito/gotenv v1.6.0 // indirect - github.com/vietddude/mpcium-sdk v0.0.0 + github.com/fystack/mpcium-sdk v0.0.0 go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel v1.39.0 // indirect go.opentelemetry.io/otel/metric v1.39.0 // indirect @@ -114,4 +114,4 @@ replace github.com/agl/ed25519 => github.com/binance-chain/edwards25519 v0.0.0-2 replace github.com/bnb-chain/tss-lib/v2 => github.com/fystack/tss-lib/v2 v2.0.3 -replace github.com/vietddude/mpcium-sdk => ../sdk +replace github.com/fystack/mpcium-sdk => ../sdk diff --git a/internal/coordinator/coordinator.go b/internal/coordinator/coordinator.go index 543b5e72..65df9ff8 100644 --- a/internal/coordinator/coordinator.go +++ b/internal/coordinator/coordinator.go @@ -10,9 +10,9 @@ import ( "strings" "time" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" "github.com/fystack/mpcium/pkg/logger" "github.com/google/uuid" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) type Coordinator struct { diff --git a/internal/coordinator/coordinator_test.go b/internal/coordinator/coordinator_test.go index c787ff85..7e843584 100644 --- a/internal/coordinator/coordinator_test.go +++ b/internal/coordinator/coordinator_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" ) type fakeSigner struct{} diff --git a/internal/coordinator/presence.go b/internal/coordinator/presence.go index 0ef95c31..51224ba0 100644 --- a/internal/coordinator/presence.go +++ b/internal/coordinator/presence.go @@ -5,7 +5,7 @@ import ( "sync" "time" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" ) type PresenceView interface { diff --git a/internal/coordinator/publisher.go b/internal/coordinator/publisher.go index 8de0b6ba..18bf720c 100644 --- a/internal/coordinator/publisher.go +++ b/internal/coordinator/publisher.go @@ -5,8 +5,8 @@ import ( "encoding/json" "fmt" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" "github.com/nats-io/nats.go" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) type ControlPublisher interface { diff --git a/internal/coordinator/runtime.go b/internal/coordinator/runtime.go index f0df7802..ccc85148 100644 --- a/internal/coordinator/runtime.go +++ b/internal/coordinator/runtime.go @@ -5,9 +5,9 @@ import ( "encoding/json" "fmt" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" "github.com/fystack/mpcium/pkg/logger" "github.com/nats-io/nats.go" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) type NATSRuntime struct { diff --git a/internal/coordinator/signing.go b/internal/coordinator/signing.go index ef915dbb..15b86b21 100644 --- a/internal/coordinator/signing.go +++ b/internal/coordinator/signing.go @@ -6,7 +6,7 @@ import ( "encoding/hex" "fmt" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" ) type Signer interface { diff --git a/internal/coordinator/store.go b/internal/coordinator/store.go index d298ed1a..6678f4d5 100644 --- a/internal/coordinator/store.go +++ b/internal/coordinator/store.go @@ -10,7 +10,7 @@ import ( "strings" "sync" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" ) type SnapshotStore interface { diff --git a/internal/coordinator/types.go b/internal/coordinator/types.go index f85eb423..74969d33 100644 --- a/internal/coordinator/types.go +++ b/internal/coordinator/types.go @@ -3,7 +3,7 @@ package coordinator import ( "time" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" ) type Operation string diff --git a/internal/cosigner/relay.go b/internal/cosigner/relay.go index 99f3c36c..5dd3857d 100644 --- a/internal/cosigner/relay.go +++ b/internal/cosigner/relay.go @@ -3,7 +3,7 @@ package cosigner import ( "fmt" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" ) type Subscription interface { diff --git a/internal/cosigner/relay_mqtt.go b/internal/cosigner/relay_mqtt.go index ec07b2c9..80ad1606 100644 --- a/internal/cosigner/relay_mqtt.go +++ b/internal/cosigner/relay_mqtt.go @@ -6,8 +6,8 @@ import ( "time" mqtt "github.com/eclipse/paho.mqtt.golang" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" "github.com/fystack/mpcium/pkg/logger" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) const mqttOperationTimeout = 10 * time.Second diff --git a/internal/cosigner/relay_nats.go b/internal/cosigner/relay_nats.go index 9c957e41..127b3248 100644 --- a/internal/cosigner/relay_nats.go +++ b/internal/cosigner/relay_nats.go @@ -7,9 +7,9 @@ import ( "os" "time" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" "github.com/fystack/mpcium/pkg/logger" "github.com/nats-io/nats.go" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) type NATSRelay struct { diff --git a/internal/cosigner/runtime.go b/internal/cosigner/runtime.go index 58fd2f99..f2e88f36 100644 --- a/internal/cosigner/runtime.go +++ b/internal/cosigner/runtime.go @@ -14,9 +14,9 @@ import ( "time" ecdsaKeygen "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" + "github.com/fystack/mpcium-sdk/participant" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" "github.com/fystack/mpcium/pkg/logger" - "github.com/vietddude/mpcium-sdk/participant" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) type Runtime struct { diff --git a/internal/cosigner/storage.go b/internal/cosigner/storage.go index 52cc3903..dfbf1d51 100644 --- a/internal/cosigner/storage.go +++ b/internal/cosigner/storage.go @@ -5,7 +5,7 @@ import ( "path/filepath" "github.com/dgraph-io/badger/v4" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" ) type PreparamsStore interface { diff --git a/internal/relay/runtime.go b/internal/relay/runtime.go index e1617f4b..6160abd6 100644 --- a/internal/relay/runtime.go +++ b/internal/relay/runtime.go @@ -13,12 +13,12 @@ import ( "sync" "time" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" "github.com/fystack/mpcium/pkg/logger" mqtt "github.com/mochi-mqtt/server/v2" "github.com/mochi-mqtt/server/v2/listeners" "github.com/mochi-mqtt/server/v2/packets" "github.com/nats-io/nats.go" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) const mqttOriginValue = "mqtt" diff --git a/pkg/coordinatorclient/client.go b/pkg/coordinatorclient/client.go index 19cdc817..3bfd233a 100644 --- a/pkg/coordinatorclient/client.go +++ b/pkg/coordinatorclient/client.go @@ -6,8 +6,8 @@ import ( "fmt" "time" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" "github.com/nats-io/nats.go" - sdkprotocol "github.com/vietddude/mpcium-sdk/protocol" ) const ( From 10a2c9ba526cad69f4c64d73b2fd324e1a21e83a Mon Sep 17 00:00:00 2001 From: vietddude Date: Sat, 18 Apr 2026 11:58:15 +0700 Subject: [PATCH 15/23] Add local SDK replacement instructions for development. Include details on directory structure and `go.mod` replace directive for seamless integration of the SDK during local development. --- docs/local-coordinator-relay-cosigners.md | 35 +++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/docs/local-coordinator-relay-cosigners.md b/docs/local-coordinator-relay-cosigners.md index fcb5534f..69ac1ab4 100644 --- a/docs/local-coordinator-relay-cosigners.md +++ b/docs/local-coordinator-relay-cosigners.md @@ -25,6 +25,41 @@ nats-server The relay listens on MQTT port `1883`. Make sure nothing else is using that port. +## Local SDK Replace + +This repository imports the SDK as: + +```go +github.com/fystack/mpcium-sdk +``` + +For local development, keep the SDK repository next to this repository: + +```txt +work/ + mpcium/ + sdk/ +``` + +Then make sure `go.mod` contains this replace directive: + +```go +replace github.com/fystack/mpcium-sdk => ../sdk +``` + +You can check it with: + +```sh +grep 'github.com/fystack/mpcium-sdk => ../sdk' go.mod +``` + +If the SDK is somewhere else, update the replace path: + +```sh +go mod edit -replace github.com/fystack/mpcium-sdk=/absolute/path/to/sdk +go mod tidy +``` + ## Config Overview `cosigner.config.yaml` runs `peer-node-01` through NATS: From c40b6a403b155ea8ac6748a0640fc6107e2795e1 Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 22 Apr 2026 15:18:17 +0700 Subject: [PATCH 16/23] Refactor coordinator runtimes and add gRPC orchestration. Introduce composable runtime startup with optional gRPC orchestration transport, extend coordinator config/runtime plumbing, and support dual-protocol keygen result aggregation in coordinator state handling. Made-with: Cursor --- Makefile | 9 +- cmd/mpcium-coordinator/main.go | 25 ++- coordinator.config.yaml | 6 +- go.mod | 15 +- go.sum | 32 ++- internal/coordinator/README.md | 7 +- internal/coordinator/composite_runtime.go | 66 ++++++ internal/coordinator/config.go | 32 ++- internal/coordinator/config_test.go | 78 +++++++ internal/coordinator/coordinator.go | 176 ++++++++++++++- internal/coordinator/coordinator_test.go | 108 ++++++++++ internal/coordinator/grpc_runtime.go | 74 +++++++ .../{runtime.go => nats_runtime.go} | 0 internal/coordinator/orchestration_grpc.go | 203 ++++++++++++++++++ internal/coordinator/store.go | 3 + 15 files changed, 798 insertions(+), 36 deletions(-) create mode 100644 internal/coordinator/composite_runtime.go create mode 100644 internal/coordinator/config_test.go create mode 100644 internal/coordinator/grpc_runtime.go rename internal/coordinator/{runtime.go => nats_runtime.go} (100%) create mode 100644 internal/coordinator/orchestration_grpc.go diff --git a/Makefile b/Makefile index 6f94ebc0..b41d233e 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: all build clean mpcium mpc install reset test test-verbose test-coverage e2e-test e2e-clean cleanup-test-env +.PHONY: all build clean mpcium mpc install reset test test-verbose test-coverage e2e-test e2e-clean cleanup-test-env proto proto-tools BIN_DIR := bin @@ -83,6 +83,13 @@ endif test: go test ./... +proto-tools: + go install google.golang.org/protobuf/cmd/protoc-gen-go@latest + go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest + +proto: + $(MAKE) -C ../sdk proto + # Run tests with verbose output test-verbose: go test -v ./... diff --git a/cmd/mpcium-coordinator/main.go b/cmd/mpcium-coordinator/main.go index 51515cf4..c5148eaf 100644 --- a/cmd/mpcium-coordinator/main.go +++ b/cmd/mpcium-coordinator/main.go @@ -18,7 +18,7 @@ import ( const coordinatorConfigPath = "coordinator.config.yaml" func main() { - logger.Init(os.Getenv("ENVIRONMENT"), true) + logger.Init(os.Getenv("ENVIRONMENT"), false) cmd := &cli.Command{ Name: "mpcium-coordinator", @@ -92,22 +92,35 @@ func run(ctx context.Context, c *cli.Command) error { return err } - runtime := coordinator.NewNATSRuntime(nc, coord, presence) - if err := runtime.Start(ctx); err != nil { + natsRuntime := coordinator.NewNATSRuntime(nc, coord, presence) + composite := coordinator.NewCompositeRuntime(natsRuntime) + if cfg.GRPCEnabled { + composite = coordinator.NewCompositeRuntime( + natsRuntime, + coordinator.NewGRPCRuntime(cfg.GRPCListenAddr, coord, cfg.GRPCPollInterval), + ) + } + + if err := composite.Start(ctx); err != nil { return err } + defer func() { + if err := composite.Stop(); err != nil { + logger.Error("stop coordinator runtime failed", err) + } + }() - return runTickLoop(ctx, runtime, coord, cfg.TickInterval) + return runTickLoop(ctx, coord, cfg.TickInterval) } -func runTickLoop(ctx context.Context, runtime *coordinator.NATSRuntime, coord *coordinator.Coordinator, interval time.Duration) error { +func runTickLoop(ctx context.Context, coord *coordinator.Coordinator, interval time.Duration) error { ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ctx.Done(): - return runtime.Stop() + return nil case <-ticker.C: if _, err := coord.Tick(ctx); err != nil { logger.Error("coordinator tick error", err) diff --git a/coordinator.config.yaml b/coordinator.config.yaml index 1c649996..2ba33f05 100644 --- a/coordinator.config.yaml +++ b/coordinator.config.yaml @@ -1,8 +1,12 @@ nats: url: nats://127.0.0.1:4222 +grpc: + enabled: true + listen_addr: 127.0.0.1:50051 + poll_interval: 200ms + coordinator: id: coordinator-01 private_key_hex: "86ed171146e6003841f1686c0958b68ae84f9992974c2c6febfb9df7f424b3adb64ca8ec459081a299aecc2b2b5d555265b15ddfd29e792ddd08bedb418bdd0d" snapshot_dir: coordinator-snapshots - relay_available: true diff --git a/go.mod b/go.mod index c8dc56d4..547ab5e8 100644 --- a/go.mod +++ b/go.mod @@ -25,14 +25,16 @@ require ( github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 github.com/urfave/cli/v3 v3.6.2 - golang.org/x/crypto v0.48.0 - golang.org/x/term v0.40.0 + golang.org/x/crypto v0.50.0 + golang.org/x/term v0.42.0 + google.golang.org/grpc v1.80.0 ) require ( github.com/gorilla/websocket v1.5.3 // indirect github.com/rs/xid v1.6.0 // indirect golang.org/x/sync v0.20.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260120221211-b8f7ae30c516 // indirect ) require ( @@ -60,6 +62,8 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/fatih/color v1.18.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/fystack/mpcium-sdk v0.0.0 + github.com/fystack/mpcium-sdk/integrations/coordinator-grpc v0.0.0 github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-viper/mapstructure/v2 v2.5.0 // indirect @@ -94,7 +98,6 @@ require ( github.com/spf13/pflag v1.0.10 // indirect github.com/stretchr/objx v0.5.3 // indirect github.com/subosito/gotenv v1.6.0 // indirect - github.com/fystack/mpcium-sdk v0.0.0 go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel v1.39.0 // indirect go.opentelemetry.io/otel/metric v1.39.0 // indirect @@ -103,9 +106,9 @@ require ( go.uber.org/zap v1.27.1 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect - golang.org/x/net v0.49.0 // indirect + golang.org/x/net v0.53.0 // indirect golang.org/x/sys v0.43.0 // indirect - golang.org/x/text v0.34.0 // indirect + golang.org/x/text v0.36.0 // indirect google.golang.org/protobuf v1.36.11 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) @@ -115,3 +118,5 @@ replace github.com/agl/ed25519 => github.com/binance-chain/edwards25519 v0.0.0-2 replace github.com/bnb-chain/tss-lib/v2 => github.com/fystack/tss-lib/v2 v2.0.3 replace github.com/fystack/mpcium-sdk => ../sdk + +replace github.com/fystack/mpcium-sdk/integrations/coordinator-grpc => ../sdk/integrations/coordinator-grpc diff --git a/go.sum b/go.sum index c9bead4d..30dfee1b 100644 --- a/go.sum +++ b/go.sum @@ -151,6 +151,8 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= @@ -379,6 +381,10 @@ go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= +go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= +go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= +go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= @@ -404,8 +410,8 @@ golang.org/x/crypto v0.0.0-20200115085410-6d4e4cb37c7d/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= @@ -414,8 +420,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= -golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= +golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= +golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= golang.org/x/net v0.0.0-20180719180050-a680a1efc54d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -434,8 +440,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -488,8 +494,8 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= -golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= -golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= +golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -497,8 +503,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= @@ -515,7 +521,13 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260120221211-b8f7ae30c516 h1:sNrWoksmOyF5bvJUcnmbeAmQi8baNhqg5IWaI3llQqU= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260120221211-b8f7ae30c516/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= +google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/internal/coordinator/README.md b/internal/coordinator/README.md index 8e92856d..6be48b30 100644 --- a/internal/coordinator/README.md +++ b/internal/coordinator/README.md @@ -4,7 +4,7 @@ This package implements the control-plane coordinator for the new MPC runtime. It is responsible for: -- request intake on versioned subjects (`keygen`, `sign`, `reshare`) +- request intake through NATS subjects or the gRPC client orchestration API (`keygen`, `sign`, `reshare`) - session creation and lifecycle state transitions - participant readiness and key exchange gating - control message fan-out to participants @@ -24,6 +24,7 @@ It is not responsible for: - `mpc.v1.request.keygen` - `mpc.v1.request.sign` - `mpc.v1.request.reshare` + or over the gRPC `CoordinatorOrchestration` service for client orchestration. 2. Validate request shape and participant constraints. 3. Create a new `session_id` and initial session state. 4. Fan out `session.start` control messages to each selected participant. @@ -36,6 +37,8 @@ It is not responsible for: core orchestration logic and state machine. - `NATSRuntime`: wiring from subjects to coordinator handlers. +- `GRPCRuntime`: + optional plaintext client API for submitting keygen/sign requests and waiting for terminal session results. - `MemorySessionStore`: in-memory session state. - `AtomicFileSnapshotStore`: @@ -45,6 +48,8 @@ It is not responsible for: - `NATSControlPublisher` / `NATSResultPublisher`: delivery adapters for control and result messages. +The gRPC API is client-facing only. Participant control fan-out, participant session events, presence, and result publishing still use NATS/relay transport. + ## Request Models The operation is determined by subject. Each operation has its own request struct: diff --git a/internal/coordinator/composite_runtime.go b/internal/coordinator/composite_runtime.go new file mode 100644 index 00000000..e6fbf6c5 --- /dev/null +++ b/internal/coordinator/composite_runtime.go @@ -0,0 +1,66 @@ +package coordinator + +import ( + "context" + "errors" + "sync" +) + +type Runtime interface { + Start(ctx context.Context) error + Stop() error +} + +type CompositeRuntime struct { + mu sync.Mutex + runtimes []Runtime + started []Runtime +} + +func NewCompositeRuntime(runtimes ...Runtime) *CompositeRuntime { + filtered := make([]Runtime, 0, len(runtimes)) + for _, r := range runtimes { + if r != nil { + filtered = append(filtered, r) + } + } + return &CompositeRuntime{runtimes: filtered} +} + +func (r *CompositeRuntime) Start(ctx context.Context) error { + r.mu.Lock() + defer r.mu.Unlock() + + if len(r.started) > 0 { + return errors.New("composite runtime already started") + } + + r.started = make([]Runtime, 0, len(r.runtimes)) + + for _, runtime := range r.runtimes { + if err := runtime.Start(ctx); err != nil { + r.stopLocked() + return err + } + r.started = append(r.started, runtime) + } + + return nil +} + +func (r *CompositeRuntime) Stop() error { + r.mu.Lock() + defer r.mu.Unlock() + return r.stopLocked() +} + +func (r *CompositeRuntime) stopLocked() error { + var errs []error + for i := len(r.started) - 1; i >= 0; i-- { + if err := r.started[i].Stop(); err != nil { + errs = append(errs, err) + } + } + r.started = nil + return errors.Join(errs...) +} diff --git a/internal/coordinator/config.go b/internal/coordinator/config.go index 3d898405..3a10cfdf 100644 --- a/internal/coordinator/config.go +++ b/internal/coordinator/config.go @@ -15,6 +15,7 @@ const ( type fileConfig struct { NATS natsConfig `mapstructure:"nats"` + GRPC grpcConfig `mapstructure:"grpc"` Coordinator coordinatorConfig `mapstructure:"coordinator"` } @@ -22,6 +23,12 @@ type natsConfig struct { URL string `mapstructure:"url"` } +type grpcConfig struct { + Enabled bool `mapstructure:"enabled"` + ListenAddr string `mapstructure:"listen_addr"` + PollInterval time.Duration `mapstructure:"poll_interval"` +} + type coordinatorConfig struct { ID string `mapstructure:"id"` PrivateKeyHex string `mapstructure:"private_key_hex"` @@ -30,6 +37,9 @@ type coordinatorConfig struct { type RuntimeConfig struct { NATSURL string + GRPCEnabled bool + GRPCListenAddr string + GRPCPollInterval time.Duration ID string PrivateKeyHex string SnapshotDir string @@ -39,16 +49,19 @@ type RuntimeConfig struct { func (cfg RuntimeConfig) Validate() error { if cfg.NATSURL == "" { - return fmt.Errorf("nats-url is required") + return fmt.Errorf("nats.url is required") } if cfg.ID == "" { - return fmt.Errorf("coordinator-id is required") + return fmt.Errorf("coordinator.id is required") } if cfg.PrivateKeyHex == "" { - return fmt.Errorf("coordinator-private-key-hex is required") + return fmt.Errorf("coordinator.private_key_hex is required") } if cfg.SnapshotDir == "" { - return fmt.Errorf("coordinator-snapshot-dir is required") + return fmt.Errorf("coordinator.snapshot_dir is required") + } + if cfg.GRPCEnabled && cfg.GRPCListenAddr == "" { + return fmt.Errorf("grpc.listen_addr is required when grpc is enabled") } return nil } @@ -58,12 +71,19 @@ func LoadRuntimeConfig() (RuntimeConfig, error) { if err := viper.Unmarshal(&cfg, viper.DecodeHook(mapstructure.StringToTimeDurationHookFunc())); err != nil { return RuntimeConfig{}, fmt.Errorf("decode config: %w", err) } - return cfg.Coordinator.runtimeConfig(cfg.NATS.URL), nil + return cfg.Coordinator.runtimeConfig(cfg.NATS.URL, cfg.GRPC), nil } -func (cfg coordinatorConfig) runtimeConfig(natsURL string) RuntimeConfig { +func (cfg coordinatorConfig) runtimeConfig(natsURL string, grpc grpcConfig) RuntimeConfig { + pollInterval := 200 * time.Millisecond + if grpc.PollInterval > 0 { + pollInterval = grpc.PollInterval + } return RuntimeConfig{ NATSURL: natsURL, + GRPCEnabled: grpc.Enabled, + GRPCListenAddr: grpc.ListenAddr, + GRPCPollInterval: pollInterval, ID: cfg.ID, PrivateKeyHex: cfg.PrivateKeyHex, SnapshotDir: cfg.SnapshotDir, diff --git a/internal/coordinator/config_test.go b/internal/coordinator/config_test.go new file mode 100644 index 00000000..68ad1388 --- /dev/null +++ b/internal/coordinator/config_test.go @@ -0,0 +1,78 @@ +package coordinator + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/spf13/viper" +) + +func TestLoadRuntimeConfigDecodesGRPCConfig(t *testing.T) { + t.Cleanup(viper.Reset) + configPath := writeCoordinatorConfig(t, ` +nats: + url: nats://127.0.0.1:4222 +grpc: + enabled: true + listen_addr: 127.0.0.1:50051 + poll_interval: 250ms +coordinator: + id: coordinator-01 + private_key_hex: abc123 + snapshot_dir: coordinator-snapshots +`) + viper.SetConfigFile(configPath) + if err := viper.ReadInConfig(); err != nil { + t.Fatal(err) + } + + cfg, err := LoadRuntimeConfig() + if err != nil { + t.Fatal(err) + } + if !cfg.GRPCEnabled || cfg.GRPCListenAddr != "127.0.0.1:50051" || cfg.GRPCPollInterval != 250*time.Millisecond { + t.Fatalf("unexpected grpc config: %+v", cfg) + } +} + +func TestLoadRuntimeConfigRejectsInvalidGRPCPollInterval(t *testing.T) { + t.Cleanup(viper.Reset) + configPath := writeCoordinatorConfig(t, ` +nats: + url: nats://127.0.0.1:4222 +grpc: + enabled: true + listen_addr: 127.0.0.1:50051 + poll_interval: nope +coordinator: + id: coordinator-01 + private_key_hex: abc123 + snapshot_dir: coordinator-snapshots +`) + viper.SetConfigFile(configPath) + if err := viper.ReadInConfig(); err != nil { + t.Fatal(err) + } + + if _, err := LoadRuntimeConfig(); err == nil { + t.Fatalf("expected invalid duration error") + } +} + +func TestRuntimeConfigValidateUsesConfigKeyNames(t *testing.T) { + err := RuntimeConfig{}.Validate() + if err == nil || err.Error() != "nats.url is required" { + t.Fatalf("Validate() error = %v", err) + } +} + +func writeCoordinatorConfig(t *testing.T, contents string) string { + t.Helper() + path := filepath.Join(t.TempDir(), "coordinator.config.yaml") + if err := os.WriteFile(path, []byte(contents), 0o600); err != nil { + t.Fatal(err) + } + return path +} diff --git a/internal/coordinator/coordinator.go b/internal/coordinator/coordinator.go index 65df9ff8..f055992e 100644 --- a/internal/coordinator/coordinator.go +++ b/internal/coordinator/coordinator.go @@ -8,6 +8,7 @@ import ( "fmt" "sort" "strings" + "sync" "time" sdkprotocol "github.com/fystack/mpcium-sdk/protocol" @@ -26,6 +27,18 @@ type Coordinator struct { results ResultPublisher defaultSessionTTL time.Duration now func() time.Time + dualKeygenMu sync.Mutex + dualKeygen map[string]*dualKeygenGroup +} + +type dualKeygenGroup struct { + sessionIDs map[sdkprotocol.ProtocolType]string + results map[sdkprotocol.ProtocolType][]byte +} + +type dualKeygenCompletion struct { + result *sdkprotocol.Result + sessionIDs []string } func NewCoordinator(cfg CoordinatorConfig) (*Coordinator, error) { @@ -45,6 +58,7 @@ func NewCoordinator(cfg CoordinatorConfig) (*Coordinator, error) { results: cfg.Results, defaultSessionTTL: cfg.DefaultSessionTTL, now: cfg.Now, + dualKeygen: make(map[string]*dualKeygenGroup), }, nil } @@ -57,7 +71,7 @@ func (c *Coordinator) HandleRequest(ctx context.Context, op Operation, raw []byt return reject(ErrorCodeInvalidJSON, "invalid JSON request"), nil } // Backward compatibility: keygen without protocol means dispatch both ECDSA and EdDSA sessions. - if op == OperationKeygen && req.SessionStart != nil && isProtocolUnspecified(req.SessionStart.Protocol) { + if op == OperationKeygen && req.SessionStart != nil && isBothKeygenProtocol(req.SessionStart.Protocol) { protocols := []sdkprotocol.ProtocolType{sdkprotocol.ProtocolTypeECDSA, sdkprotocol.ProtocolTypeEdDSA} sessionIDs := make([]string, 0, len(protocols)) var firstAccepted *sdkprotocol.RequestAccepted @@ -89,8 +103,9 @@ func (c *Coordinator) HandleRequest(ctx context.Context, op Operation, raw []byt } return reject(ErrorCodeConflict, "no keygen sessions created"), nil } + c.registerDualKeygen(sessionIDs) - logger.Info("coordinator expanded keygen request without protocol", + logger.Info("coordinator expanded keygen request to both protocols", "operation", string(op), "sessions", strings.Join(sessionIDs, ","), ) @@ -338,6 +353,13 @@ func isProtocolUnspecified(protocol sdkprotocol.ProtocolType) bool { return protocol == sdkprotocol.ProtocolTypeUnspecified || string(protocol) == "" } +func isBothKeygenProtocol(protocol sdkprotocol.ProtocolType) bool { + if isProtocolUnspecified(protocol) { + return true + } + return strings.EqualFold(strings.TrimSpace(string(protocol)), "both") +} + func (c *Coordinator) advance(ctx context.Context, session *Session, event *sdkprotocol.SessionEvent) error { switch session.State { case SessionWaitingParticipants: @@ -363,6 +385,15 @@ func (c *Coordinator) advance(ctx context.Context, session *Session, event *sdkp if err := c.persistKeyInfoIfNeeded(session, result); err != nil { return c.failSession(ctx, session, ErrorCodeInternal, err.Error()) } + if session.Op == OperationKeygen { + if completion, ready, isDual := c.recordDualKeygenResult(session, result); isDual { + if !ready { + session.UpdatedAt = c.now() + return c.store.Save(ctx, session) + } + return c.completeDualKeygen(ctx, completion) + } + } now := c.now() session.State = SessionCompleted session.ResultHash = resultHash @@ -409,6 +440,96 @@ func (c *Coordinator) persistKeyInfoIfNeeded(session *Session, result *sdkprotoc return nil } +func (c *Coordinator) registerDualKeygen(sessionIDs []string) { + if c == nil || len(sessionIDs) == 0 { + return + } + group := &dualKeygenGroup{ + sessionIDs: make(map[sdkprotocol.ProtocolType]string, len(sessionIDs)), + results: make(map[sdkprotocol.ProtocolType][]byte, len(sessionIDs)), + } + for _, sessionID := range sessionIDs { + session, ok := c.store.Get(context.Background(), sessionID) + if !ok || session == nil || session.Start == nil { + continue + } + group.sessionIDs[session.Start.Protocol] = sessionID + } + c.dualKeygenMu.Lock() + defer c.dualKeygenMu.Unlock() + for _, sessionID := range sessionIDs { + c.dualKeygen[sessionID] = group + } +} + +func (c *Coordinator) recordDualKeygenResult(session *Session, result *sdkprotocol.Result) (*dualKeygenCompletion, bool, bool) { + if c == nil || session == nil || session.Start == nil || result == nil || result.KeyShare == nil { + return nil, false, false + } + c.dualKeygenMu.Lock() + defer c.dualKeygenMu.Unlock() + group, ok := c.dualKeygen[session.ID] + if !ok { + return nil, false, false + } + group.results[session.Start.Protocol] = keySharePublicKeyForProtocol(result.KeyShare, session.Start.Protocol) + ecdsaPubKey := group.results[sdkprotocol.ProtocolTypeECDSA] + eddsaPubKey := group.results[sdkprotocol.ProtocolTypeEdDSA] + if len(ecdsaPubKey) == 0 || len(eddsaPubKey) == 0 { + return nil, false, true + } + walletID := keygenWalletID(session.Start) + if walletID == "" { + walletID = result.KeyShare.KeyID + } + aggregate := &sdkprotocol.Result{ + KeyShare: &sdkprotocol.KeyShareResult{ + KeyID: walletID, + ECDSAPubKey: append([]byte(nil), ecdsaPubKey...), + EDDSAPubKey: append([]byte(nil), eddsaPubKey...), + }, + } + for _, sessionID := range group.sessionIDs { + delete(c.dualKeygen, sessionID) + } + sessionIDs := make([]string, 0, len(group.sessionIDs)) + for _, sessionID := range group.sessionIDs { + sessionIDs = append(sessionIDs, sessionID) + } + sort.Strings(sessionIDs) + return &dualKeygenCompletion{result: aggregate, sessionIDs: sessionIDs}, true, true +} + +func (c *Coordinator) completeDualKeygen(ctx context.Context, completion *dualKeygenCompletion) error { + if completion == nil { + return fmt.Errorf("missing dual keygen completion") + } + result := completion.result + if result == nil || result.KeyShare == nil { + return fmt.Errorf("missing dual keygen result") + } + now := c.now() + resultHash := canonicalResultHash(result) + for _, sessionID := range completion.sessionIDs { + session, ok := c.store.Get(ctx, sessionID) + if !ok { + return newCoordinatorError(ErrorCodeValidation, "unknown session") + } + session.State = SessionCompleted + session.ResultHash = resultHash + session.Result = cloneResult(result) + session.CompletedAt = &now + session.UpdatedAt = now + if err := c.store.Save(ctx, session); err != nil { + return err + } + if err := c.results.PublishResult(ctx, session.ID, result); err != nil { + return err + } + } + return nil +} + func (c *Coordinator) fanOutSessionStart(ctx context.Context, session *Session) error { msg := &sdkprotocol.ControlMessage{ SessionID: session.ID, @@ -548,8 +669,10 @@ func (c *Coordinator) buildCompletedResult(session *Session, event *sdkprotocol. } result = &sdkprotocol.Result{ KeyShare: &sdkprotocol.KeyShareResult{ - KeyID: in.KeyShare.KeyID, - PublicKey: append([]byte(nil), in.KeyShare.PublicKey...), + KeyID: in.KeyShare.KeyID, + PublicKey: append([]byte(nil), in.KeyShare.PublicKey...), + ECDSAPubKey: keyShareProtocolPubKey(in.KeyShare, sdkprotocol.ProtocolTypeECDSA, session.Start.Protocol), + EDDSAPubKey: keyShareProtocolPubKey(in.KeyShare, sdkprotocol.ProtocolTypeEdDSA, session.Start.Protocol), }, } case OperationSign: @@ -570,6 +693,21 @@ func (c *Coordinator) nextControlSequence(session *Session) uint64 { return session.ControlSeq } +func (c *Coordinator) GetSession(ctx context.Context, sessionID string) (*Session, bool) { + if c == nil || c.store == nil { + return nil, false + } + return c.store.Get(ctx, sessionID) +} + +func (c *Coordinator) GetSessionResult(ctx context.Context, sessionID string) (*sdkprotocol.Result, bool) { + session, ok := c.GetSession(ctx, sessionID) + if !ok { + return nil, false + } + return cloneResult(session.Result), true +} + func allParticipants(session *Session, predicate func(*ParticipantState) bool) bool { for _, participant := range session.ParticipantState { if !predicate(participant) { @@ -590,8 +728,10 @@ func canonicalOperationResultHash(op Operation, result *sdkprotocol.Result) stri } normalized := &sdkprotocol.Result{ KeyShare: &sdkprotocol.KeyShareResult{ - KeyID: result.KeyShare.KeyID, - PublicKey: append([]byte(nil), result.KeyShare.PublicKey...), + KeyID: result.KeyShare.KeyID, + PublicKey: append([]byte(nil), result.KeyShare.PublicKey...), + ECDSAPubKey: append([]byte(nil), result.KeyShare.ECDSAPubKey...), + EDDSAPubKey: append([]byte(nil), result.KeyShare.EDDSAPubKey...), }, } return canonicalResultHash(normalized) @@ -616,6 +756,30 @@ func keygenWalletID(start *sdkprotocol.SessionStart) string { return start.Keygen.KeyID } +func keySharePublicKeyForProtocol(keyShare *sdkprotocol.KeyShareResult, protocol sdkprotocol.ProtocolType) []byte { + if keyShare == nil { + return nil + } + switch protocol { + case sdkprotocol.ProtocolTypeECDSA: + if len(keyShare.ECDSAPubKey) > 0 { + return append([]byte(nil), keyShare.ECDSAPubKey...) + } + case sdkprotocol.ProtocolTypeEdDSA: + if len(keyShare.EDDSAPubKey) > 0 { + return append([]byte(nil), keyShare.EDDSAPubKey...) + } + } + return append([]byte(nil), keyShare.PublicKey...) +} + +func keyShareProtocolPubKey(keyShare *sdkprotocol.KeyShareResult, target, sessionProtocol sdkprotocol.ProtocolType) []byte { + if keyShare == nil || target != sessionProtocol { + return nil + } + return keySharePublicKeyForProtocol(keyShare, target) +} + func firstNonEmpty(values ...string) string { for _, value := range values { if value != "" { diff --git a/internal/coordinator/coordinator_test.go b/internal/coordinator/coordinator_test.go index 7e843584..8d6fd06b 100644 --- a/internal/coordinator/coordinator_test.go +++ b/internal/coordinator/coordinator_test.go @@ -313,6 +313,94 @@ func TestHandleRequestKeygenWithoutProtocolCreatesBothSessions(t *testing.T) { } } +func TestHandleRequestKeygenBothCreatesBothSessions(t *testing.T) { + ctx := context.Background() + coord, _, _, fixtures := newTestCoordinator(t) + markOnline(t, coord.presence, fixtures["p1"].pub, "p1") + markOnline(t, coord.presence, fixtures["p2"].pub, "p2") + + req := &sdkprotocol.ControlMessage{ + SessionStart: &sdkprotocol.SessionStart{ + SessionID: "client-supplied", + Protocol: sdkprotocol.ProtocolType("both"), + Operation: sdkprotocol.OperationTypeKeygen, + Threshold: 1, + Participants: []*sdkprotocol.SessionParticipant{ + {ParticipantID: "p1", PartyKey: []byte("p1"), IdentityPublicKey: fixtures["p1"].pub}, + {ParticipantID: "p2", PartyKey: []byte("p2"), IdentityPublicKey: fixtures["p2"].pub}, + }, + Keygen: &sdkprotocol.KeygenPayload{KeyID: "wallet_explicit_both"}, + }, + } + + rawReply, err := coord.HandleRequest(ctx, OperationKeygen, mustJSON(t, req)) + if err != nil { + t.Fatal(err) + } + var accepted sdkprotocol.RequestAccepted + if err := json.Unmarshal(rawReply, &accepted); err != nil { + t.Fatal(err) + } + if !accepted.Accepted { + t.Fatalf("expected request accepted") + } + active := coord.store.ListActive(ctx) + if len(active) != 2 { + t.Fatalf("expected 2 active sessions, got %d", len(active)) + } +} + +func TestKeygenBothPublishesAggregatedPubKeys(t *testing.T) { + ctx := context.Background() + coord, _, results, fixtures := newTestCoordinator(t) + markOnline(t, coord.presence, fixtures["p1"].pub, "p1") + markOnline(t, coord.presence, fixtures["p2"].pub, "p2") + + req := &sdkprotocol.ControlMessage{ + SessionStart: &sdkprotocol.SessionStart{ + SessionID: "client-supplied", + Protocol: sdkprotocol.ProtocolType("both"), + Operation: sdkprotocol.OperationTypeKeygen, + Threshold: 1, + Participants: []*sdkprotocol.SessionParticipant{ + {ParticipantID: "p1", PartyKey: []byte("p1"), IdentityPublicKey: fixtures["p1"].pub}, + {ParticipantID: "p2", PartyKey: []byte("p2"), IdentityPublicKey: fixtures["p2"].pub}, + }, + Keygen: &sdkprotocol.KeygenPayload{KeyID: "wallet_dual_result"}, + }, + } + + rawReply, err := coord.HandleRequest(ctx, OperationKeygen, mustJSON(t, req)) + if err != nil { + t.Fatal(err) + } + var accepted sdkprotocol.RequestAccepted + if err := json.Unmarshal(rawReply, &accepted); err != nil { + t.Fatal(err) + } + + sessionsByProtocol := map[sdkprotocol.ProtocolType]string{} + for _, session := range coord.store.ListActive(ctx) { + sessionsByProtocol[session.Start.Protocol] = session.ID + } + completeKeygenSession(t, coord, fixtures, sessionsByProtocol[sdkprotocol.ProtocolTypeECDSA], "wallet_dual_result", []byte("ecdsa-pub")) + if results.results[accepted.SessionID] != nil { + t.Fatalf("dual keygen result should wait for both protocols") + } + completeKeygenSession(t, coord, fixtures, sessionsByProtocol[sdkprotocol.ProtocolTypeEdDSA], "wallet_dual_result", []byte("eddsa-pub")) + + published := results.results[accepted.SessionID] + if published == nil || published.KeyShare == nil { + t.Fatalf("missing published dual keygen result") + } + if string(published.KeyShare.ECDSAPubKey) != "ecdsa-pub" { + t.Fatalf("ecdsa_pubkey = %q", string(published.KeyShare.ECDSAPubKey)) + } + if string(published.KeyShare.EDDSAPubKey) != "eddsa-pub" { + t.Fatalf("eddsa_pubkey = %q", string(published.KeyShare.EDDSAPubKey)) + } +} + func TestHandleRequestSignWithoutProtocolRejected(t *testing.T) { ctx := context.Background() coord, _, _, fixtures := newTestCoordinator(t) @@ -473,6 +561,26 @@ func emitSignedEvent(t *testing.T, coord *Coordinator, sessionID string, keys ma } } +func completeKeygenSession(t *testing.T, coord *Coordinator, keys map[string]participantKey, sessionID, walletID string, publicKey []byte) { + t.Helper() + result := &sdkprotocol.Result{ + KeyShare: &sdkprotocol.KeyShareResult{ + KeyID: walletID, + PublicKey: publicKey, + }, + } + for _, participant := range []string{"p1", "p2"} { + emitSignedEvent(t, coord, sessionID, keys, participant, &sdkprotocol.SessionEvent{PeerJoined: &sdkprotocol.PeerJoined{ParticipantID: participant}}) + emitSignedEvent(t, coord, sessionID, keys, participant, &sdkprotocol.SessionEvent{PeerReady: &sdkprotocol.PeerReady{ParticipantID: participant}}) + } + for _, participant := range []string{"p1", "p2"} { + emitSignedEvent(t, coord, sessionID, keys, participant, &sdkprotocol.SessionEvent{PeerKeyExchangeDone: &sdkprotocol.PeerKeyExchangeDone{ParticipantID: participant}}) + } + for _, participant := range []string{"p1", "p2"} { + emitSignedEvent(t, coord, sessionID, keys, participant, &sdkprotocol.SessionEvent{SessionCompleted: &sdkprotocol.SessionCompleted{Result: result}}) + } +} + func markOnline(t *testing.T, presence PresenceView, _ ed25519.PublicKey, participantID string) { t.Helper() err := presence.ApplyPresence(sdkprotocol.PresenceEvent{ diff --git a/internal/coordinator/grpc_runtime.go b/internal/coordinator/grpc_runtime.go new file mode 100644 index 00000000..8015c0a9 --- /dev/null +++ b/internal/coordinator/grpc_runtime.go @@ -0,0 +1,74 @@ +package coordinator + +import ( + "context" + "fmt" + "net" + "strings" + "sync" + "time" + + coordinatorv1 "github.com/fystack/mpcium-sdk/integrations/coordinator-grpc/proto/coordinator/v1" + "github.com/fystack/mpcium/pkg/logger" + "google.golang.org/grpc" +) + +type GRPCRuntime struct { + addr string + server *grpc.Server + listener net.Listener + mu sync.Mutex + started bool + stopOnce sync.Once + orchestration *OrchestrationGRPCServer +} + +func NewGRPCRuntime(addr string, coordination *Coordinator, pollInterval time.Duration) *GRPCRuntime { + return &GRPCRuntime{ + addr: strings.TrimSpace(addr), + server: grpc.NewServer(), + orchestration: NewOrchestrationGRPCServer(coordination, pollInterval), + } +} + +func (r *GRPCRuntime) Start(_ context.Context) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.started { + return nil + } + + listener, err := net.Listen("tcp", r.addr) + if err != nil { + return fmt.Errorf("listen grpc: %w", err) + } + + coordinatorv1.RegisterCoordinatorOrchestrationServer(r.server, r.orchestration) + r.listener = listener + r.started = true + + go func() { + logger.Info("starting grpc orchestration runtime", "addr", r.addr) + if serveErr := r.server.Serve(listener); serveErr != nil && !strings.Contains(strings.ToLower(serveErr.Error()), "closed network connection") { + logger.Error("grpc runtime stopped with error", serveErr, "addr", r.addr) + } + }() + + return nil +} + +func (r *GRPCRuntime) Stop() error { + r.stopOnce.Do(func() { + r.mu.Lock() + defer r.mu.Unlock() + if !r.started { + return + } + r.server.GracefulStop() + if r.listener != nil { + _ = r.listener.Close() + } + r.started = false + }) + return nil +} diff --git a/internal/coordinator/runtime.go b/internal/coordinator/nats_runtime.go similarity index 100% rename from internal/coordinator/runtime.go rename to internal/coordinator/nats_runtime.go diff --git a/internal/coordinator/orchestration_grpc.go b/internal/coordinator/orchestration_grpc.go new file mode 100644 index 00000000..487e5b21 --- /dev/null +++ b/internal/coordinator/orchestration_grpc.go @@ -0,0 +1,203 @@ +package coordinator + +import ( + "context" + "encoding/hex" + "encoding/json" + "fmt" + "strings" + "time" + + coordinatorv1 "github.com/fystack/mpcium-sdk/integrations/coordinator-grpc/proto/coordinator/v1" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type OrchestrationGRPCServer struct { + coordinatorv1.UnimplementedCoordinatorOrchestrationServer + coord *Coordinator + pollInterval time.Duration +} + +func NewOrchestrationGRPCServer(coord *Coordinator, pollInterval time.Duration) *OrchestrationGRPCServer { + if pollInterval <= 0 { + pollInterval = 200 * time.Millisecond + } + return &OrchestrationGRPCServer{coord: coord, pollInterval: pollInterval} +} + +func (s *OrchestrationGRPCServer) Keygen(ctx context.Context, req *coordinatorv1.KeygenRequest) (*coordinatorv1.RequestAccepted, error) { + control := &sdkprotocol.ControlMessage{ + SessionStart: &sdkprotocol.SessionStart{ + SessionID: "tmp", + Protocol: sdkprotocol.ProtocolType(strings.TrimSpace(req.GetProtocol())), + Operation: sdkprotocol.OperationTypeKeygen, + Threshold: req.GetThreshold(), + Keygen: &sdkprotocol.KeygenPayload{ + KeyID: req.GetWalletId(), + }, + }, + } + participants, err := mapParticipantsToSDK(req.GetParticipants()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid participants: %v", err) + } + control.SessionStart.Participants = participants + return s.handleOperation(ctx, OperationKeygen, control) +} + +func (s *OrchestrationGRPCServer) Sign(ctx context.Context, req *coordinatorv1.SignRequest) (*coordinatorv1.RequestAccepted, error) { + signingInput, err := decodeOptionalHex(req.GetSigningInputHex()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid signing_input_hex: %v", err) + } + derivationDelta, err := decodeOptionalHex(req.GetDerivationDeltaHex()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid derivation_delta_hex: %v", err) + } + + control := &sdkprotocol.ControlMessage{ + SessionStart: &sdkprotocol.SessionStart{ + SessionID: "tmp", + Protocol: sdkprotocol.ProtocolType(strings.TrimSpace(req.GetProtocol())), + Operation: sdkprotocol.OperationTypeSign, + Threshold: req.GetThreshold(), + Sign: &sdkprotocol.SignPayload{ + KeyID: req.GetWalletId(), + SigningInput: signingInput, + Derivation: &sdkprotocol.NonHardenedDerivation{ + Path: append([]uint32(nil), req.GetDerivationPath()...), + Delta: derivationDelta, + }, + }, + }, + } + participants, err := mapParticipantsToSDK(req.GetParticipants()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid participants: %v", err) + } + control.SessionStart.Participants = participants + if len(req.GetDerivationPath()) == 0 && len(derivationDelta) == 0 { + control.SessionStart.Sign.Derivation = nil + } + + return s.handleOperation(ctx, OperationSign, control) +} + +func (s *OrchestrationGRPCServer) WaitSessionResult(ctx context.Context, req *coordinatorv1.SessionLookup) (*coordinatorv1.SessionResult, error) { + sessionID := strings.TrimSpace(req.GetSessionId()) + if sessionID == "" { + return nil, status.Error(codes.InvalidArgument, "session_id is required") + } + + if _, ok := s.coord.GetSession(ctx, sessionID); !ok { + return nil, status.Error(codes.NotFound, "session not found") + } + + ticker := time.NewTicker(s.pollInterval) + defer ticker.Stop() + + for { + session, ok := s.coord.GetSession(ctx, sessionID) + if !ok { + return nil, status.Error(codes.NotFound, "session not found") + } + if session.State.Terminal() { + return sessionToProtoResult(session), nil + } + + select { + case <-ctx.Done(): + return nil, status.Error(codes.DeadlineExceeded, "wait session result timeout") + case <-ticker.C: + } + } +} + +func (s *OrchestrationGRPCServer) handleOperation(ctx context.Context, op Operation, msg *sdkprotocol.ControlMessage) (*coordinatorv1.RequestAccepted, error) { + raw, err := json.Marshal(msg) + if err != nil { + return nil, status.Errorf(codes.Internal, "marshal request: %v", err) + } + + replyRaw, err := s.coord.HandleRequest(ctx, op, raw) + if err != nil { + return nil, status.Errorf(codes.Internal, "handle request: %v", err) + } + + var accepted sdkprotocol.RequestAccepted + if err := json.Unmarshal(replyRaw, &accepted); err == nil && accepted.Accepted { + return &coordinatorv1.RequestAccepted{ + Accepted: true, + SessionId: accepted.SessionID, + ExpiresAt: accepted.ExpiresAt, + }, nil + } + + var rejected sdkprotocol.RequestRejected + if err := json.Unmarshal(replyRaw, &rejected); err == nil && !rejected.Accepted { + return &coordinatorv1.RequestAccepted{ + Accepted: false, + ErrorCode: rejected.ErrorCode, + ErrorMessage: rejected.ErrorMessage, + }, nil + } + + return nil, status.Error(codes.Internal, "unexpected coordinator response") +} + +func mapParticipantsToSDK(participants []*coordinatorv1.Participant) ([]*sdkprotocol.SessionParticipant, error) { + mapped := make([]*sdkprotocol.SessionParticipant, 0, len(participants)) + for _, participant := range participants { + if participant == nil { + continue + } + pubKey, err := decodeOptionalHex(participant.GetIdentityPublicKeyHex()) + if err != nil { + return nil, fmt.Errorf("participant %q identity_public_key_hex: %w", participant.GetId(), err) + } + id := strings.TrimSpace(participant.GetId()) + mapped = append(mapped, &sdkprotocol.SessionParticipant{ + ParticipantID: id, + PartyKey: []byte(id), + IdentityPublicKey: pubKey, + }) + } + return mapped, nil +} + +func sessionToProtoResult(session *Session) *coordinatorv1.SessionResult { + result := &coordinatorv1.SessionResult{ + Completed: session.State == SessionCompleted, + SessionId: session.ID, + ErrorCode: session.ErrorCode, + ErrorMessage: session.ErrorMessage, + } + if session.Result == nil { + return result + } + if session.Result.KeyShare != nil { + result.KeyId = session.Result.KeyShare.KeyID + result.PublicKeyHex = hex.EncodeToString(session.Result.KeyShare.PublicKey) + } + if session.Result.Signature != nil { + sig := session.Result.Signature + result.KeyId = sig.KeyID + result.PublicKeyHex = hex.EncodeToString(sig.PublicKey) + result.SignatureHex = hex.EncodeToString(sig.Signature) + result.SignatureRecoveryHex = hex.EncodeToString(sig.SignatureRecovery) + result.RHex = hex.EncodeToString(sig.R) + result.SHex = hex.EncodeToString(sig.S) + result.SignedInputHex = hex.EncodeToString(sig.SignedInput) + } + return result +} + +func decodeOptionalHex(value string) ([]byte, error) { + value = strings.TrimSpace(value) + if value == "" { + return nil, nil + } + return hex.DecodeString(value) +} diff --git a/internal/coordinator/store.go b/internal/coordinator/store.go index 6678f4d5..e94d8a07 100644 --- a/internal/coordinator/store.go +++ b/internal/coordinator/store.go @@ -299,7 +299,10 @@ func cloneResult(result *sdkprotocol.Result) *sdkprotocol.Result { cloned := *result if result.KeyShare != nil { keyShare := *result.KeyShare + keyShare.ShareBlob = append([]byte(nil), result.KeyShare.ShareBlob...) keyShare.PublicKey = append([]byte(nil), result.KeyShare.PublicKey...) + keyShare.ECDSAPubKey = append([]byte(nil), result.KeyShare.ECDSAPubKey...) + keyShare.EDDSAPubKey = append([]byte(nil), result.KeyShare.EDDSAPubKey...) cloned.KeyShare = &keyShare } if result.Signature != nil { From 18af50310e1eb4076c5eb9a6a1d7c966782c0e9f Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 22 Apr 2026 15:18:26 +0700 Subject: [PATCH 17/23] Add gRPC transport support to coordinator client. Allow the coordinator client and example flows to submit keygen/sign requests and await results over gRPC while keeping NATS transport compatibility and adding dedicated gRPC client tests. Made-with: Cursor --- examples/coordinatorclient-keygen/main.go | 4 +- examples/coordinatorclient-sign/main.go | 22 +-- pkg/coordinatorclient/client.go | 214 ++++++++++++++++++++-- pkg/coordinatorclient/client_test.go | 173 +++++++++++++++++ 4 files changed, 378 insertions(+), 35 deletions(-) create mode 100644 pkg/coordinatorclient/client_test.go diff --git a/examples/coordinatorclient-keygen/main.go b/examples/coordinatorclient-keygen/main.go index 9e5604d4..151a0062 100644 --- a/examples/coordinatorclient-keygen/main.go +++ b/examples/coordinatorclient-keygen/main.go @@ -14,8 +14,8 @@ import ( func main() { client, err := coordinatorclient.New(coordinatorclient.Config{ - NATSURL: "nats://127.0.0.1:4222", - Timeout: 5 * time.Second, + GRPCAddress: "127.0.0.1:50051", + Timeout: 5 * time.Second, }) if err != nil { log.Fatalf("create coordinator client: %v", err) diff --git a/examples/coordinatorclient-sign/main.go b/examples/coordinatorclient-sign/main.go index 9e5d7c62..355195f2 100644 --- a/examples/coordinatorclient-sign/main.go +++ b/examples/coordinatorclient-sign/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "crypto/ed25519" "encoding/hex" "fmt" "log" @@ -14,8 +13,8 @@ import ( func main() { client, err := coordinatorclient.New(coordinatorclient.Config{ - NATSURL: "nats://127.0.0.1:4222", - Timeout: 5 * time.Second, + GRPCAddress: "127.0.0.1:50051", + Timeout: 5 * time.Second, }) if err != nil { log.Fatalf("create coordinator client: %v", err) @@ -33,7 +32,7 @@ func main() { }, } - walletID := "wallet_f8029c22-a222-4828-b135-8aacc021d716" + walletID := "wallet_eb791062-d9b4-4ed0-87a0-793f8f7370d3" message := []byte("deadbeef") protocol := sdkprotocol.ProtocolTypeEdDSA @@ -78,18 +77,3 @@ func mustDecodeHex(value string) []byte { } return decoded } - -func mustPublicKeyFromPrivateHex(privateKeyHex string) []byte { - privateRaw := mustDecodeHex(privateKeyHex) - var private ed25519.PrivateKey - switch len(privateRaw) { - case ed25519.PrivateKeySize: - private = ed25519.PrivateKey(privateRaw) - case ed25519.SeedSize: - private = ed25519.NewKeyFromSeed(privateRaw) - default: - panic(fmt.Sprintf("invalid ed25519 private key length: %d", len(privateRaw))) - } - public := private.Public().(ed25519.PublicKey) - return append([]byte(nil), public...) -} diff --git a/pkg/coordinatorclient/client.go b/pkg/coordinatorclient/client.go index 3bfd233a..7a9b9f6c 100644 --- a/pkg/coordinatorclient/client.go +++ b/pkg/coordinatorclient/client.go @@ -2,12 +2,16 @@ package coordinatorclient import ( "context" + "encoding/hex" "encoding/json" "fmt" "time" + coordinatorv1 "github.com/fystack/mpcium-sdk/integrations/coordinator-grpc/proto/coordinator/v1" sdkprotocol "github.com/fystack/mpcium-sdk/protocol" "github.com/nats-io/nats.go" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" ) const ( @@ -17,15 +21,26 @@ const ( ) type Client struct { - nc *nats.Conn - timeout time.Duration + nc *nats.Conn + grpcConn *grpc.ClientConn + grpcClient coordinatorv1.CoordinatorOrchestrationClient + timeout time.Duration + transport transportType } type Config struct { - NATSURL string - Timeout time.Duration + NATSURL string + GRPCAddress string + Timeout time.Duration } +type transportType string + +const ( + transportNATS transportType = "nats" + transportGRPC transportType = "grpc" +) + type KeygenParticipant struct { ID string IdentityPublicKey []byte @@ -50,12 +65,27 @@ type SignRequest struct { } func New(cfg Config) (*Client, error) { - if cfg.NATSURL == "" { - cfg.NATSURL = nats.DefaultURL - } if cfg.Timeout <= 0 { cfg.Timeout = 5 * time.Second } + if cfg.GRPCAddress != "" { + conn, err := grpc.Dial( + cfg.GRPCAddress, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + return nil, fmt.Errorf("connect to gRPC coordinator: %w", err) + } + return &Client{ + grpcConn: conn, + grpcClient: coordinatorv1.NewCoordinatorOrchestrationClient(conn), + timeout: cfg.Timeout, + transport: transportGRPC, + }, nil + } + if cfg.NATSURL == "" { + cfg.NATSURL = nats.DefaultURL + } nc, err := nats.Connect(cfg.NATSURL) if err != nil { @@ -63,19 +93,28 @@ func New(cfg Config) (*Client, error) { } return &Client{ - nc: nc, - timeout: cfg.Timeout, + nc: nc, + timeout: cfg.Timeout, + transport: transportNATS, }, nil } func (c *Client) Close() { - if c == nil || c.nc == nil { + if c == nil { return } - c.nc.Close() + if c.nc != nil { + c.nc.Close() + } + if c.grpcConn != nil { + _ = c.grpcConn.Close() + } } func (c *Client) PublishPresence(ctx context.Context, peerID string) error { + if c.transport != transportNATS { + return fmt.Errorf("presence publishing is supported only in NATS mode") + } if peerID == "" { return fmt.Errorf("peerID is required") } @@ -110,6 +149,10 @@ func (c *Client) RequestKeygen(ctx context.Context, req KeygenRequest) (*sdkprot defer cancel() } + if c.transport == transportGRPC { + return c.requestKeygenGRPC(ctx, req) + } + msg := &sdkprotocol.ControlMessage{ SessionStart: &sdkprotocol.SessionStart{ SessionID: "tmp", // coordinator replaces this value when accepting request @@ -123,7 +166,7 @@ func (c *Client) RequestKeygen(ctx context.Context, req KeygenRequest) (*sdkprot }, } - return c.requestSession(ctx, requestKeygenSubject, msg, "keygen") + return c.requestSessionNATS(ctx, requestKeygenSubject, msg, "keygen") } func (c *Client) RequestSign(ctx context.Context, req SignRequest) (*sdkprotocol.RequestAccepted, error) { @@ -136,6 +179,10 @@ func (c *Client) RequestSign(ctx context.Context, req SignRequest) (*sdkprotocol defer cancel() } + if c.transport == transportGRPC { + return c.requestSignGRPC(ctx, req) + } + msg := &sdkprotocol.ControlMessage{ SessionStart: &sdkprotocol.SessionStart{ SessionID: "tmp", // coordinator replaces this value when accepting request @@ -151,10 +198,10 @@ func (c *Client) RequestSign(ctx context.Context, req SignRequest) (*sdkprotocol }, } - return c.requestSession(ctx, requestSignSubject, msg, "sign") + return c.requestSessionNATS(ctx, requestSignSubject, msg, "sign") } -func (c *Client) requestSession(ctx context.Context, subject string, msg *sdkprotocol.ControlMessage, action string) (*sdkprotocol.RequestAccepted, error) { +func (c *Client) requestSessionNATS(ctx context.Context, subject string, msg *sdkprotocol.ControlMessage, action string) (*sdkprotocol.RequestAccepted, error) { payload, err := json.Marshal(msg) if err != nil { return nil, fmt.Errorf("marshal %s request: %w", action, err) @@ -195,6 +242,10 @@ func (c *Client) WaitSessionResult(ctx context.Context, sessionID string) (*sdkp defer cancel() } + if c.transport == transportGRPC { + return c.waitSessionResultGRPC(ctx, sessionID) + } + subject := fmt.Sprintf("%s.session.%s.result", topicPrefix, sessionID) sub, err := c.nc.SubscribeSync(subject) if err != nil { @@ -218,6 +269,141 @@ func (c *Client) WaitSessionResult(ctx context.Context, sessionID string) (*sdkp return result, nil } +func (c *Client) requestKeygenGRPC(ctx context.Context, req KeygenRequest) (*sdkprotocol.RequestAccepted, error) { + grpcReq := &coordinatorv1.KeygenRequest{ + Protocol: string(req.Protocol), + Threshold: req.Threshold, + WalletId: req.WalletID, + Participants: mapParticipantsToProto(req.Participants), + } + resp, err := c.grpcClient.Keygen(ctx, grpcReq) + if err != nil { + return nil, fmt.Errorf("request keygen: %w", err) + } + if !resp.GetAccepted() { + return nil, fmt.Errorf("coordinator rejected request (%s): %s", resp.GetErrorCode(), resp.GetErrorMessage()) + } + return &sdkprotocol.RequestAccepted{ + Accepted: true, + SessionID: resp.GetSessionId(), + ExpiresAt: resp.GetExpiresAt(), + }, nil +} + +func (c *Client) requestSignGRPC(ctx context.Context, req SignRequest) (*sdkprotocol.RequestAccepted, error) { + grpcReq := &coordinatorv1.SignRequest{ + Protocol: string(req.Protocol), + Threshold: req.Threshold, + WalletId: req.WalletID, + SigningInputHex: hex.EncodeToString(req.SigningInput), + Participants: mapParticipantsToProto(req.Participants), + } + if req.Derivation != nil { + grpcReq.DerivationPath = append([]uint32(nil), req.Derivation.Path...) + grpcReq.DerivationDeltaHex = hex.EncodeToString(req.Derivation.Delta) + } + + resp, err := c.grpcClient.Sign(ctx, grpcReq) + if err != nil { + return nil, fmt.Errorf("request sign: %w", err) + } + if !resp.GetAccepted() { + return nil, fmt.Errorf("coordinator rejected request (%s): %s", resp.GetErrorCode(), resp.GetErrorMessage()) + } + return &sdkprotocol.RequestAccepted{ + Accepted: true, + SessionID: resp.GetSessionId(), + ExpiresAt: resp.GetExpiresAt(), + }, nil +} + +func (c *Client) waitSessionResultGRPC(ctx context.Context, sessionID string) (*sdkprotocol.Result, error) { + resp, err := c.grpcClient.WaitSessionResult(ctx, &coordinatorv1.SessionLookup{SessionId: sessionID}) + if err != nil { + return nil, fmt.Errorf("wait session result: %w", err) + } + if !resp.GetCompleted() { + return nil, fmt.Errorf("session failed (%s): %s", resp.GetErrorCode(), resp.GetErrorMessage()) + } + + if resp.GetSignatureHex() != "" || resp.GetSignatureRecoveryHex() != "" || resp.GetRHex() != "" || resp.GetSHex() != "" { + signature, err := mapProtoSignature(resp) + if err != nil { + return nil, err + } + return &sdkprotocol.Result{Signature: signature}, nil + } + + publicKey, err := decodeHexField("public_key_hex", resp.GetPublicKeyHex()) + if err != nil { + return nil, err + } + return &sdkprotocol.Result{ + KeyShare: &sdkprotocol.KeyShareResult{ + KeyID: resp.GetKeyId(), + PublicKey: publicKey, + }, + }, nil +} + +func mapParticipantsToProto(participants []KeygenParticipant) []*coordinatorv1.Participant { + mapped := make([]*coordinatorv1.Participant, 0, len(participants)) + for _, participant := range participants { + mapped = append(mapped, &coordinatorv1.Participant{ + Id: participant.ID, + IdentityPublicKeyHex: hex.EncodeToString(participant.IdentityPublicKey), + }) + } + return mapped +} + +func mapProtoSignature(resp *coordinatorv1.SessionResult) (*sdkprotocol.SignatureResult, error) { + signature, err := decodeHexField("signature_hex", resp.GetSignatureHex()) + if err != nil { + return nil, err + } + recovery, err := decodeHexField("signature_recovery_hex", resp.GetSignatureRecoveryHex()) + if err != nil { + return nil, err + } + r, err := decodeHexField("r_hex", resp.GetRHex()) + if err != nil { + return nil, err + } + s, err := decodeHexField("s_hex", resp.GetSHex()) + if err != nil { + return nil, err + } + signedInput, err := decodeHexField("signed_input_hex", resp.GetSignedInputHex()) + if err != nil { + return nil, err + } + publicKey, err := decodeHexField("public_key_hex", resp.GetPublicKeyHex()) + if err != nil { + return nil, err + } + return &sdkprotocol.SignatureResult{ + KeyID: resp.GetKeyId(), + Signature: signature, + SignatureRecovery: recovery, + R: r, + S: s, + SignedInput: signedInput, + PublicKey: publicKey, + }, nil +} + +func decodeHexField(name, value string) ([]byte, error) { + if value == "" { + return nil, nil + } + decoded, err := hex.DecodeString(value) + if err != nil { + return nil, fmt.Errorf("decode %s: %w", name, err) + } + return decoded, nil +} + func validateKeygenRequest(req KeygenRequest) error { if req.WalletID == "" { return fmt.Errorf("walletID is required") diff --git a/pkg/coordinatorclient/client_test.go b/pkg/coordinatorclient/client_test.go new file mode 100644 index 00000000..86b3016a --- /dev/null +++ b/pkg/coordinatorclient/client_test.go @@ -0,0 +1,173 @@ +package coordinatorclient + +import ( + "context" + "encoding/hex" + "net" + "strings" + "testing" + "time" + + coordinatorv1 "github.com/fystack/mpcium-sdk/integrations/coordinator-grpc/proto/coordinator/v1" + sdkprotocol "github.com/fystack/mpcium-sdk/protocol" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" +) + +const bufSize = 1024 * 1024 + +type fakeCoordinatorServer struct { + coordinatorv1.UnimplementedCoordinatorOrchestrationServer + keygenResp *coordinatorv1.RequestAccepted + signResp *coordinatorv1.RequestAccepted + results map[string]*coordinatorv1.SessionResult +} + +func (s *fakeCoordinatorServer) Keygen(context.Context, *coordinatorv1.KeygenRequest) (*coordinatorv1.RequestAccepted, error) { + if s.keygenResp != nil { + return s.keygenResp, nil + } + return &coordinatorv1.RequestAccepted{ + Accepted: true, + SessionId: "sess_keygen", + ExpiresAt: "2026-04-22T10:00:00Z", + }, nil +} + +func (s *fakeCoordinatorServer) Sign(context.Context, *coordinatorv1.SignRequest) (*coordinatorv1.RequestAccepted, error) { + if s.signResp != nil { + return s.signResp, nil + } + return &coordinatorv1.RequestAccepted{ + Accepted: true, + SessionId: "sess_sign", + ExpiresAt: "2026-04-22T10:00:00Z", + }, nil +} + +func (s *fakeCoordinatorServer) WaitSessionResult(_ context.Context, req *coordinatorv1.SessionLookup) (*coordinatorv1.SessionResult, error) { + return s.results[req.GetSessionId()], nil +} + +func newTestGRPCClient(t *testing.T, fake *fakeCoordinatorServer) (*Client, func()) { + t.Helper() + listener := bufconn.Listen(bufSize) + server := grpc.NewServer() + coordinatorv1.RegisterCoordinatorOrchestrationServer(server, fake) + go func() { + _ = server.Serve(listener) + }() + + conn, err := grpc.DialContext( + context.Background(), + "bufnet", + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { + return listener.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatal(err) + } + client := &Client{ + grpcConn: conn, + grpcClient: coordinatorv1.NewCoordinatorOrchestrationClient(conn), + timeout: time.Second, + transport: transportGRPC, + } + cleanup := func() { + client.Close() + server.Stop() + _ = listener.Close() + } + return client, cleanup +} + +func TestGRPCClientRequestKeygenAndSignResponses(t *testing.T) { + client, cleanup := newTestGRPCClient(t, &fakeCoordinatorServer{ + signResp: &coordinatorv1.RequestAccepted{ + Accepted: false, + ErrorCode: "validation", + ErrorMessage: "protocol is required", + }, + }) + defer cleanup() + + accepted, err := client.RequestKeygen(context.Background(), KeygenRequest{ + Protocol: sdkprotocol.ProtocolTypeECDSA, + Threshold: 1, + WalletID: "wallet-1", + Participants: []KeygenParticipant{ + {ID: "p1", IdentityPublicKey: []byte("pub-1")}, + {ID: "p2", IdentityPublicKey: []byte("pub-2")}, + }, + }) + if err != nil { + t.Fatal(err) + } + if !accepted.Accepted || accepted.SessionID != "sess_keygen" || accepted.ExpiresAt == "" { + t.Fatalf("unexpected accepted response: %+v", accepted) + } + + _, err = client.RequestSign(context.Background(), SignRequest{ + Protocol: sdkprotocol.ProtocolTypeECDSA, + Threshold: 1, + WalletID: "wallet-1", + SigningInput: []byte("message"), + Participants: []SignParticipant{ + {ID: "p1", IdentityPublicKey: []byte("pub-1")}, + {ID: "p2", IdentityPublicKey: []byte("pub-2")}, + }, + }) + if err == nil || !strings.Contains(err.Error(), "coordinator rejected request (validation): protocol is required") { + t.Fatalf("unexpected sign error: %v", err) + } +} + +func TestGRPCClientWaitSessionResultMapsKeygenAndSignature(t *testing.T) { + signature := []byte("signature") + recovery := []byte("recovery") + r := []byte("r") + s := []byte("s") + signedInput := []byte("message") + publicKey := []byte("public-key") + client, cleanup := newTestGRPCClient(t, &fakeCoordinatorServer{ + results: map[string]*coordinatorv1.SessionResult{ + "sess_keygen": { + Completed: true, + SessionId: "sess_keygen", + KeyId: "wallet-1", + PublicKeyHex: hex.EncodeToString(publicKey), + }, + "sess_sign": { + Completed: true, + SessionId: "sess_sign", + KeyId: "wallet-1", + PublicKeyHex: hex.EncodeToString(publicKey), + SignatureHex: hex.EncodeToString(signature), + SignatureRecoveryHex: hex.EncodeToString(recovery), + RHex: hex.EncodeToString(r), + SHex: hex.EncodeToString(s), + SignedInputHex: hex.EncodeToString(signedInput), + }, + }, + }) + defer cleanup() + + keygenResult, err := client.WaitSessionResult(context.Background(), "sess_keygen") + if err != nil { + t.Fatal(err) + } + if keygenResult.KeyShare == nil || keygenResult.KeyShare.KeyID != "wallet-1" || string(keygenResult.KeyShare.PublicKey) != string(publicKey) { + t.Fatalf("unexpected keygen result: %+v", keygenResult) + } + + signResult, err := client.WaitSessionResult(context.Background(), "sess_sign") + if err != nil { + t.Fatal(err) + } + if signResult.Signature == nil || string(signResult.Signature.Signature) != string(signature) || string(signResult.Signature.PublicKey) != string(publicKey) { + t.Fatalf("unexpected sign result: %+v", signResult) + } +} From e630eeb7b5fea75a852187fd112ac54395176bff Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 22 Apr 2026 15:18:33 +0700 Subject: [PATCH 18/23] Update coordinator docs for mixed gRPC and NATS flow. Clarify that gRPC is used for client orchestration while participant transport, control fan-out, and session/result messaging remain on NATS and relay paths. Made-with: Cursor --- cmd/mpcium-coordinator/README.md | 23 +++++++++-------------- docs/local-coordinator-relay-cosigners.md | 4 +++- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/cmd/mpcium-coordinator/README.md b/cmd/mpcium-coordinator/README.md index fd3b3e73..a036e220 100644 --- a/cmd/mpcium-coordinator/README.md +++ b/cmd/mpcium-coordinator/README.md @@ -5,33 +5,28 @@ This runtime implements the v1 control-plane coordinator from `docs/architecture It owns: - NATS request-reply intake on `mpc.v1.request.keygen`, `mpc.v1.request.sign`, and `mpc.v1.request.reshare` +- optional plaintext gRPC client orchestration API for `Keygen`, `Sign`, and `WaitSessionResult` - pinned participant validation - session lifecycle state - signed control fan-out to `mpc.v1.peer..control` - participant event intake from `mpc.v1.session..event` - terminal result publishing to `mpc.v1.session..result` -It does not implement relay, MQTT mailboxing, p2p MPC packet routing, or legacy `mpc.*` subjects. +It does not implement relay, MQTT mailboxing, p2p MPC packet routing, gRPC participant transport, or legacy `mpc.*` subjects. NATS is still required for cosigner presence, control fan-out, participant session events, and result publishing. ## Run ```sh -go run ./cmd/mpcium-coordinator \ - --nats-url nats://127.0.0.1:4222 \ - --coordinator-id coordinator-01 \ - --coordinator-private-key-hex \ - --snapshot-dir ./coordinator-snapshots \ - --relay-available=true +go run ./cmd/mpcium-coordinator/main.go -c coordinator.config.yaml ``` -The same settings can be provided through environment variables: +The runtime config includes: -- `NATS_URL` -- `COORDINATOR_ID` -- `COORDINATOR_PRIVATE_KEY_HEX` -- `COORDINATOR_SNAPSHOT_DIR` -- `COORDINATOR_RELAY_AVAILABLE` -- `COORDINATOR_TICK_INTERVAL` +- `nats.url`: NATS server used for participant transport. +- `grpc.enabled`: enables the client orchestration API. +- `grpc.listen_addr`: plaintext gRPC listen address. +- `grpc.poll_interval`: result wait polling interval. +- `coordinator.id`, `coordinator.private_key_hex`, and `coordinator.snapshot_dir`. Each operation has its own request shape. The operation comes from the NATS subject, so a sign request to `mpc.v1.request.sign` looks like: diff --git a/docs/local-coordinator-relay-cosigners.md b/docs/local-coordinator-relay-cosigners.md index 69ac1ab4..e7c1bbf6 100644 --- a/docs/local-coordinator-relay-cosigners.md +++ b/docs/local-coordinator-relay-cosigners.md @@ -133,7 +133,7 @@ Open one terminal per process. go run ./cmd/mpcium-coordinator/main.go -c coordinator.config.yaml ``` -Expected logs include coordinator request, presence, and session event subscriptions. +Expected logs include coordinator request, presence, and session event subscriptions. If `grpc.enabled` is true, the coordinator also listens on `grpc.listen_addr` for the plaintext client orchestration API. NATS is still required for cosigner control messages, presence, session events, and result publishing. ### 2. Relay @@ -197,6 +197,8 @@ After both cosigners are online, run: go run ./examples/coordinatorclient-keygen ``` +The example submits `Keygen` over gRPC and waits for the terminal result over gRPC. The actual participant session still flows through NATS/relay transport. + Expected output: ```txt From 2373530151a8bba8873e6cfcc2228e146e569aba Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 22 Apr 2026 15:23:10 +0700 Subject: [PATCH 19/23] Add ECDSA and EdDSA public key handling in gRPC session results Enhance the sessionToProtoResult function to include ECDSA and EdDSA public keys in the gRPC response. Update the coordinator client to decode and return these keys in the session result structure, improving the key sharing capabilities of the client. --- internal/coordinator/orchestration_grpc.go | 2 ++ pkg/coordinatorclient/client.go | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/internal/coordinator/orchestration_grpc.go b/internal/coordinator/orchestration_grpc.go index 487e5b21..a5c1278f 100644 --- a/internal/coordinator/orchestration_grpc.go +++ b/internal/coordinator/orchestration_grpc.go @@ -180,6 +180,8 @@ func sessionToProtoResult(session *Session) *coordinatorv1.SessionResult { if session.Result.KeyShare != nil { result.KeyId = session.Result.KeyShare.KeyID result.PublicKeyHex = hex.EncodeToString(session.Result.KeyShare.PublicKey) + result.EcdsaPubkey = hex.EncodeToString(session.Result.KeyShare.ECDSAPubKey) + result.EddsaPubkey = hex.EncodeToString(session.Result.KeyShare.EDDSAPubKey) } if session.Result.Signature != nil { sig := session.Result.Signature diff --git a/pkg/coordinatorclient/client.go b/pkg/coordinatorclient/client.go index 7a9b9f6c..21122b54 100644 --- a/pkg/coordinatorclient/client.go +++ b/pkg/coordinatorclient/client.go @@ -338,10 +338,20 @@ func (c *Client) waitSessionResultGRPC(ctx context.Context, sessionID string) (* if err != nil { return nil, err } + ecdsaPubKey, err := decodeHexField("ecdsa_pubkey", resp.GetEcdsaPubkey()) + if err != nil { + return nil, err + } + eddsaPubKey, err := decodeHexField("eddsa_pubkey", resp.GetEddsaPubkey()) + if err != nil { + return nil, err + } return &sdkprotocol.Result{ KeyShare: &sdkprotocol.KeyShareResult{ - KeyID: resp.GetKeyId(), - PublicKey: publicKey, + KeyID: resp.GetKeyId(), + PublicKey: publicKey, + ECDSAPubKey: ecdsaPubKey, + EDDSAPubKey: eddsaPubKey, }, }, nil } From 636db7167bc93fef3d0fd9bccdf5662823bfab2d Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 22 Apr 2026 16:55:09 +0700 Subject: [PATCH 20/23] Improve dual-keygen orchestration with seeded key reuse. Enhance coordinator dual-protocol keygen flow to reuse existing wallet protocol keys, publish a stable aggregate result, and expand integration-style tests for gRPC and NATS request/result behavior. Made-with: Cursor --- go.mod | 4 + go.sum | 9 + internal/coordinator/coordinator.go | 84 +++++- internal/coordinator/coordinator_test.go | 359 +++++++++++++++++++++++ 4 files changed, 450 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 547ab5e8..3a03dbba 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/hashicorp/consul/api v1.33.2 github.com/mitchellh/mapstructure v1.5.0 github.com/mochi-mqtt/server/v2 v2.7.9 + github.com/nats-io/nats-server/v2 v2.10.29 github.com/nats-io/nats.go v1.48.0 github.com/rs/zerolog v1.34.0 github.com/samber/lo v1.52.0 @@ -32,8 +33,11 @@ require ( require ( github.com/gorilla/websocket v1.5.3 // indirect + github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76 // indirect + github.com/nats-io/jwt/v2 v2.8.1 // indirect github.com/rs/xid v1.6.0 // indirect golang.org/x/sync v0.20.0 // indirect + golang.org/x/time v0.15.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260120221211-b8f7ae30c516 // indirect ) diff --git a/go.sum b/go.sum index 30dfee1b..d543a2e5 100644 --- a/go.sum +++ b/go.sum @@ -257,6 +257,8 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/miekg/dns v1.1.56 h1:5imZaSeoRNvpM9SzWNhEcP9QliKiz20/dA2QabIGVnE= github.com/miekg/dns v1.1.56/go.mod h1:cRm6Oo2C8TY9ZS/TqsSrseAcncm74lfK5G+ikN2SWWY= +github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76 h1:KGuD/pM2JpL9FAYvBrnBBeENKZNh6eNtjqytV6TYjnk= +github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76/go.mod h1:GGYsuwP/fPD6Y9hMiXuapVvlIUEhFhMTh0rxU3ik1LQ= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= @@ -269,6 +271,10 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/nats-io/jwt/v2 v2.8.1 h1:V0xpGuD/N8Mi+fQNDynXohVvp7ZztevW5io8CUWlPmU= +github.com/nats-io/jwt/v2 v2.8.1/go.mod h1:nWnOEEiVMiKHQpnAy4eXlizVEtSfzacZ1Q43LIRavZg= +github.com/nats-io/nats-server/v2 v2.10.29 h1:IJ8TrZaiMZUrPGavMvP7hNAE9lYnHTThuthpwlsdlbc= +github.com/nats-io/nats-server/v2 v2.10.29/go.mod h1:VhRCs7C6pF/6FanJcOdr1R6jDb7yMBK3I630WN62FDw= github.com/nats-io/nats.go v1.48.0 h1:pSFyXApG+yWU/TgbKCjmm5K4wrHu86231/w84qRVR+U= github.com/nats-io/nats.go v1.48.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g= github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4= @@ -487,6 +493,7 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -505,6 +512,8 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= diff --git a/internal/coordinator/coordinator.go b/internal/coordinator/coordinator.go index f055992e..ccb39eee 100644 --- a/internal/coordinator/coordinator.go +++ b/internal/coordinator/coordinator.go @@ -41,6 +41,11 @@ type dualKeygenCompletion struct { sessionIDs []string } +type dualKeygenPlan struct { + protocols []sdkprotocol.ProtocolType + seeded map[sdkprotocol.ProtocolType][]byte +} + func NewCoordinator(cfg CoordinatorConfig) (*Coordinator, error) { cfg = applyDefaults(cfg) if err := cfg.Validate(); err != nil { @@ -72,19 +77,25 @@ func (c *Coordinator) HandleRequest(ctx context.Context, op Operation, raw []byt } // Backward compatibility: keygen without protocol means dispatch both ECDSA and EdDSA sessions. if op == OperationKeygen && req.SessionStart != nil && isBothKeygenProtocol(req.SessionStart.Protocol) { - protocols := []sdkprotocol.ProtocolType{sdkprotocol.ProtocolTypeECDSA, sdkprotocol.ProtocolTypeEdDSA} - sessionIDs := make([]string, 0, len(protocols)) + plan, err := c.planDualKeygen(req.SessionStart) + if err != nil { + return rejectFromError(err), nil + } + sessionIDs := make([]string, 0, len(plan.protocols)) var firstAccepted *sdkprotocol.RequestAccepted var firstErr error - for _, protocol := range protocols { + for _, protocol := range plan.protocols { cloned := cloneSessionStart(req.SessionStart) cloned.Protocol = protocol accepted, err := c.acceptRequest(ctx, op, &sdkprotocol.ControlMessage{SessionStart: cloned}) if err != nil { var coordErr *CoordinatorError if AsCoordinatorError(err, &coordErr) && coordErr.Code == ErrorCodeConflict { - // Allow fanout to continue: one protocol might already exist while the other doesn't. + if seed, seedErr := c.existingDualKeygenSeed(req.SessionStart, protocol); seedErr == nil && len(seed) > 0 { + plan.seeded[protocol] = seed + continue + } if firstErr == nil { firstErr = err } @@ -103,7 +114,7 @@ func (c *Coordinator) HandleRequest(ctx context.Context, op Operation, raw []byt } return reject(ErrorCodeConflict, "no keygen sessions created"), nil } - c.registerDualKeygen(sessionIDs) + c.registerDualKeygen(sessionIDs, plan.seeded) logger.Info("coordinator expanded keygen request to both protocols", "operation", string(op), @@ -119,6 +130,52 @@ func (c *Coordinator) HandleRequest(ctx context.Context, op Operation, raw []byt return json.Marshal(accepted) } +func (c *Coordinator) planDualKeygen(start *sdkprotocol.SessionStart) (*dualKeygenPlan, error) { + protocols := []sdkprotocol.ProtocolType{sdkprotocol.ProtocolTypeECDSA, sdkprotocol.ProtocolTypeEdDSA} + plan := &dualKeygenPlan{ + protocols: make([]sdkprotocol.ProtocolType, 0, len(protocols)), + seeded: make(map[sdkprotocol.ProtocolType][]byte, len(protocols)), + } + if c.keyInfoStore == nil { + plan.protocols = append(plan.protocols, protocols...) + return plan, nil + } + + for _, protocol := range protocols { + seed, err := c.existingDualKeygenSeed(start, protocol) + if err != nil { + return nil, err + } + if len(seed) > 0 { + plan.seeded[protocol] = seed + continue + } + plan.protocols = append(plan.protocols, protocol) + } + if len(plan.protocols) == 0 { + return nil, newCoordinatorError(ErrorCodeConflict, "wallet keys already exist") + } + return plan, nil +} + +func (c *Coordinator) existingDualKeygenSeed(start *sdkprotocol.SessionStart, protocol sdkprotocol.ProtocolType) ([]byte, error) { + if c.keyInfoStore == nil { + return nil, nil + } + walletID := keygenWalletID(start) + if walletID == "" { + return nil, newCoordinatorError(ErrorCodeValidation, "wallet_id is required") + } + info, exists := c.keyInfoStore.Get(walletID, string(protocol)) + if !exists { + return nil, nil + } + if len(info.PublicKey) == 0 { + return nil, newCoordinatorError(ErrorCodeConflict, "wallet key already exists without public key") + } + return append([]byte(nil), info.PublicKey...), nil +} + func (c *Coordinator) acceptRequest(ctx context.Context, op Operation, req *sdkprotocol.ControlMessage) (*sdkprotocol.RequestAccepted, error) { if err := c.validateRequest(ctx, op, req); err != nil { return nil, err @@ -247,6 +304,14 @@ func (c *Coordinator) HandleSessionEvent(ctx context.Context, raw []byte) error if err := c.advance(ctx, session, &event); err != nil { return err } + if latest, ok := c.store.Get(ctx, session.ID); ok && latest.State.Terminal() { + logger.Debug("coordinator processed terminal session event", + "session_id", latest.ID, + "participant_id", event.ParticipantID, + "state", string(latest.State), + ) + return nil + } logger.Debug("coordinator processed session event", "session_id", session.ID, "participant_id", event.ParticipantID, @@ -440,7 +505,7 @@ func (c *Coordinator) persistKeyInfoIfNeeded(session *Session, result *sdkprotoc return nil } -func (c *Coordinator) registerDualKeygen(sessionIDs []string) { +func (c *Coordinator) registerDualKeygen(sessionIDs []string, seeded map[sdkprotocol.ProtocolType][]byte) { if c == nil || len(sessionIDs) == 0 { return } @@ -448,6 +513,12 @@ func (c *Coordinator) registerDualKeygen(sessionIDs []string) { sessionIDs: make(map[sdkprotocol.ProtocolType]string, len(sessionIDs)), results: make(map[sdkprotocol.ProtocolType][]byte, len(sessionIDs)), } + for protocol, publicKey := range seeded { + if len(publicKey) == 0 { + continue + } + group.results[protocol] = append([]byte(nil), publicKey...) + } for _, sessionID := range sessionIDs { session, ok := c.store.Get(context.Background(), sessionID) if !ok || session == nil || session.Start == nil { @@ -485,6 +556,7 @@ func (c *Coordinator) recordDualKeygenResult(session *Session, result *sdkprotoc aggregate := &sdkprotocol.Result{ KeyShare: &sdkprotocol.KeyShareResult{ KeyID: walletID, + PublicKey: append([]byte(nil), ecdsaPubKey...), ECDSAPubKey: append([]byte(nil), ecdsaPubKey...), EDDSAPubKey: append([]byte(nil), eddsaPubKey...), }, diff --git a/internal/coordinator/coordinator_test.go b/internal/coordinator/coordinator_test.go index 8d6fd06b..5a046563 100644 --- a/internal/coordinator/coordinator_test.go +++ b/internal/coordinator/coordinator_test.go @@ -4,11 +4,15 @@ import ( "context" "crypto/ed25519" "crypto/rand" + "encoding/hex" "encoding/json" "testing" "time" + coordinatorv1 "github.com/fystack/mpcium-sdk/integrations/coordinator-grpc/proto/coordinator/v1" sdkprotocol "github.com/fystack/mpcium-sdk/protocol" + natsserver "github.com/nats-io/nats-server/v2/server" + "github.com/nats-io/nats.go" ) type fakeSigner struct{} @@ -311,6 +315,15 @@ func TestHandleRequestKeygenWithoutProtocolCreatesBothSessions(t *testing.T) { if !seenProtocols[sdkprotocol.ProtocolTypeECDSA] || !seenProtocols[sdkprotocol.ProtocolTypeEdDSA] { t.Fatalf("expected both ECDSA and EdDSA sessions, got %+v", seenProtocols) } + + sessionsByProtocol := map[sdkprotocol.ProtocolType]string{} + for _, session := range active { + sessionsByProtocol[session.Start.Protocol] = session.ID + } + completeKeygenSession(t, coord, fixtures, sessionsByProtocol[sdkprotocol.ProtocolTypeECDSA], "wallet_dual_protocol", []byte("ecdsa-pub")) + completeKeygenSession(t, coord, fixtures, sessionsByProtocol[sdkprotocol.ProtocolTypeEdDSA], "wallet_dual_protocol", []byte("eddsa-pub")) + + assertDualKeygenResult(t, coord, accepted.SessionID, "wallet_dual_protocol", []byte("ecdsa-pub"), []byte("eddsa-pub")) } func TestHandleRequestKeygenBothCreatesBothSessions(t *testing.T) { @@ -348,6 +361,15 @@ func TestHandleRequestKeygenBothCreatesBothSessions(t *testing.T) { if len(active) != 2 { t.Fatalf("expected 2 active sessions, got %d", len(active)) } + + sessionsByProtocol := map[sdkprotocol.ProtocolType]string{} + for _, session := range active { + sessionsByProtocol[session.Start.Protocol] = session.ID + } + completeKeygenSession(t, coord, fixtures, sessionsByProtocol[sdkprotocol.ProtocolTypeECDSA], "wallet_explicit_both", []byte("ecdsa-pub")) + completeKeygenSession(t, coord, fixtures, sessionsByProtocol[sdkprotocol.ProtocolTypeEdDSA], "wallet_explicit_both", []byte("eddsa-pub")) + + assertDualKeygenResult(t, coord, accepted.SessionID, "wallet_explicit_both", []byte("ecdsa-pub"), []byte("eddsa-pub")) } func TestKeygenBothPublishesAggregatedPubKeys(t *testing.T) { @@ -393,12 +415,296 @@ func TestKeygenBothPublishesAggregatedPubKeys(t *testing.T) { if published == nil || published.KeyShare == nil { t.Fatalf("missing published dual keygen result") } + if string(published.KeyShare.PublicKey) != "ecdsa-pub" { + t.Fatalf("public_key = %q", string(published.KeyShare.PublicKey)) + } if string(published.KeyShare.ECDSAPubKey) != "ecdsa-pub" { t.Fatalf("ecdsa_pubkey = %q", string(published.KeyShare.ECDSAPubKey)) } if string(published.KeyShare.EDDSAPubKey) != "eddsa-pub" { t.Fatalf("eddsa_pubkey = %q", string(published.KeyShare.EDDSAPubKey)) } + session, ok := coord.store.Get(ctx, accepted.SessionID) + if !ok { + t.Fatalf("missing accepted session") + } + grpcResult := sessionToProtoResult(session) + if grpcResult.GetPublicKeyHex() != hex.EncodeToString([]byte("ecdsa-pub")) { + t.Fatalf("grpc public_key_hex = %q", grpcResult.GetPublicKeyHex()) + } +} + +func TestHandleRequestKeygenBothWithExistingProtocolSeedsAggregatedResult(t *testing.T) { + ctx := context.Background() + coord, _, _, fixtures := newTestCoordinator(t) + markOnline(t, coord.presence, fixtures["p1"].pub, "p1") + markOnline(t, coord.presence, fixtures["p2"].pub, "p2") + coord.keyInfoStore.Save(KeyInfo{ + WalletID: "wallet_seeded_both", + KeyType: string(sdkprotocol.ProtocolTypeECDSA), + PublicKey: []byte("existing-ecdsa-pub"), + }) + + req := &sdkprotocol.ControlMessage{ + SessionStart: &sdkprotocol.SessionStart{ + SessionID: "client-supplied", + Protocol: sdkprotocol.ProtocolType("both"), + Operation: sdkprotocol.OperationTypeKeygen, + Threshold: 1, + Participants: []*sdkprotocol.SessionParticipant{ + {ParticipantID: "p1", PartyKey: []byte("p1"), IdentityPublicKey: fixtures["p1"].pub}, + {ParticipantID: "p2", PartyKey: []byte("p2"), IdentityPublicKey: fixtures["p2"].pub}, + }, + Keygen: &sdkprotocol.KeygenPayload{KeyID: "wallet_seeded_both"}, + }, + } + + rawReply, err := coord.HandleRequest(ctx, OperationKeygen, mustJSON(t, req)) + if err != nil { + t.Fatal(err) + } + var accepted sdkprotocol.RequestAccepted + if err := json.Unmarshal(rawReply, &accepted); err != nil { + t.Fatal(err) + } + if !accepted.Accepted { + t.Fatalf("expected request accepted") + } + + active := coord.store.ListActive(ctx) + if len(active) != 1 { + t.Fatalf("expected 1 active session, got %d", len(active)) + } + if active[0].Start.Protocol != sdkprotocol.ProtocolTypeEdDSA { + t.Fatalf("expected missing EdDSA session, got %s", active[0].Start.Protocol) + } + completeKeygenSession(t, coord, fixtures, active[0].ID, "wallet_seeded_both", []byte("eddsa-pub")) + + assertDualKeygenResult(t, coord, accepted.SessionID, "wallet_seeded_both", []byte("existing-ecdsa-pub"), []byte("eddsa-pub")) +} + +func TestHandleRequestKeygenBothRejectsWhenBothProtocolsExist(t *testing.T) { + ctx := context.Background() + coord, _, _, fixtures := newTestCoordinator(t) + markOnline(t, coord.presence, fixtures["p1"].pub, "p1") + markOnline(t, coord.presence, fixtures["p2"].pub, "p2") + coord.keyInfoStore.Save(KeyInfo{WalletID: "wallet_existing_both", KeyType: string(sdkprotocol.ProtocolTypeECDSA), PublicKey: []byte("ecdsa-pub")}) + coord.keyInfoStore.Save(KeyInfo{WalletID: "wallet_existing_both", KeyType: string(sdkprotocol.ProtocolTypeEdDSA), PublicKey: []byte("eddsa-pub")}) + + req := &sdkprotocol.ControlMessage{ + SessionStart: &sdkprotocol.SessionStart{ + SessionID: "client-supplied", + Protocol: sdkprotocol.ProtocolType("both"), + Operation: sdkprotocol.OperationTypeKeygen, + Threshold: 1, + Participants: []*sdkprotocol.SessionParticipant{ + {ParticipantID: "p1", PartyKey: []byte("p1"), IdentityPublicKey: fixtures["p1"].pub}, + {ParticipantID: "p2", PartyKey: []byte("p2"), IdentityPublicKey: fixtures["p2"].pub}, + }, + Keygen: &sdkprotocol.KeygenPayload{KeyID: "wallet_existing_both"}, + }, + } + + rawReply, err := coord.HandleRequest(ctx, OperationKeygen, mustJSON(t, req)) + if err != nil { + t.Fatal(err) + } + var rejected sdkprotocol.RequestRejected + if err := json.Unmarshal(rawReply, &rejected); err != nil { + t.Fatal(err) + } + if rejected.Accepted { + t.Fatalf("expected request rejected") + } + if rejected.ErrorCode != ErrorCodeConflict { + t.Fatalf("error code = %s, want %s", rejected.ErrorCode, ErrorCodeConflict) + } +} + +func TestHandleRequestKeygenBothRejectsExistingProtocolWithoutPublicKey(t *testing.T) { + ctx := context.Background() + coord, _, _, fixtures := newTestCoordinator(t) + markOnline(t, coord.presence, fixtures["p1"].pub, "p1") + markOnline(t, coord.presence, fixtures["p2"].pub, "p2") + coord.keyInfoStore.Save(KeyInfo{WalletID: "wallet_empty_existing", KeyType: string(sdkprotocol.ProtocolTypeECDSA)}) + + req := &sdkprotocol.ControlMessage{ + SessionStart: &sdkprotocol.SessionStart{ + SessionID: "client-supplied", + Protocol: sdkprotocol.ProtocolType("both"), + Operation: sdkprotocol.OperationTypeKeygen, + Threshold: 1, + Participants: []*sdkprotocol.SessionParticipant{ + {ParticipantID: "p1", PartyKey: []byte("p1"), IdentityPublicKey: fixtures["p1"].pub}, + {ParticipantID: "p2", PartyKey: []byte("p2"), IdentityPublicKey: fixtures["p2"].pub}, + }, + Keygen: &sdkprotocol.KeygenPayload{KeyID: "wallet_empty_existing"}, + }, + } + + rawReply, err := coord.HandleRequest(ctx, OperationKeygen, mustJSON(t, req)) + if err != nil { + t.Fatal(err) + } + var rejected sdkprotocol.RequestRejected + if err := json.Unmarshal(rawReply, &rejected); err != nil { + t.Fatal(err) + } + if rejected.Accepted { + t.Fatalf("expected request rejected") + } + if rejected.ErrorCode != ErrorCodeConflict { + t.Fatalf("error code = %s, want %s", rejected.ErrorCode, ErrorCodeConflict) + } + if len(coord.store.ListActive(ctx)) != 0 { + t.Fatalf("expected no sessions to be created") + } +} + +func TestGRPCKeygenEmptyProtocolReturnsAggregatedResult(t *testing.T) { + testGRPCDualKeygenReturnsAggregatedResult(t, "") +} + +func TestGRPCKeygenBothProtocolReturnsAggregatedResult(t *testing.T) { + testGRPCDualKeygenReturnsAggregatedResult(t, "both") +} + +func testGRPCDualKeygenReturnsAggregatedResult(t *testing.T, protocol string) { + t.Helper() + ctx := context.Background() + coord, _, _, fixtures := newTestCoordinator(t) + markOnline(t, coord.presence, fixtures["p1"].pub, "p1") + markOnline(t, coord.presence, fixtures["p2"].pub, "p2") + server := NewOrchestrationGRPCServer(coord, time.Millisecond) + + accepted, err := server.Keygen(ctx, &coordinatorv1.KeygenRequest{ + Protocol: protocol, + Threshold: 1, + WalletId: "wallet_grpc_dual_" + protocol, + Participants: grpcParticipants(fixtures), + }) + if err != nil { + t.Fatal(err) + } + if !accepted.GetAccepted() || accepted.GetSessionId() == "" { + t.Fatalf("unexpected accepted response: %+v", accepted) + } + + sessionsByProtocol := map[sdkprotocol.ProtocolType]string{} + for _, session := range coord.store.ListActive(ctx) { + sessionsByProtocol[session.Start.Protocol] = session.ID + } + completeKeygenSession(t, coord, fixtures, sessionsByProtocol[sdkprotocol.ProtocolTypeECDSA], "wallet_grpc_dual_"+protocol, []byte("grpc-ecdsa-pub")) + completeKeygenSession(t, coord, fixtures, sessionsByProtocol[sdkprotocol.ProtocolTypeEdDSA], "wallet_grpc_dual_"+protocol, []byte("grpc-eddsa-pub")) + + resultCtx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + result, err := server.WaitSessionResult(resultCtx, &coordinatorv1.SessionLookup{SessionId: accepted.GetSessionId()}) + if err != nil { + t.Fatal(err) + } + if !result.GetCompleted() { + t.Fatalf("expected completed result: %+v", result) + } + if result.GetPublicKeyHex() != hex.EncodeToString([]byte("grpc-ecdsa-pub")) { + t.Fatalf("public_key_hex = %q", result.GetPublicKeyHex()) + } + if result.GetEcdsaPubkey() != hex.EncodeToString([]byte("grpc-ecdsa-pub")) { + t.Fatalf("ecdsa_pubkey = %q", result.GetEcdsaPubkey()) + } + if result.GetEddsaPubkey() != hex.EncodeToString([]byte("grpc-eddsa-pub")) { + t.Fatalf("eddsa_pubkey = %q", result.GetEddsaPubkey()) + } +} + +func TestNATSRuntimeKeygenEmptyAndBothProtocolsPublishAggregatedResult(t *testing.T) { + for _, protocol := range []sdkprotocol.ProtocolType{"", sdkprotocol.ProtocolType("both")} { + t.Run("protocol_"+string(protocol), func(t *testing.T) { + ctx := context.Background() + coord, _, _, fixtures := newTestCoordinator(t) + markOnline(t, coord.presence, fixtures["p1"].pub, "p1") + markOnline(t, coord.presence, fixtures["p2"].pub, "p2") + + natsServer := startTestNATSServer(t) + nc, err := nats.Connect(natsServer.ClientURL()) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + defer natsServer.Shutdown() + + coord.controls = NewNATSControlPublisher(nc) + coord.results = NewNATSResultPublisher(nc) + runtime := NewNATSRuntime(nc, coord, coord.presence) + if err := runtime.Start(ctx); err != nil { + t.Fatal(err) + } + defer runtime.Stop() + + req := &sdkprotocol.ControlMessage{ + SessionStart: &sdkprotocol.SessionStart{ + SessionID: "client-supplied", + Protocol: protocol, + Operation: sdkprotocol.OperationTypeKeygen, + Threshold: 1, + Participants: []*sdkprotocol.SessionParticipant{ + {ParticipantID: "p1", PartyKey: []byte("p1"), IdentityPublicKey: fixtures["p1"].pub}, + {ParticipantID: "p2", PartyKey: []byte("p2"), IdentityPublicKey: fixtures["p2"].pub}, + }, + Keygen: &sdkprotocol.KeygenPayload{KeyID: "wallet_nats_dual_" + string(protocol)}, + }, + } + + replyMsg, err := nc.RequestWithContext(ctx, RequestSubject(OperationKeygen), mustJSON(t, req)) + if err != nil { + t.Fatal(err) + } + var accepted sdkprotocol.RequestAccepted + if err := json.Unmarshal(replyMsg.Data, &accepted); err != nil { + t.Fatal(err) + } + if !accepted.Accepted || accepted.SessionID == "" { + t.Fatalf("unexpected accepted response: %+v", accepted) + } + + resultSub, err := nc.SubscribeSync(SessionResultSubject(accepted.SessionID)) + if err != nil { + t.Fatal(err) + } + defer resultSub.Unsubscribe() + if err := nc.Flush(); err != nil { + t.Fatal(err) + } + + sessionsByProtocol := map[sdkprotocol.ProtocolType]string{} + for _, session := range coord.store.ListActive(ctx) { + sessionsByProtocol[session.Start.Protocol] = session.ID + } + completeKeygenSession(t, coord, fixtures, sessionsByProtocol[sdkprotocol.ProtocolTypeECDSA], "wallet_nats_dual_"+string(protocol), []byte("nats-ecdsa-pub")) + completeKeygenSession(t, coord, fixtures, sessionsByProtocol[sdkprotocol.ProtocolTypeEdDSA], "wallet_nats_dual_"+string(protocol), []byte("nats-eddsa-pub")) + + resultMsg, err := resultSub.NextMsg(time.Second) + if err != nil { + t.Fatal(err) + } + var result sdkprotocol.Result + if err := json.Unmarshal(resultMsg.Data, &result); err != nil { + t.Fatal(err) + } + if result.KeyShare == nil { + t.Fatalf("missing key share result") + } + if string(result.KeyShare.PublicKey) != "nats-ecdsa-pub" { + t.Fatalf("public_key = %q", string(result.KeyShare.PublicKey)) + } + if string(result.KeyShare.ECDSAPubKey) != "nats-ecdsa-pub" { + t.Fatalf("ecdsa_pubkey = %q", string(result.KeyShare.ECDSAPubKey)) + } + if string(result.KeyShare.EDDSAPubKey) != "nats-eddsa-pub" { + t.Fatalf("eddsa_pubkey = %q", string(result.KeyShare.EDDSAPubKey)) + } + }) + } } func TestHandleRequestSignWithoutProtocolRejected(t *testing.T) { @@ -581,6 +887,59 @@ func completeKeygenSession(t *testing.T, coord *Coordinator, keys map[string]par } } +func assertDualKeygenResult(t *testing.T, coord *Coordinator, sessionID, walletID string, ecdsaPubKey, eddsaPubKey []byte) { + t.Helper() + session, ok := coord.store.Get(context.Background(), sessionID) + if !ok { + t.Fatalf("missing session %s", sessionID) + } + if session.State != SessionCompleted { + t.Fatalf("session state = %s, want %s", session.State, SessionCompleted) + } + if session.Result == nil || session.Result.KeyShare == nil { + t.Fatalf("missing key share result") + } + keyShare := session.Result.KeyShare + if keyShare.KeyID != walletID { + t.Fatalf("key_id = %q, want %q", keyShare.KeyID, walletID) + } + if string(keyShare.PublicKey) != string(ecdsaPubKey) { + t.Fatalf("public_key = %q", string(keyShare.PublicKey)) + } + if string(keyShare.ECDSAPubKey) != string(ecdsaPubKey) { + t.Fatalf("ecdsa_pubkey = %q", string(keyShare.ECDSAPubKey)) + } + if string(keyShare.EDDSAPubKey) != string(eddsaPubKey) { + t.Fatalf("eddsa_pubkey = %q", string(keyShare.EDDSAPubKey)) + } +} + +func grpcParticipants(keys map[string]participantKey) []*coordinatorv1.Participant { + return []*coordinatorv1.Participant{ + {Id: "p1", IdentityPublicKeyHex: hex.EncodeToString(keys["p1"].pub)}, + {Id: "p2", IdentityPublicKeyHex: hex.EncodeToString(keys["p2"].pub)}, + } +} + +func startTestNATSServer(t *testing.T) *natsserver.Server { + t.Helper() + server, err := natsserver.NewServer(&natsserver.Options{ + Host: "127.0.0.1", + Port: -1, + NoLog: true, + NoSigs: true, + }) + if err != nil { + t.Fatal(err) + } + go server.Start() + if !server.ReadyForConnections(5 * time.Second) { + server.Shutdown() + t.Fatal("nats server did not become ready") + } + return server +} + func markOnline(t *testing.T, presence PresenceView, _ ed25519.PublicKey, participantID string) { t.Helper() err := presence.ApplyPresence(sdkprotocol.PresenceEvent{ From 20b1e571a6961e30cc9c2d636d960ad51863e11d Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 22 Apr 2026 16:55:16 +0700 Subject: [PATCH 21/23] Serialize cosigner session operations and queue early peer packets. Protect session transitions with operation-level locking and buffer peer MPC packets that arrive before local party startup, then flush them safely when MPC begins. Made-with: Cursor --- internal/cosigner/runtime.go | 127 ++++++++++++++++++++++++++++++++--- 1 file changed, 117 insertions(+), 10 deletions(-) diff --git a/internal/cosigner/runtime.go b/internal/cosigner/runtime.go index f2e88f36..08d942c9 100644 --- a/internal/cosigner/runtime.go +++ b/internal/cosigner/runtime.go @@ -20,15 +20,17 @@ import ( ) type Runtime struct { - cfg Config - relay Relay - stores Stores - identity *localIdentity - coordLookup *coordinatorLookup - sessionsMu sync.RWMutex - sessions map[string]*participant.ParticipantSession - sessionMeta map[string]sessionMeta - subs []Subscription + cfg Config + relay Relay + stores Stores + identity *localIdentity + coordLookup *coordinatorLookup + sessionsMu sync.RWMutex + sessionOpsMu sync.Mutex + sessions map[string]*participant.ParticipantSession + sessionMeta map[string]sessionMeta + pendingPeer map[string][]*sdkprotocol.PeerMessage + subs []Subscription } type sessionMeta struct { @@ -37,6 +39,7 @@ type sessionMeta struct { } const bootstrapPreparamsSlot = "bootstrap" +const maxPendingPeerMessagesPerSession = 256 func NewRuntime(cfg Config) (*Runtime, error) { relay, err := NewRelayFromConfig(cfg) @@ -68,6 +71,7 @@ func NewRuntime(cfg Config) (*Runtime, error) { coordLookup: coordLookup, sessions: map[string]*participant.ParticipantSession{}, sessionMeta: map[string]sessionMeta{}, + pendingPeer: map[string][]*sdkprotocol.PeerMessage{}, }, nil } @@ -220,6 +224,8 @@ func (r *Runtime) handleControl(raw []byte) error { "session_id", msg.SessionID, "action", meta.action, ) + r.sessionOpsMu.Lock() + defer r.sessionOpsMu.Unlock() return r.startSession(&msg, meta) } meta := r.getSessionMeta(msg.SessionID) @@ -231,6 +237,8 @@ func (r *Runtime) handleControl(raw []byte) error { "protocol", meta.protocol, "action", meta.action, ) + r.sessionOpsMu.Lock() + defer r.sessionOpsMu.Unlock() session := r.getSession(msg.SessionID) if session == nil { logger.Warn("ignoring control for unknown session", "session_id", msg.SessionID) @@ -267,7 +275,13 @@ func (r *Runtime) handleControl(raw []byte) error { ) return err } - return r.dispatchActions(actions) + if err := r.dispatchActions(actions); err != nil { + return err + } + if msg.MPCBegin != nil { + return r.flushPendingPeerMessages(msg.SessionID) + } + return nil } func (r *Runtime) startSession(msg *sdkprotocol.ControlMessage, meta sessionMeta) error { @@ -321,6 +335,8 @@ func (r *Runtime) handlePeer(raw []byte) error { "from_participant", msg.FromParticipantID, "phase", string(msg.Phase), ) + r.sessionOpsMu.Lock() + defer r.sessionOpsMu.Unlock() session := r.getSession(msg.SessionID) if session == nil { logger.Warn("ignoring peer message for unknown session", "session_id", msg.SessionID) @@ -328,12 +344,102 @@ func (r *Runtime) handlePeer(raw []byte) error { } actions, err := session.HandlePeer(&msg) if err != nil { + if errors.Is(err, participant.ErrPartyNotRunning) && msg.MPCPacket != nil { + if r.enqueuePendingPeerMessage(&msg) { + logger.Warn("queued peer mpc message until local party starts", + "node_id", r.cfg.NodeID, + "session_id", msg.SessionID, + "from_participant", msg.FromParticipantID, + "phase", string(msg.Phase), + ) + return nil + } + } + logger.Error("session handle peer failed", err, + "node_id", r.cfg.NodeID, + "session_id", msg.SessionID, + "from_participant", msg.FromParticipantID, + "phase", string(msg.Phase), + ) return err } return r.dispatchActions(actions) } +func (r *Runtime) enqueuePendingPeerMessage(msg *sdkprotocol.PeerMessage) bool { + r.sessionsMu.Lock() + defer r.sessionsMu.Unlock() + if _, ok := r.sessions[msg.SessionID]; !ok { + return false + } + queue := r.pendingPeer[msg.SessionID] + if len(queue) >= maxPendingPeerMessagesPerSession { + logger.Error("dropping peer mpc message because pending queue is full", + fmt.Errorf("pending peer queue full"), + "node_id", r.cfg.NodeID, + "session_id", msg.SessionID, + "from_participant", msg.FromParticipantID, + "limit", maxPendingPeerMessagesPerSession, + ) + return true + } + clone := *msg + if msg.Signature != nil { + clone.Signature = append([]byte(nil), msg.Signature...) + } + if msg.MPCPacket != nil { + packet := *msg.MPCPacket + packet.Payload = append([]byte(nil), msg.MPCPacket.Payload...) + packet.Nonce = append([]byte(nil), msg.MPCPacket.Nonce...) + clone.MPCPacket = &packet + } + queue = append(queue, &clone) + r.pendingPeer[msg.SessionID] = queue + return true +} + +func (r *Runtime) takePendingPeerMessages(sessionID string) []*sdkprotocol.PeerMessage { + r.sessionsMu.Lock() + defer r.sessionsMu.Unlock() + pending := r.pendingPeer[sessionID] + delete(r.pendingPeer, sessionID) + return pending +} + +func (r *Runtime) flushPendingPeerMessages(sessionID string) error { + pending := r.takePendingPeerMessages(sessionID) + if len(pending) == 0 { + return nil + } + logger.Info("flushing queued peer mpc messages", + "node_id", r.cfg.NodeID, + "session_id", sessionID, + "count", len(pending), + ) + for _, msg := range pending { + session := r.getSession(sessionID) + if session == nil { + logger.Warn("dropping queued peer message for unknown session", "session_id", sessionID) + return nil + } + actions, err := session.HandlePeer(msg) + if err != nil { + if errors.Is(err, participant.ErrPartyNotRunning) { + r.enqueuePendingPeerMessage(msg) + return nil + } + return err + } + if err := r.dispatchActions(actions); err != nil { + return err + } + } + return nil +} + func (r *Runtime) tickSessions() error { + r.sessionOpsMu.Lock() + defer r.sessionOpsMu.Unlock() r.sessionsMu.RLock() ids := make([]string, 0, len(r.sessions)) for id := range r.sessions { @@ -414,6 +520,7 @@ func (r *Runtime) dropSessionMeta(sessionID string) { defer r.sessionsMu.Unlock() delete(r.sessionMeta, sessionID) delete(r.sessions, sessionID) + delete(r.pendingPeer, sessionID) } func controlType(msg *sdkprotocol.ControlMessage) string { From 194f6f24eb3b36348b4e848ac711ea1890a2cb0d Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 22 Apr 2026 16:55:22 +0700 Subject: [PATCH 22/23] Normalize keygen protocol handling across coordinator clients. Trim and normalize protocol values in NATS/gRPC keygen requests, add transport-level coverage for empty and "both" protocol inputs, and update the keygen example to use NATS aggregated key output. Made-with: Cursor --- examples/coordinatorclient-keygen/main.go | 34 +++--- pkg/coordinatorclient/client.go | 8 +- pkg/coordinatorclient/client_test.go | 127 +++++++++++++++++++++- 3 files changed, 147 insertions(+), 22 deletions(-) diff --git a/examples/coordinatorclient-keygen/main.go b/examples/coordinatorclient-keygen/main.go index 151a0062..cb55419c 100644 --- a/examples/coordinatorclient-keygen/main.go +++ b/examples/coordinatorclient-keygen/main.go @@ -7,15 +7,14 @@ import ( "log" "time" - sdkprotocol "github.com/fystack/mpcium-sdk/protocol" "github.com/fystack/mpcium/pkg/coordinatorclient" "github.com/google/uuid" ) func main() { client, err := coordinatorclient.New(coordinatorclient.Config{ - GRPCAddress: "127.0.0.1:50051", - Timeout: 5 * time.Second, + NATSURL: "nats://127.0.0.1:4222", + Timeout: 5 * time.Second, }) if err != nil { log.Fatalf("create coordinator client: %v", err) @@ -31,15 +30,10 @@ func main() { ID: "peer-node-02", IdentityPublicKey: mustDecodeHex("d9034dd84e0dd10a57d6a09a8267b217051d5f121ff52fca66c2b485be16ae02"), }, - { - ID: "mobile-sample-01", - IdentityPublicKey: mustDecodeHex("0c67697e3142c1c87dd8fa034fdfece14fc8ba00145bc0f123d6cd8bd33640e2"), - }, } walletID := "wallet_" + uuid.New().String() - runKeygenForProtocol(client, participants, walletID, sdkprotocol.ProtocolTypeECDSA) - runKeygenForProtocol(client, participants, walletID, sdkprotocol.ProtocolTypeEdDSA) + runKeygen(client, participants, walletID) } func mustDecodeHex(value string) []byte { @@ -50,17 +44,16 @@ func mustDecodeHex(value string) []byte { return decoded } -func runKeygenForProtocol(client *coordinatorclient.Client, participants []coordinatorclient.KeygenParticipant, walletID string, protocol sdkprotocol.ProtocolType) { +func runKeygen(client *coordinatorclient.Client, participants []coordinatorclient.KeygenParticipant, walletID string) { requestCtx, cancelRequest := context.WithTimeout(context.Background(), 10*time.Second) resp, err := client.RequestKeygen(requestCtx, coordinatorclient.KeygenRequest{ - Protocol: protocol, Threshold: 1, WalletID: walletID, Participants: participants, }) cancelRequest() if err != nil { - log.Fatalf("request keygen (%s): %v (verify both cosigners are online and publishing real presence)", protocol, err) + log.Fatalf("request keygen: %v (verify both cosigners are online and publishing real presence)", err) } acceptedAt := time.Now() @@ -68,14 +61,19 @@ func runKeygenForProtocol(client *coordinatorclient.Client, participants []coord result, err := client.WaitSessionResult(resultCtx, resp.SessionID) cancelResult() if err != nil { - log.Fatalf("wait session result (%s): %v (check both cosigners are running and session events are flowing)", protocol, err) + log.Fatalf("wait session result: %v (check both cosigners are running and session events are flowing)", err) } - if result == nil { - fmt.Printf("protocol=%s session_id=%s result=empty wait_seconds=%.3f\n", protocol, resp.SessionID, time.Since(acceptedAt).Seconds()) + if result == nil || result.KeyShare == nil { + fmt.Printf("session_id=%s result=empty wait_seconds=%.3f\n", resp.SessionID, time.Since(acceptedAt).Seconds()) return } - fmt.Printf("protocol=%s key_id=%s session_id=%s wait_seconds=%.3f\n", protocol, result.KeyShare.KeyID, resp.SessionID, time.Since(acceptedAt).Seconds()) - if result.KeyShare != nil { - fmt.Printf("public_key_hex=%s\n", hex.EncodeToString(result.KeyShare.PublicKey)) + + fmt.Printf("key_id=%s session_id=%s wait_seconds=%.3f\n", result.KeyShare.KeyID, resp.SessionID, time.Since(acceptedAt).Seconds()) + fmt.Printf("public_key_hex=%s\n", hex.EncodeToString(result.KeyShare.PublicKey)) + if len(result.KeyShare.ECDSAPubKey) > 0 { + fmt.Printf("ecdsa_pubkey_hex=%s\n", hex.EncodeToString(result.KeyShare.ECDSAPubKey)) + } + if len(result.KeyShare.EDDSAPubKey) > 0 { + fmt.Printf("eddsa_pubkey_hex=%s\n", hex.EncodeToString(result.KeyShare.EDDSAPubKey)) } } diff --git a/pkg/coordinatorclient/client.go b/pkg/coordinatorclient/client.go index 21122b54..fdf5677a 100644 --- a/pkg/coordinatorclient/client.go +++ b/pkg/coordinatorclient/client.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "strings" "time" coordinatorv1 "github.com/fystack/mpcium-sdk/integrations/coordinator-grpc/proto/coordinator/v1" @@ -226,10 +227,11 @@ func (c *Client) requestSessionNATS(ctx context.Context, subject string, msg *sd } func normalizeProtocol(protocol sdkprotocol.ProtocolType) sdkprotocol.ProtocolType { - if string(protocol) == "" { + value := strings.TrimSpace(string(protocol)) + if value == "" { return sdkprotocol.ProtocolTypeUnspecified } - return protocol + return sdkprotocol.ProtocolType(value) } func (c *Client) WaitSessionResult(ctx context.Context, sessionID string) (*sdkprotocol.Result, error) { @@ -271,7 +273,7 @@ func (c *Client) WaitSessionResult(ctx context.Context, sessionID string) (*sdkp func (c *Client) requestKeygenGRPC(ctx context.Context, req KeygenRequest) (*sdkprotocol.RequestAccepted, error) { grpcReq := &coordinatorv1.KeygenRequest{ - Protocol: string(req.Protocol), + Protocol: string(normalizeProtocol(req.Protocol)), Threshold: req.Threshold, WalletId: req.WalletID, Participants: mapParticipantsToProto(req.Participants), diff --git a/pkg/coordinatorclient/client_test.go b/pkg/coordinatorclient/client_test.go index 86b3016a..91c36676 100644 --- a/pkg/coordinatorclient/client_test.go +++ b/pkg/coordinatorclient/client_test.go @@ -3,6 +3,7 @@ package coordinatorclient import ( "context" "encoding/hex" + "encoding/json" "net" "strings" "testing" @@ -10,6 +11,8 @@ import ( coordinatorv1 "github.com/fystack/mpcium-sdk/integrations/coordinator-grpc/proto/coordinator/v1" sdkprotocol "github.com/fystack/mpcium-sdk/protocol" + natsserver "github.com/nats-io/nats-server/v2/server" + "github.com/nats-io/nats.go" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" @@ -22,9 +25,13 @@ type fakeCoordinatorServer struct { keygenResp *coordinatorv1.RequestAccepted signResp *coordinatorv1.RequestAccepted results map[string]*coordinatorv1.SessionResult + keygenReqs []*coordinatorv1.KeygenRequest } -func (s *fakeCoordinatorServer) Keygen(context.Context, *coordinatorv1.KeygenRequest) (*coordinatorv1.RequestAccepted, error) { +func (s *fakeCoordinatorServer) Keygen(_ context.Context, req *coordinatorv1.KeygenRequest) (*coordinatorv1.RequestAccepted, error) { + cloned := *req + cloned.Participants = append([]*coordinatorv1.Participant(nil), req.GetParticipants()...) + s.keygenReqs = append(s.keygenReqs, &cloned) if s.keygenResp != nil { return s.keygenResp, nil } @@ -125,6 +132,96 @@ func TestGRPCClientRequestKeygenAndSignResponses(t *testing.T) { } } +func TestGRPCClientRequestKeygenNormalizesProtocol(t *testing.T) { + fake := &fakeCoordinatorServer{} + client, cleanup := newTestGRPCClient(t, fake) + defer cleanup() + + for _, protocol := range []sdkprotocol.ProtocolType{"", sdkprotocol.ProtocolType(" both ")} { + _, err := client.RequestKeygen(context.Background(), KeygenRequest{ + Protocol: protocol, + Threshold: 1, + WalletID: "wallet-1", + Participants: []KeygenParticipant{ + {ID: "p1", IdentityPublicKey: []byte("pub-1")}, + {ID: "p2", IdentityPublicKey: []byte("pub-2")}, + }, + }) + if err != nil { + t.Fatal(err) + } + } + + if len(fake.keygenReqs) != 2 { + t.Fatalf("keygen request count = %d, want 2", len(fake.keygenReqs)) + } + if fake.keygenReqs[0].GetProtocol() != string(sdkprotocol.ProtocolTypeUnspecified) { + t.Fatalf("empty protocol sent as %q", fake.keygenReqs[0].GetProtocol()) + } + if fake.keygenReqs[1].GetProtocol() != "both" { + t.Fatalf("both protocol sent as %q", fake.keygenReqs[1].GetProtocol()) + } +} + +func TestNATSClientRequestKeygenNormalizesProtocol(t *testing.T) { + server := startTestNATSServer(t) + defer server.Shutdown() + + responder, err := nats.Connect(server.ClientURL()) + if err != nil { + t.Fatal(err) + } + defer responder.Close() + + seenProtocols := make(chan sdkprotocol.ProtocolType, 2) + sub, err := responder.Subscribe(requestKeygenSubject, func(msg *nats.Msg) { + var control sdkprotocol.ControlMessage + if err := json.Unmarshal(msg.Data, &control); err != nil { + t.Errorf("unmarshal keygen request: %v", err) + _ = msg.Respond(mustJSON(t, &sdkprotocol.RequestRejected{Accepted: false, ErrorCode: "decode", ErrorMessage: err.Error()})) + return + } + seenProtocols <- control.SessionStart.Protocol + _ = msg.Respond(mustJSON(t, &sdkprotocol.RequestAccepted{Accepted: true, SessionID: "sess_keygen", ExpiresAt: "2026-04-22T10:00:00Z"})) + }) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + if err := responder.Flush(); err != nil { + t.Fatal(err) + } + + clientConn, err := nats.Connect(server.ClientURL()) + if err != nil { + t.Fatal(err) + } + client := &Client{nc: clientConn, timeout: time.Second, transport: transportNATS} + defer client.Close() + + for _, protocol := range []sdkprotocol.ProtocolType{"", sdkprotocol.ProtocolType(" both ")} { + _, err := client.RequestKeygen(context.Background(), KeygenRequest{ + Protocol: protocol, + Threshold: 1, + WalletID: "wallet-1", + Participants: []KeygenParticipant{ + {ID: "p1", IdentityPublicKey: []byte("pub-1")}, + {ID: "p2", IdentityPublicKey: []byte("pub-2")}, + }, + }) + if err != nil { + t.Fatal(err) + } + } + + if got := <-seenProtocols; got != sdkprotocol.ProtocolTypeUnspecified { + t.Fatalf("empty protocol sent as %q", got) + } + if got := <-seenProtocols; got != sdkprotocol.ProtocolType("both") { + t.Fatalf("both protocol sent as %q", got) + } +} + func TestGRPCClientWaitSessionResultMapsKeygenAndSignature(t *testing.T) { signature := []byte("signature") recovery := []byte("recovery") @@ -171,3 +268,31 @@ func TestGRPCClientWaitSessionResultMapsKeygenAndSignature(t *testing.T) { t.Fatalf("unexpected sign result: %+v", signResult) } } + +func startTestNATSServer(t *testing.T) *natsserver.Server { + t.Helper() + server, err := natsserver.NewServer(&natsserver.Options{ + Host: "127.0.0.1", + Port: -1, + NoLog: true, + NoSigs: true, + }) + if err != nil { + t.Fatal(err) + } + go server.Start() + if !server.ReadyForConnections(5 * time.Second) { + server.Shutdown() + t.Fatal("nats server did not become ready") + } + return server +} + +func mustJSON(t *testing.T, v any) []byte { + t.Helper() + raw, err := json.Marshal(v) + if err != nil { + t.Fatal(err) + } + return raw +} From b7b12bed2ead0f25cde4ae0defb505db4d642694 Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 22 Apr 2026 17:44:22 +0700 Subject: [PATCH 23/23] Refactor keygen result handling in coordinator to use unified Result structure. Update related functions and tests to accommodate changes from KeyShare to Keygen, ensuring consistency across the codebase. --- examples/coordinatorclient-keygen/main.go | 13 +- internal/coordinator/coordinator.go | 138 +++++++++++++-------- internal/coordinator/coordinator_test.go | 76 +++++------- internal/coordinator/orchestration_grpc.go | 8 +- internal/coordinator/publisher.go | 4 +- internal/coordinator/store.go | 14 +-- internal/coordinator/types.go | 23 +++- pkg/coordinatorclient/client.go | 59 +++++---- pkg/coordinatorclient/client_test.go | 13 +- 9 files changed, 188 insertions(+), 160 deletions(-) diff --git a/examples/coordinatorclient-keygen/main.go b/examples/coordinatorclient-keygen/main.go index cb55419c..4ce9e948 100644 --- a/examples/coordinatorclient-keygen/main.go +++ b/examples/coordinatorclient-keygen/main.go @@ -63,17 +63,16 @@ func runKeygen(client *coordinatorclient.Client, participants []coordinatorclien if err != nil { log.Fatalf("wait session result: %v (check both cosigners are running and session events are flowing)", err) } - if result == nil || result.KeyShare == nil { + if result == nil || result.Keygen == nil { fmt.Printf("session_id=%s result=empty wait_seconds=%.3f\n", resp.SessionID, time.Since(acceptedAt).Seconds()) return } - fmt.Printf("key_id=%s session_id=%s wait_seconds=%.3f\n", result.KeyShare.KeyID, resp.SessionID, time.Since(acceptedAt).Seconds()) - fmt.Printf("public_key_hex=%s\n", hex.EncodeToString(result.KeyShare.PublicKey)) - if len(result.KeyShare.ECDSAPubKey) > 0 { - fmt.Printf("ecdsa_pubkey_hex=%s\n", hex.EncodeToString(result.KeyShare.ECDSAPubKey)) + fmt.Printf("key_id=%s session_id=%s wait_seconds=%.3f\n", result.Keygen.KeyID, resp.SessionID, time.Since(acceptedAt).Seconds()) + if len(result.Keygen.ECDSAPubKey) > 0 { + fmt.Printf("ecdsa_pubkey_hex=%s\n", hex.EncodeToString(result.Keygen.ECDSAPubKey)) } - if len(result.KeyShare.EDDSAPubKey) > 0 { - fmt.Printf("eddsa_pubkey_hex=%s\n", hex.EncodeToString(result.KeyShare.EDDSAPubKey)) + if len(result.Keygen.EDDSAPubKey) > 0 { + fmt.Printf("eddsa_pubkey_hex=%s\n", hex.EncodeToString(result.Keygen.EDDSAPubKey)) } } diff --git a/internal/coordinator/coordinator.go b/internal/coordinator/coordinator.go index ccb39eee..c2f0e82a 100644 --- a/internal/coordinator/coordinator.go +++ b/internal/coordinator/coordinator.go @@ -37,7 +37,7 @@ type dualKeygenGroup struct { } type dualKeygenCompletion struct { - result *sdkprotocol.Result + result *Result sessionIDs []string } @@ -285,7 +285,7 @@ func (c *Coordinator) HandleSessionEvent(ctx context.Context, raw []byte) error if event.SessionCompleted.Result == nil { return c.failSession(ctx, session, ErrorCodeValidation, "missing result payload") } - state.ResultHash = canonicalOperationResultHash(session.Op, event.SessionCompleted.Result) + state.ResultHash = canonicalSessionResultHash(session, event.SessionCompleted.Result) case event.PeerFailed != nil: state.Failed = true state.ErrorCode = ErrorCodeParticipantFailed @@ -474,11 +474,11 @@ func (c *Coordinator) advance(ctx context.Context, session *Session, event *sdkp return nil } -func (c *Coordinator) persistKeyInfoIfNeeded(session *Session, result *sdkprotocol.Result) error { - if c.keyInfoStore == nil || session == nil || result == nil || session.Op != OperationKeygen || result.KeyShare == nil { +func (c *Coordinator) persistKeyInfoIfNeeded(session *Session, result *Result) error { + if c.keyInfoStore == nil || session == nil || session.Start == nil || result == nil || session.Op != OperationKeygen || result.Keygen == nil { return nil } - walletID := result.KeyShare.KeyID + walletID := result.Keygen.KeyID if walletID == "" { walletID = keygenWalletID(session.Start) } @@ -498,7 +498,7 @@ func (c *Coordinator) persistKeyInfoIfNeeded(session *Session, result *sdkprotoc KeyType: string(session.Start.Protocol), Threshold: int(session.Start.Threshold), Participants: participantIDs, - PublicKey: append([]byte(nil), result.KeyShare.PublicKey...), + PublicKey: keygenPublicKeyForProtocol(result.Keygen, session.Start.Protocol), CreatedAt: c.now().UTC().Format(time.RFC3339Nano), } c.keyInfoStore.Save(info) @@ -533,8 +533,8 @@ func (c *Coordinator) registerDualKeygen(sessionIDs []string, seeded map[sdkprot } } -func (c *Coordinator) recordDualKeygenResult(session *Session, result *sdkprotocol.Result) (*dualKeygenCompletion, bool, bool) { - if c == nil || session == nil || session.Start == nil || result == nil || result.KeyShare == nil { +func (c *Coordinator) recordDualKeygenResult(session *Session, result *Result) (*dualKeygenCompletion, bool, bool) { + if c == nil || session == nil || session.Start == nil || result == nil || result.Keygen == nil { return nil, false, false } c.dualKeygenMu.Lock() @@ -543,7 +543,7 @@ func (c *Coordinator) recordDualKeygenResult(session *Session, result *sdkprotoc if !ok { return nil, false, false } - group.results[session.Start.Protocol] = keySharePublicKeyForProtocol(result.KeyShare, session.Start.Protocol) + group.results[session.Start.Protocol] = keygenPublicKeyForProtocol(result.Keygen, session.Start.Protocol) ecdsaPubKey := group.results[sdkprotocol.ProtocolTypeECDSA] eddsaPubKey := group.results[sdkprotocol.ProtocolTypeEdDSA] if len(ecdsaPubKey) == 0 || len(eddsaPubKey) == 0 { @@ -551,12 +551,11 @@ func (c *Coordinator) recordDualKeygenResult(session *Session, result *sdkprotoc } walletID := keygenWalletID(session.Start) if walletID == "" { - walletID = result.KeyShare.KeyID + walletID = result.Keygen.KeyID } - aggregate := &sdkprotocol.Result{ - KeyShare: &sdkprotocol.KeyShareResult{ + aggregate := &Result{ + Keygen: &KeygenResult{ KeyID: walletID, - PublicKey: append([]byte(nil), ecdsaPubKey...), ECDSAPubKey: append([]byte(nil), ecdsaPubKey...), EDDSAPubKey: append([]byte(nil), eddsaPubKey...), }, @@ -577,7 +576,7 @@ func (c *Coordinator) completeDualKeygen(ctx context.Context, completion *dualKe return fmt.Errorf("missing dual keygen completion") } result := completion.result - if result == nil || result.KeyShare == nil { + if result == nil || result.Keygen == nil || result.Keygen.KeyID == "" { return fmt.Errorf("missing dual keygen result") } now := c.now() @@ -715,9 +714,9 @@ func (c *Coordinator) expireSession(ctx context.Context, session *Session) error return c.results.PublishResult(ctx, session.ID, nil) } -func (c *Coordinator) buildCompletedResult(session *Session, event *sdkprotocol.SessionEvent) (*sdkprotocol.Result, string, error) { +func (c *Coordinator) buildCompletedResult(session *Session, event *sdkprotocol.SessionEvent) (*Result, string, error) { var resultHash string - var result *sdkprotocol.Result + var result *Result for _, state := range session.ParticipantState { if state.ResultHash == "" { return nil, "", fmt.Errorf("participant completed without result hash") @@ -739,20 +738,31 @@ func (c *Coordinator) buildCompletedResult(session *Session, event *sdkprotocol. if in.KeyShare == nil { return nil, "", fmt.Errorf("missing key share result") } - result = &sdkprotocol.Result{ - KeyShare: &sdkprotocol.KeyShareResult{ - KeyID: in.KeyShare.KeyID, - PublicKey: append([]byte(nil), in.KeyShare.PublicKey...), - ECDSAPubKey: keyShareProtocolPubKey(in.KeyShare, sdkprotocol.ProtocolTypeECDSA, session.Start.Protocol), - EDDSAPubKey: keyShareProtocolPubKey(in.KeyShare, sdkprotocol.ProtocolTypeEdDSA, session.Start.Protocol), - }, + keygen := &KeygenResult{KeyID: in.KeyShare.KeyID} + switch session.Start.Protocol { + case sdkprotocol.ProtocolTypeECDSA: + keygen.ECDSAPubKey = append([]byte(nil), in.KeyShare.PublicKey...) + case sdkprotocol.ProtocolTypeEdDSA: + keygen.EDDSAPubKey = append([]byte(nil), in.KeyShare.PublicKey...) + default: + return nil, "", fmt.Errorf("unsupported keygen protocol %q", session.Start.Protocol) } + result = &Result{Keygen: keygen} case OperationSign: if in.Signature == nil { return nil, "", fmt.Errorf("missing signature result") } - result = &sdkprotocol.Result{ - Signature: cloneResult(in).Signature, + sig := in.Signature + result = &Result{ + Signature: &SignResult{ + KeyID: sig.KeyID, + Signature: append([]byte(nil), sig.Signature...), + SignatureRecovery: append([]byte(nil), sig.SignatureRecovery...), + R: append([]byte(nil), sig.R...), + S: append([]byte(nil), sig.S...), + SignedInput: append([]byte(nil), sig.SignedInput...), + PublicKey: append([]byte(nil), sig.PublicKey...), + }, } default: return nil, "", fmt.Errorf("unsupported operation") @@ -772,7 +782,7 @@ func (c *Coordinator) GetSession(ctx context.Context, sessionID string) (*Sessio return c.store.Get(ctx, sessionID) } -func (c *Coordinator) GetSessionResult(ctx context.Context, sessionID string) (*sdkprotocol.Result, bool) { +func (c *Coordinator) GetSessionResult(ctx context.Context, sessionID string) (*Result, bool) { session, ok := c.GetSession(ctx, sessionID) if !ok { return nil, false @@ -789,30 +799,25 @@ func allParticipants(session *Session, predicate func(*ParticipantState) bool) b return true } -func canonicalOperationResultHash(op Operation, result *sdkprotocol.Result) string { +func canonicalSessionResultHash(session *Session, result *sdkprotocol.Result) string { if result == nil { return "" } - switch op { + if session == nil { + return canonicalResultHash(sdkResultToCoordinatorResult(result, sdkprotocol.ProtocolTypeUnspecified)) + } + switch session.Op { case OperationKeygen: if result.KeyShare == nil { return "" } - normalized := &sdkprotocol.Result{ - KeyShare: &sdkprotocol.KeyShareResult{ - KeyID: result.KeyShare.KeyID, - PublicKey: append([]byte(nil), result.KeyShare.PublicKey...), - ECDSAPubKey: append([]byte(nil), result.KeyShare.ECDSAPubKey...), - EDDSAPubKey: append([]byte(nil), result.KeyShare.EDDSAPubKey...), - }, - } - return canonicalResultHash(normalized) + return canonicalResultHash(&Result{Keygen: sdkKeyShareToCoordinatorKeygen(result.KeyShare, session.Start.Protocol)}) default: - return canonicalResultHash(result) + return canonicalResultHash(sdkResultToCoordinatorResult(result, sdkprotocol.ProtocolTypeUnspecified)) } } -func canonicalResultHash(result *sdkprotocol.Result) string { +func canonicalResultHash(result any) string { if result == nil { return "" } @@ -828,28 +833,61 @@ func keygenWalletID(start *sdkprotocol.SessionStart) string { return start.Keygen.KeyID } -func keySharePublicKeyForProtocol(keyShare *sdkprotocol.KeyShareResult, protocol sdkprotocol.ProtocolType) []byte { +func sdkResultToCoordinatorResult(result *sdkprotocol.Result, protocol sdkprotocol.ProtocolType) *Result { + if result == nil { + return nil + } + if result.KeyShare != nil { + return &Result{Keygen: sdkKeyShareToCoordinatorKeygen(result.KeyShare, protocol)} + } + if result.Signature != nil { + sig := result.Signature + return &Result{ + Signature: &SignResult{ + KeyID: sig.KeyID, + Signature: append([]byte(nil), sig.Signature...), + SignatureRecovery: append([]byte(nil), sig.SignatureRecovery...), + R: append([]byte(nil), sig.R...), + S: append([]byte(nil), sig.S...), + SignedInput: append([]byte(nil), sig.SignedInput...), + PublicKey: append([]byte(nil), sig.PublicKey...), + }, + } + } + return &Result{} +} + +func sdkKeyShareToCoordinatorKeygen(keyShare *sdkprotocol.KeyShareResult, protocol sdkprotocol.ProtocolType) *KeygenResult { if keyShare == nil { return nil } + keygen := &KeygenResult{KeyID: keyShare.KeyID} switch protocol { case sdkprotocol.ProtocolTypeECDSA: - if len(keyShare.ECDSAPubKey) > 0 { - return append([]byte(nil), keyShare.ECDSAPubKey...) - } + keygen.ECDSAPubKey = append([]byte(nil), keyShare.PublicKey...) case sdkprotocol.ProtocolTypeEdDSA: - if len(keyShare.EDDSAPubKey) > 0 { - return append([]byte(nil), keyShare.EDDSAPubKey...) - } + keygen.EDDSAPubKey = append([]byte(nil), keyShare.PublicKey...) + default: + keygen.ECDSAPubKey = append([]byte(nil), keyShare.PublicKey...) } - return append([]byte(nil), keyShare.PublicKey...) + return keygen } -func keyShareProtocolPubKey(keyShare *sdkprotocol.KeyShareResult, target, sessionProtocol sdkprotocol.ProtocolType) []byte { - if keyShare == nil || target != sessionProtocol { +func keygenPublicKeyForProtocol(keygen *KeygenResult, protocol sdkprotocol.ProtocolType) []byte { + if keygen == nil { return nil } - return keySharePublicKeyForProtocol(keyShare, target) + switch protocol { + case sdkprotocol.ProtocolTypeECDSA: + return append([]byte(nil), keygen.ECDSAPubKey...) + case sdkprotocol.ProtocolTypeEdDSA: + return append([]byte(nil), keygen.EDDSAPubKey...) + default: + if len(keygen.ECDSAPubKey) > 0 { + return append([]byte(nil), keygen.ECDSAPubKey...) + } + return append([]byte(nil), keygen.EDDSAPubKey...) + } } func firstNonEmpty(values ...string) string { diff --git a/internal/coordinator/coordinator_test.go b/internal/coordinator/coordinator_test.go index 5a046563..5846391a 100644 --- a/internal/coordinator/coordinator_test.go +++ b/internal/coordinator/coordinator_test.go @@ -33,12 +33,12 @@ func (p *fakeControlPublisher) PublishControl(_ context.Context, participantID s } type fakeResultPublisher struct { - results map[string]*sdkprotocol.Result + results map[string]*Result } -func (p *fakeResultPublisher) PublishResult(_ context.Context, sessionID string, result *sdkprotocol.Result) error { +func (p *fakeResultPublisher) PublishResult(_ context.Context, sessionID string, result *Result) error { if p.results == nil { - p.results = map[string]*sdkprotocol.Result{} + p.results = map[string]*Result{} } p.results[sessionID] = result return nil @@ -186,12 +186,9 @@ func TestLifecycleCompletesKeygenWithoutShareBlob(t *testing.T) { } published := results.results[reply.SessionID] - if published == nil || published.KeyShare == nil { + if published == nil || published.Keygen == nil { t.Fatalf("missing published keygen result") } - if len(published.KeyShare.ShareBlob) != 0 { - t.Fatalf("share blob should not be required/published") - } } func TestHandleRequestRejectsDuplicateWalletIDAfterCompletedKeygen(t *testing.T) { @@ -412,25 +409,22 @@ func TestKeygenBothPublishesAggregatedPubKeys(t *testing.T) { completeKeygenSession(t, coord, fixtures, sessionsByProtocol[sdkprotocol.ProtocolTypeEdDSA], "wallet_dual_result", []byte("eddsa-pub")) published := results.results[accepted.SessionID] - if published == nil || published.KeyShare == nil { + if published == nil || published.Keygen == nil { t.Fatalf("missing published dual keygen result") } - if string(published.KeyShare.PublicKey) != "ecdsa-pub" { - t.Fatalf("public_key = %q", string(published.KeyShare.PublicKey)) - } - if string(published.KeyShare.ECDSAPubKey) != "ecdsa-pub" { - t.Fatalf("ecdsa_pubkey = %q", string(published.KeyShare.ECDSAPubKey)) + if string(published.Keygen.ECDSAPubKey) != "ecdsa-pub" { + t.Fatalf("ecdsa_pub_key = %q", string(published.Keygen.ECDSAPubKey)) } - if string(published.KeyShare.EDDSAPubKey) != "eddsa-pub" { - t.Fatalf("eddsa_pubkey = %q", string(published.KeyShare.EDDSAPubKey)) + if string(published.Keygen.EDDSAPubKey) != "eddsa-pub" { + t.Fatalf("eddsa_pub_key = %q", string(published.Keygen.EDDSAPubKey)) } session, ok := coord.store.Get(ctx, accepted.SessionID) if !ok { t.Fatalf("missing accepted session") } grpcResult := sessionToProtoResult(session) - if grpcResult.GetPublicKeyHex() != hex.EncodeToString([]byte("ecdsa-pub")) { - t.Fatalf("grpc public_key_hex = %q", grpcResult.GetPublicKeyHex()) + if grpcResult.GetKeyId() != "wallet_dual_result" { + t.Fatalf("grpc key_id = %q", grpcResult.GetKeyId()) } } @@ -606,14 +600,8 @@ func testGRPCDualKeygenReturnsAggregatedResult(t *testing.T, protocol string) { if !result.GetCompleted() { t.Fatalf("expected completed result: %+v", result) } - if result.GetPublicKeyHex() != hex.EncodeToString([]byte("grpc-ecdsa-pub")) { - t.Fatalf("public_key_hex = %q", result.GetPublicKeyHex()) - } - if result.GetEcdsaPubkey() != hex.EncodeToString([]byte("grpc-ecdsa-pub")) { - t.Fatalf("ecdsa_pubkey = %q", result.GetEcdsaPubkey()) - } - if result.GetEddsaPubkey() != hex.EncodeToString([]byte("grpc-eddsa-pub")) { - t.Fatalf("eddsa_pubkey = %q", result.GetEddsaPubkey()) + if result.GetKeyId() != "wallet_grpc_dual_"+protocol { + t.Fatalf("key_id = %q", result.GetKeyId()) } } @@ -687,21 +675,18 @@ func TestNATSRuntimeKeygenEmptyAndBothProtocolsPublishAggregatedResult(t *testin if err != nil { t.Fatal(err) } - var result sdkprotocol.Result + var result Result if err := json.Unmarshal(resultMsg.Data, &result); err != nil { t.Fatal(err) } - if result.KeyShare == nil { - t.Fatalf("missing key share result") - } - if string(result.KeyShare.PublicKey) != "nats-ecdsa-pub" { - t.Fatalf("public_key = %q", string(result.KeyShare.PublicKey)) + if result.Keygen == nil { + t.Fatalf("missing keygen result") } - if string(result.KeyShare.ECDSAPubKey) != "nats-ecdsa-pub" { - t.Fatalf("ecdsa_pubkey = %q", string(result.KeyShare.ECDSAPubKey)) + if string(result.Keygen.ECDSAPubKey) != "nats-ecdsa-pub" { + t.Fatalf("ecdsa_pub_key = %q", string(result.Keygen.ECDSAPubKey)) } - if string(result.KeyShare.EDDSAPubKey) != "nats-eddsa-pub" { - t.Fatalf("eddsa_pubkey = %q", string(result.KeyShare.EDDSAPubKey)) + if string(result.Keygen.EDDSAPubKey) != "nats-eddsa-pub" { + t.Fatalf("eddsa_pub_key = %q", string(result.Keygen.EDDSAPubKey)) } }) } @@ -896,21 +881,18 @@ func assertDualKeygenResult(t *testing.T, coord *Coordinator, sessionID, walletI if session.State != SessionCompleted { t.Fatalf("session state = %s, want %s", session.State, SessionCompleted) } - if session.Result == nil || session.Result.KeyShare == nil { - t.Fatalf("missing key share result") - } - keyShare := session.Result.KeyShare - if keyShare.KeyID != walletID { - t.Fatalf("key_id = %q, want %q", keyShare.KeyID, walletID) + if session.Result == nil || session.Result.Keygen == nil { + t.Fatalf("missing keygen result") } - if string(keyShare.PublicKey) != string(ecdsaPubKey) { - t.Fatalf("public_key = %q", string(keyShare.PublicKey)) + keygen := session.Result.Keygen + if keygen.KeyID != walletID { + t.Fatalf("key_id = %q, want %q", keygen.KeyID, walletID) } - if string(keyShare.ECDSAPubKey) != string(ecdsaPubKey) { - t.Fatalf("ecdsa_pubkey = %q", string(keyShare.ECDSAPubKey)) + if string(keygen.ECDSAPubKey) != string(ecdsaPubKey) { + t.Fatalf("ecdsa_pub_key = %q", string(keygen.ECDSAPubKey)) } - if string(keyShare.EDDSAPubKey) != string(eddsaPubKey) { - t.Fatalf("eddsa_pubkey = %q", string(keyShare.EDDSAPubKey)) + if string(keygen.EDDSAPubKey) != string(eddsaPubKey) { + t.Fatalf("eddsa_pub_key = %q", string(keygen.EDDSAPubKey)) } } diff --git a/internal/coordinator/orchestration_grpc.go b/internal/coordinator/orchestration_grpc.go index a5c1278f..6be930ca 100644 --- a/internal/coordinator/orchestration_grpc.go +++ b/internal/coordinator/orchestration_grpc.go @@ -177,16 +177,12 @@ func sessionToProtoResult(session *Session) *coordinatorv1.SessionResult { if session.Result == nil { return result } - if session.Result.KeyShare != nil { - result.KeyId = session.Result.KeyShare.KeyID - result.PublicKeyHex = hex.EncodeToString(session.Result.KeyShare.PublicKey) - result.EcdsaPubkey = hex.EncodeToString(session.Result.KeyShare.ECDSAPubKey) - result.EddsaPubkey = hex.EncodeToString(session.Result.KeyShare.EDDSAPubKey) + if session.Result.Keygen != nil { + result.KeyId = session.Result.Keygen.KeyID } if session.Result.Signature != nil { sig := session.Result.Signature result.KeyId = sig.KeyID - result.PublicKeyHex = hex.EncodeToString(sig.PublicKey) result.SignatureHex = hex.EncodeToString(sig.Signature) result.SignatureRecoveryHex = hex.EncodeToString(sig.SignatureRecovery) result.RHex = hex.EncodeToString(sig.R) diff --git a/internal/coordinator/publisher.go b/internal/coordinator/publisher.go index 18bf720c..d84f3bce 100644 --- a/internal/coordinator/publisher.go +++ b/internal/coordinator/publisher.go @@ -14,7 +14,7 @@ type ControlPublisher interface { } type ResultPublisher interface { - PublishResult(ctx context.Context, sessionID string, result *sdkprotocol.Result) error + PublishResult(ctx context.Context, sessionID string, result *Result) error } type NATSControlPublisher struct { @@ -44,7 +44,7 @@ func NewNATSResultPublisher(nc *nats.Conn) *NATSResultPublisher { return &NATSResultPublisher{nc: nc} } -func (p *NATSResultPublisher) PublishResult(ctx context.Context, sessionID string, result *sdkprotocol.Result) error { +func (p *NATSResultPublisher) PublishResult(ctx context.Context, sessionID string, result *Result) error { if err := ctx.Err(); err != nil { return err } diff --git a/internal/coordinator/store.go b/internal/coordinator/store.go index e94d8a07..5cec8e07 100644 --- a/internal/coordinator/store.go +++ b/internal/coordinator/store.go @@ -292,18 +292,16 @@ func cloneParticipants(participants []*sdkprotocol.SessionParticipant) []*sdkpro return out } -func cloneResult(result *sdkprotocol.Result) *sdkprotocol.Result { +func cloneResult(result *Result) *Result { if result == nil { return nil } cloned := *result - if result.KeyShare != nil { - keyShare := *result.KeyShare - keyShare.ShareBlob = append([]byte(nil), result.KeyShare.ShareBlob...) - keyShare.PublicKey = append([]byte(nil), result.KeyShare.PublicKey...) - keyShare.ECDSAPubKey = append([]byte(nil), result.KeyShare.ECDSAPubKey...) - keyShare.EDDSAPubKey = append([]byte(nil), result.KeyShare.EDDSAPubKey...) - cloned.KeyShare = &keyShare + if result.Keygen != nil { + keygen := *result.Keygen + keygen.ECDSAPubKey = append([]byte(nil), result.Keygen.ECDSAPubKey...) + keygen.EDDSAPubKey = append([]byte(nil), result.Keygen.EDDSAPubKey...) + cloned.Keygen = &keygen } if result.Signature != nil { signature := *result.Signature diff --git a/internal/coordinator/types.go b/internal/coordinator/types.go index 74969d33..92815ece 100644 --- a/internal/coordinator/types.go +++ b/internal/coordinator/types.go @@ -69,7 +69,7 @@ type Session struct { ParticipantState map[string]*ParticipantState `json:"participant_state"` ExchangeID string `json:"exchange_id,omitempty"` ResultHash string `json:"result_hash,omitempty"` - Result *sdkprotocol.Result `json:"result,omitempty"` + Result *Result `json:"result,omitempty"` ErrorCode string `json:"error_code,omitempty"` ErrorMessage string `json:"error_message,omitempty"` CreatedAt time.Time `json:"created_at"` @@ -79,3 +79,24 @@ type Session struct { ControlSeq uint64 `json:"control_seq"` ParticipantKeys map[string][]byte `json:"participant_keys"` } + +type KeygenResult struct { + KeyID string `json:"key_id"` + ECDSAPubKey []byte `json:"ecdsa_pub_key"` + EDDSAPubKey []byte `json:"eddsa_pub_key"` +} + +type SignResult struct { + KeyID string `json:"key_id"` + Signature []byte `json:"signature"` + SignatureRecovery []byte `json:"signature_recovery"` + R []byte `json:"r"` + S []byte `json:"s"` + SignedInput []byte `json:"signed_input"` + PublicKey []byte `json:"public_key"` +} + +type Result struct { + Keygen *KeygenResult `json:"keygen,omitempty"` + Signature *SignResult `json:"signature,omitempty"` +} diff --git a/pkg/coordinatorclient/client.go b/pkg/coordinatorclient/client.go index fdf5677a..2899b957 100644 --- a/pkg/coordinatorclient/client.go +++ b/pkg/coordinatorclient/client.go @@ -65,6 +65,27 @@ type SignRequest struct { Participants []SignParticipant } +type KeygenResult struct { + KeyID string `json:"key_id"` + ECDSAPubKey []byte `json:"ecdsa_pub_key,omitempty"` + EDDSAPubKey []byte `json:"eddsa_pub_key,omitempty"` +} + +type SignResult struct { + KeyID string `json:"key_id"` + Signature []byte `json:"signature,omitempty"` + SignatureRecovery []byte `json:"signature_recovery,omitempty"` + R []byte `json:"r,omitempty"` + S []byte `json:"s,omitempty"` + SignedInput []byte `json:"signed_input,omitempty"` + PublicKey []byte `json:"public_key,omitempty"` +} + +type Result struct { + Keygen *KeygenResult `json:"keygen,omitempty"` + Signature *SignResult `json:"signature,omitempty"` +} + func New(cfg Config) (*Client, error) { if cfg.Timeout <= 0 { cfg.Timeout = 5 * time.Second @@ -234,7 +255,7 @@ func normalizeProtocol(protocol sdkprotocol.ProtocolType) sdkprotocol.ProtocolTy return sdkprotocol.ProtocolType(value) } -func (c *Client) WaitSessionResult(ctx context.Context, sessionID string) (*sdkprotocol.Result, error) { +func (c *Client) WaitSessionResult(ctx context.Context, sessionID string) (*Result, error) { if sessionID == "" { return nil, fmt.Errorf("sessionID is required") } @@ -264,7 +285,7 @@ func (c *Client) WaitSessionResult(ctx context.Context, sessionID string) (*sdkp return nil, fmt.Errorf("wait session result: %w", err) } - var result *sdkprotocol.Result + var result *Result if err := json.Unmarshal(msg.Data, &result); err != nil { return nil, fmt.Errorf("decode session result: %w", err) } @@ -319,7 +340,7 @@ func (c *Client) requestSignGRPC(ctx context.Context, req SignRequest) (*sdkprot }, nil } -func (c *Client) waitSessionResultGRPC(ctx context.Context, sessionID string) (*sdkprotocol.Result, error) { +func (c *Client) waitSessionResultGRPC(ctx context.Context, sessionID string) (*Result, error) { resp, err := c.grpcClient.WaitSessionResult(ctx, &coordinatorv1.SessionLookup{SessionId: sessionID}) if err != nil { return nil, fmt.Errorf("wait session result: %w", err) @@ -333,29 +354,10 @@ func (c *Client) waitSessionResultGRPC(ctx context.Context, sessionID string) (* if err != nil { return nil, err } - return &sdkprotocol.Result{Signature: signature}, nil + return &Result{Signature: signature}, nil } - publicKey, err := decodeHexField("public_key_hex", resp.GetPublicKeyHex()) - if err != nil { - return nil, err - } - ecdsaPubKey, err := decodeHexField("ecdsa_pubkey", resp.GetEcdsaPubkey()) - if err != nil { - return nil, err - } - eddsaPubKey, err := decodeHexField("eddsa_pubkey", resp.GetEddsaPubkey()) - if err != nil { - return nil, err - } - return &sdkprotocol.Result{ - KeyShare: &sdkprotocol.KeyShareResult{ - KeyID: resp.GetKeyId(), - PublicKey: publicKey, - ECDSAPubKey: ecdsaPubKey, - EDDSAPubKey: eddsaPubKey, - }, - }, nil + return &Result{Keygen: &KeygenResult{KeyID: resp.GetKeyId()}}, nil } func mapParticipantsToProto(participants []KeygenParticipant) []*coordinatorv1.Participant { @@ -369,7 +371,7 @@ func mapParticipantsToProto(participants []KeygenParticipant) []*coordinatorv1.P return mapped } -func mapProtoSignature(resp *coordinatorv1.SessionResult) (*sdkprotocol.SignatureResult, error) { +func mapProtoSignature(resp *coordinatorv1.SessionResult) (*SignResult, error) { signature, err := decodeHexField("signature_hex", resp.GetSignatureHex()) if err != nil { return nil, err @@ -390,18 +392,13 @@ func mapProtoSignature(resp *coordinatorv1.SessionResult) (*sdkprotocol.Signatur if err != nil { return nil, err } - publicKey, err := decodeHexField("public_key_hex", resp.GetPublicKeyHex()) - if err != nil { - return nil, err - } - return &sdkprotocol.SignatureResult{ + return &SignResult{ KeyID: resp.GetKeyId(), Signature: signature, SignatureRecovery: recovery, R: r, S: s, SignedInput: signedInput, - PublicKey: publicKey, }, nil } diff --git a/pkg/coordinatorclient/client_test.go b/pkg/coordinatorclient/client_test.go index 91c36676..38c56efc 100644 --- a/pkg/coordinatorclient/client_test.go +++ b/pkg/coordinatorclient/client_test.go @@ -228,20 +228,17 @@ func TestGRPCClientWaitSessionResultMapsKeygenAndSignature(t *testing.T) { r := []byte("r") s := []byte("s") signedInput := []byte("message") - publicKey := []byte("public-key") client, cleanup := newTestGRPCClient(t, &fakeCoordinatorServer{ results: map[string]*coordinatorv1.SessionResult{ "sess_keygen": { - Completed: true, - SessionId: "sess_keygen", - KeyId: "wallet-1", - PublicKeyHex: hex.EncodeToString(publicKey), + Completed: true, + SessionId: "sess_keygen", + KeyId: "wallet-1", }, "sess_sign": { Completed: true, SessionId: "sess_sign", KeyId: "wallet-1", - PublicKeyHex: hex.EncodeToString(publicKey), SignatureHex: hex.EncodeToString(signature), SignatureRecoveryHex: hex.EncodeToString(recovery), RHex: hex.EncodeToString(r), @@ -256,7 +253,7 @@ func TestGRPCClientWaitSessionResultMapsKeygenAndSignature(t *testing.T) { if err != nil { t.Fatal(err) } - if keygenResult.KeyShare == nil || keygenResult.KeyShare.KeyID != "wallet-1" || string(keygenResult.KeyShare.PublicKey) != string(publicKey) { + if keygenResult.Keygen == nil || keygenResult.Keygen.KeyID != "wallet-1" { t.Fatalf("unexpected keygen result: %+v", keygenResult) } @@ -264,7 +261,7 @@ func TestGRPCClientWaitSessionResultMapsKeygenAndSignature(t *testing.T) { if err != nil { t.Fatal(err) } - if signResult.Signature == nil || string(signResult.Signature.Signature) != string(signature) || string(signResult.Signature.PublicKey) != string(publicKey) { + if signResult.Signature == nil || string(signResult.Signature.Signature) != string(signature) { t.Fatalf("unexpected sign result: %+v", signResult) } }