From 2900c652eaf838789ce1eef4d74cbd330ff5e5bb Mon Sep 17 00:00:00 2001 From: ahzero7d1 Date: Wed, 15 Apr 2026 14:52:47 +0200 Subject: [PATCH 1/5] Add ConnectionsReady, StartWeightSharing messages --- .../decentralized/decentralized_client.ts | 11 +++-- discojs/src/client/decentralized/messages.ts | 19 ++++++++- discojs/src/client/messages.ts | 4 ++ .../controllers/decentralized_controller.ts | 42 ++++++++++++++++++- 4 files changed, 69 insertions(+), 7 deletions(-) diff --git a/discojs/src/client/decentralized/decentralized_client.ts b/discojs/src/client/decentralized/decentralized_client.ts index 6f9da6e77..2851ce197 100644 --- a/discojs/src/client/decentralized/decentralized_client.ts +++ b/discojs/src/client/decentralized/decentralized_client.ts @@ -152,7 +152,9 @@ export class DecentralizedClient extends Client<"decentralized"> { await this.waitForParticipantsIfNeeded() // Create peer-to-peer connections with all peers for the round await this.establishPeerConnections() - // Exchange weight updates with peers and return aggregated weights + // Wait StartWeightSharing message from the server before exchanging weight updates + await waitMessage(this.server, type.StartWeightSharing) + // Exchange weight updates with peers and return aggregated weights // and then send out the contributions return await this.exchangeWeightUpdates(weights) } @@ -178,8 +180,9 @@ export class DecentralizedClient extends Client<"decentralized"> { try { debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`); const receivedMessage = await waitMessage(this.server, type.PeersForRound) - + const peers = Set(receivedMessage.peers) + debug(`[${shortenId(this.ownId)}] received peer list: %o`, peers.toArray()); if (this.ownId !== undefined && peers.has(this.ownId)) { throw new Error('received peer list contains our own id') @@ -198,7 +201,9 @@ export class DecentralizedClient extends Client<"decentralized"> { (conn) => this.receivePayloads(conn) ) - debug(`[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS()); + // Signal server that all connections with other peers in the round are established + this.server.send({ type: type.ConnectionsReady }); + debug(`[${shortenId(this.ownId)}] peer connections ready: %o`, connections.keySeq().toJS()); this.#connections = connections } catch (e) { debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e); diff --git a/discojs/src/client/decentralized/messages.ts b/discojs/src/client/decentralized/messages.ts index 626062ad4..ae6b2d891 100644 --- a/discojs/src/client/decentralized/messages.ts +++ b/discojs/src/client/decentralized/messages.ts @@ -38,6 +38,17 @@ export interface PeersForRound { aggregationRound: number } +// peer sends to server to signal all the connections to other peers +// are established +export interface ConnectionsReady { + type: type.ConnectionsReady +} + +// Server signal each peer to start weight update sharing +export interface StartWeightSharing { + type: type.StartWeightSharing; +} + /// Phase 1 communication (between peers) export interface Payload { @@ -55,13 +66,15 @@ export type MessageFromServer = SignalForPeer | PeersForRound | WaitingForMoreParticipants | - EnoughParticipants + EnoughParticipants | + StartWeightSharing export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady | - JoinRound + JoinRound | + ConnectionsReady export type PeerMessage = Payload @@ -80,6 +93,7 @@ export function isMessageFromServer (o: unknown): o is MessageFromServer { return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID) case type.WaitingForMoreParticipants: case type.EnoughParticipants: + case type.StartWeightSharing: return true } @@ -97,6 +111,7 @@ export function isMessageToServer (o: unknown): o is MessageToServer { 'signal' in o // TODO check signal content? case type.JoinRound: case type.PeerIsReady: + case type.ConnectionsReady: return true } diff --git a/discojs/src/client/messages.ts b/discojs/src/client/messages.ts index f5b5f9bb4..b0e9dc4e7 100644 --- a/discojs/src/client/messages.ts +++ b/discojs/src/client/messages.ts @@ -24,6 +24,10 @@ export enum type { // Message forwarded by the server from a client to another client // to establish a peer-to-peer (WebRTC) connection SignalForPeer, + // Message sent by nodes to server to signal all connections are established + ConnectionsReady, + // Sent by the server to signal nodes proceed to weight update sharing + StartWeightSharing, // The weight update Payload, diff --git a/server/src/controllers/decentralized_controller.ts b/server/src/controllers/decentralized_controller.ts index 37ea428f4..fa73ea768 100644 --- a/server/src/controllers/decentralized_controller.ts +++ b/server/src/controllers/decentralized_controller.ts @@ -21,6 +21,7 @@ export class DecentralizedController< // the node has already sent a PeerIsReady message) // We wait for all peers to be ready to exchange weight updates #roundPeers = Map() + #connectFinishedNodes = Map() #aggregationRound = 0 handle (ws: WebSocket): void { @@ -84,6 +85,11 @@ export class DecentralizedController< this.connections.get(msg.peer)?.send(msgpack.encode(forward)) break } + case MessageTypes.ConnectionsReady: { + this.#connectFinishedNodes = this.#connectFinishedNodes.set(peerId, true) + this.signalWeightSharing() + break + } default: { const _: never = msg throw new Error('should never happen') @@ -145,9 +151,41 @@ export class DecentralizedController< } return [conn, encoded] as [WebSocket, Buffer] }).forEach(([conn, encoded]) => { conn.send(encoded) }) + + // Initialize connectFinishedNodes with all peers set to false + this.#connectFinishedNodes = this.#roundPeers.map(() => false) as Map + this.#aggregationRound++ + } + + /** + * Check if all the participants of the round finished connecting + * with other peers in the round + * If so, send StartWeightSharing message to signal peers to proceed + */ + private signalWeightSharing(): void { + if (!this.#connectFinishedNodes.every((ready) => ready)) + return + this.#roundPeers.keySeq() + .map((id) => { + const startSignal = { + type: MessageTypes.StartWeightSharing, + } + debug("Signaling weight sharing to: %o", id.slice(0, 4)) + + const encoded = msgpack.encode(startSignal) + return [id, encoded] as [client.NodeID, Buffer] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error(`peer ${id} marked as ready but not connection to it`) + } + return [conn, encoded] as [WebSocket, Buffer] + }) + .forEach(([conn, encoded]) => {conn.send(encoded)}) + // empty the list of peers for the next round this.#roundPeers = Map() - this.#aggregationRound++ + this.#connectFinishedNodes = Map() } } - From dd22a7d1ba1c3f0396a26909e22bfeaafe424e6c Mon Sep 17 00:00:00 2001 From: ahzero7d1 Date: Wed, 15 Apr 2026 16:06:11 +0200 Subject: [PATCH 2/5] Update gitignore --- .gitignore | 3 +++ datasets/.gitignore | 12 +++--------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 4f09db6ff..8e885b6fc 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,6 @@ dist/ .idea/ .vscode/ *.DS_Store + +# venv +.venv/ diff --git a/datasets/.gitignore b/datasets/.gitignore index 1ae84a880..e644c626a 100644 --- a/datasets/.gitignore +++ b/datasets/.gitignore @@ -2,17 +2,11 @@ /2_QAID_1.masked.reshaped.squared.224.png /9-mnist-example.png /CIFAR10/ -/cifar10-agents -/cifar10-example.png -/cifar10-labels.csv +/cifar10* /simple_face /simple_face-example.png -/titanic_test.csv -/titanic_train.csv -/titanic_train_with_nan.csv -/titanic_test_with_nan.csv -/titanic_wrong_number_columns.csv -/titanic_wrong_passengerID.csv +/titanic* +/mnist* # wikitext /wikitext/ From 7d0e6344d6183c481ce90f3430f761a3b8cb0c47 Mon Sep 17 00:00:00 2001 From: ahzero7d1 Date: Wed, 15 Apr 2026 16:17:39 +0200 Subject: [PATCH 3/5] Fix lint error --- server/src/controllers/decentralized_controller.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/controllers/decentralized_controller.ts b/server/src/controllers/decentralized_controller.ts index fa73ea768..2e9f94c28 100644 --- a/server/src/controllers/decentralized_controller.ts +++ b/server/src/controllers/decentralized_controller.ts @@ -153,7 +153,7 @@ export class DecentralizedController< }).forEach(([conn, encoded]) => { conn.send(encoded) }) // Initialize connectFinishedNodes with all peers set to false - this.#connectFinishedNodes = this.#roundPeers.map(() => false) as Map + this.#connectFinishedNodes = this.#roundPeers.map(() => false) this.#aggregationRound++ } From 947d2c6644a7ebdce19eb5ba4c50cb87c2bf38ae Mon Sep 17 00:00:00 2001 From: ahzero7d1 Date: Thu, 30 Apr 2026 12:32:36 +0200 Subject: [PATCH 4/5] Add server timeout and retry handling for peer connection failures --- .../decentralized/decentralized_client.ts | 40 +++++- discojs/src/client/decentralized/messages.ts | 21 ++- discojs/src/client/messages.ts | 4 + .../controllers/decentralized_controller.ts | 125 +++++++++++++++++- .../components/training/TrainerDashboard.vue | 16 ++- 5 files changed, 196 insertions(+), 10 deletions(-) diff --git a/discojs/src/client/decentralized/decentralized_client.ts b/discojs/src/client/decentralized/decentralized_client.ts index 2851ce197..0e9611531 100644 --- a/discojs/src/client/decentralized/decentralized_client.ts +++ b/discojs/src/client/decentralized/decentralized_client.ts @@ -150,11 +150,41 @@ export class DecentralizedClient extends Client<"decentralized"> { this.saveAndEmit("connecting to peers") // First we check if we are waiting for more participants before sending our weight update await this.waitForParticipantsIfNeeded() - // Create peer-to-peer connections with all peers for the round - await this.establishPeerConnections() - // Wait StartWeightSharing message from the server before exchanging weight updates - await waitMessage(this.server, type.StartWeightSharing) - // Exchange weight updates with peers and return aggregated weights // and then send out the contributions + + while(true){ + // Create peer-to-peer connections with all peers for the round + await this.establishPeerConnections() + + // Wait for connection related messages from the server before exchanging weight updates + // (1) If the client receives a StartWeightSharing message, it proceeds to weight update exchange + // (2) If it receives a RetryPeerConnections message, it retries peer connection establishment + // (3) After multiple retires, if the connection is still unsuccessful, the server starts excluding nodes from the round + // and sends a ConnectionFail message to those nodes + // (4) Upon receiving ConnectionFail, the client disconnects from the server + const msg = await Promise.race([ + waitMessage(this.server, type.StartWeightSharing), + waitMessage(this.server, type.RetryPeerConnections), + waitMessage(this.server, type.ConnectionFail), + ]) + + if (msg.type === type.StartWeightSharing){ + break + } else if (msg.type === type.RetryPeerConnections){ + debug(`[${shortenId(this.ownId)}] retrying peer connection establishment`) + // clear the communication round peer pool + await this.#pool?.shutdown() + this.#pool = new PeerPool(this.ownId) + // clear the connections + this.#connections = Map() + this.setAggregatorNodes(Set(this.ownId)) + continue + } else if (msg.type === type.ConnectionFail){ + debug(`[${shortenId(this.ownId)}] disconnect from the server`) + await this.disconnect() + throw new Error("Client disconnected after connection failure") + } + } + // Exchange weight updates with peers and return aggregated weights return await this.exchangeWeightUpdates(weights) } diff --git a/discojs/src/client/decentralized/messages.ts b/discojs/src/client/decentralized/messages.ts index ae6b2d891..54dc234c4 100644 --- a/discojs/src/client/decentralized/messages.ts +++ b/discojs/src/client/decentralized/messages.ts @@ -44,9 +44,20 @@ export interface ConnectionsReady { type: type.ConnectionsReady } -// Server signal each peer to start weight update sharing +// Server signals each peer to start weight update sharing export interface StartWeightSharing { - type: type.StartWeightSharing; + type: type.StartWeightSharing +} + +// Server signals peers to reestablish peer connections +export interface RetryPeerConnections { + type: type.RetryPeerConnections + aggregationRound: number +} + +// Server signals a node that the connection with other peers failed +export interface ConnectionFail { + type: type.ConnectionFail } /// Phase 1 communication (between peers) @@ -67,7 +78,9 @@ export type MessageFromServer = PeersForRound | WaitingForMoreParticipants | EnoughParticipants | - StartWeightSharing + StartWeightSharing | + RetryPeerConnections | + ConnectionFail export type MessageToServer = ClientConnected | @@ -94,6 +107,8 @@ export function isMessageFromServer (o: unknown): o is MessageFromServer { case type.WaitingForMoreParticipants: case type.EnoughParticipants: case type.StartWeightSharing: + case type.RetryPeerConnections: + case type.ConnectionFail: return true } diff --git a/discojs/src/client/messages.ts b/discojs/src/client/messages.ts index b0e9dc4e7..c28e8b7f5 100644 --- a/discojs/src/client/messages.ts +++ b/discojs/src/client/messages.ts @@ -28,6 +28,10 @@ export enum type { ConnectionsReady, // Sent by the server to signal nodes proceed to weight update sharing StartWeightSharing, + // Sent by the server to signal nodes reestablish connections + RetryPeerConnections, + // Sent by the server to signal that the node's connection was not successful + ConnectionFail, // The weight update Payload, diff --git a/server/src/controllers/decentralized_controller.ts b/server/src/controllers/decentralized_controller.ts index 2e9f94c28..d1f9c1f5b 100644 --- a/server/src/controllers/decentralized_controller.ts +++ b/server/src/controllers/decentralized_controller.ts @@ -23,6 +23,8 @@ export class DecentralizedController< #roundPeers = Map() #connectFinishedNodes = Map() #aggregationRound = 0 + #timeout?: NodeJS.Timeout + #connectionRetry= 0 // number of connection retrial for specific aggregationRound handle (ws: WebSocket): void { const minNbOfParticipants = this.task.trainingInformation.minNbOfParticipants @@ -154,7 +156,11 @@ export class DecentralizedController< // Initialize connectFinishedNodes with all peers set to false this.#connectFinishedNodes = this.#roundPeers.map(() => false) - this.#aggregationRound++ + // Change the peer states to not ready + this.#roundPeers = this.#roundPeers.map(() => false) + + // Start timeout to check peer connections are successful + this.startTimeout() } /** @@ -163,8 +169,14 @@ export class DecentralizedController< * If so, send StartWeightSharing message to signal peers to proceed */ private signalWeightSharing(): void { + // Return if not all participants are ready if (!this.#connectFinishedNodes.every((ready) => ready)) return + + // Stop the timeout + this.clearTimeout() + + // Send round participants StartWeightSharing messages this.#roundPeers.keySeq() .map((id) => { const startSignal = { @@ -187,5 +199,116 @@ export class DecentralizedController< // empty the list of peers for the next round this.#roundPeers = Map() this.#connectFinishedNodes = Map() + this.#aggregationRound++ + } + + /** + * Set a timeout to check peer connections establishment + */ + private startTimeout(maxTime: number = 60_000): void { + this.#timeout = setTimeout(() => { + this.handleTimeout() + }, maxTime) + } + + /** + * Clear previously set timeout once all peer connections + * are established before the timeout + */ + private clearTimeout(): void { + if (this.#timeout !== undefined){ + clearTimeout(this.#timeout) + this.#timeout = undefined + } + } + + /** + * Called when a timeout occurs during peer connection + * Signals peers to discard existing connections and + * reestablish connections with the current set of peers + */ + private handleTimeout(): void { + debug(`Connection setup timeout for round ${this.#aggregationRound}, Retrying with same peers`) + // Increment the connection retry count + this.#connectionRetry += 1; + + // If the number of retries exceeds the threshold, exclude the failed peers from the roundPeers + if (this.#connectionRetry >= 3){ + const numFailedClient = this.#connectFinishedNodes.valueSeq().count((val) => val === false) + const remainingPeers = this.#roundPeers.size - numFailedClient + + // Exclude the failed peers + this.#connectFinishedNodes.forEach((connected, nodeId) => { + if (!connected){ + // If the node failed connection, exclude from #roundPeers + this.#roundPeers = this.#roundPeers.delete(nodeId) + // Signal the node that connection is failed for that node + const conn = this.connections.get(nodeId) + if (conn === undefined) { + throw new Error(`peer ${nodeId} marked as ready but not connection to it`) + } + const failSignal : messages.ConnectionFail = { + type: MessageTypes.ConnectionFail + } + const encoded = msgpack.encode(failSignal) + conn.send(encoded) + } + }) + + // If excluding failed peers would leave too few participants, + // restart the round + // TODO: We need to wait until minNbOfParticipants is satisfied + if (remainingPeers < this.task.trainingInformation.minNbOfParticipants){ + this.#roundPeers.keySeq() + .map((id) => { + const retrySignal = { + type: MessageTypes.RetryPeerConnections, + } + debug("Signaling connection retry to: %o", id.slice(0, 4)) + + const encoded = msgpack.encode(retrySignal) + return [id, encoded] as [client.NodeID, Buffer] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error(`peer ${id} marked as ready but not connection to it`) + } + return [conn, encoded] as [WebSocket, Buffer] + }) + .forEach(([conn, encoded]) => {conn.send(encoded)}) + + // empty the list of peers for the new round + // round number is not increased since this round failed + this.#roundPeers = Map() + this.#connectFinishedNodes = Map() + this.#connectionRetry = 0 + return + } + } + + // Retry peer connection with the currently remaining round peers + this.#connectFinishedNodes = this.#roundPeers.map(() => false) + this.#connectionRetry = 0 + + this.#roundPeers.keySeq() + .map((id) => { + const retrySignal = { + type: MessageTypes.RetryPeerConnections, + } + debug("Signaling connection retry to: %o", id.slice(0, 4)) + + const encoded = msgpack.encode(retrySignal) + return [id, encoded] as [client.NodeID, Buffer] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error(`peer ${id} marked as ready but not connection to it`) + } + return [conn, encoded] as [WebSocket, Buffer] + }) + .forEach(([conn, encoded]) => {conn.send(encoded)}) } } + diff --git a/webapp/src/components/training/TrainerDashboard.vue b/webapp/src/components/training/TrainerDashboard.vue index c7493060e..1ecc775d2 100644 --- a/webapp/src/components/training/TrainerDashboard.vue +++ b/webapp/src/components/training/TrainerDashboard.vue @@ -222,6 +222,9 @@ async function startTraining(): Promise { // manually interrupt the training cleanupDisco.value = async () => await disco.close() + // For the training completed message + let trainingCompleted = true + try { trainingGenerator.value = disco.train(dataset); @@ -245,6 +248,7 @@ async function startTraining(): Promise { epochsOfRoundLogs.value = List(); } } catch (e) { + trainingCompleted = false; if (e === stopper) { toaster.info("Training stopped"); return; @@ -262,6 +266,13 @@ async function startTraining(): Promise { toaster.error( "Training is not converging. Data potentially needs better preprocessing.", ); + } else if ( + e instanceof Error && + e.message.includes("Client disconnected after connection failure") + ){ + toaster.error( + "Client disconnected after multiple peer connection failure. Please rejoin the training." + ); } else { toaster.error("An error occurred during training"); } @@ -271,7 +282,10 @@ async function startTraining(): Promise { await cleanupTrainingSession() } - toaster.success("Training successfully completed"); + if (trainingCompleted){ + // printed only when the training is compeleted successfully + toaster.success("Training successfully completed"); + } } async function cleanupTrainingSession() { From b443f53babcf88f0f98c939f45210029d465225f Mon Sep 17 00:00:00 2001 From: ahzero7d1 Date: Thu, 7 May 2026 17:51:55 +0200 Subject: [PATCH 5/5] Implement decentralized model synchronization for clients joining during training --- discojs/src/client/client.ts | 13 ++ .../decentralized/decentralized_client.ts | 120 ++++++++++++++++- discojs/src/client/decentralized/messages.ts | 50 ++++++- discojs/src/client/messages.ts | 8 ++ discojs/src/training/disco.ts | 19 +++ .../controllers/decentralized_controller.ts | 122 ++++++++++++------ .../components/training/TrainerDashboard.vue | 9 +- 7 files changed, 292 insertions(+), 49 deletions(-) diff --git a/discojs/src/client/client.ts b/discojs/src/client/client.ts index 9e1298b93..e0fb26372 100644 --- a/discojs/src/client/client.ts +++ b/discojs/src/client/client.ts @@ -14,6 +14,7 @@ import type { EventConnection } from './event_connection.js' import type { Aggregator } from '../aggregator/index.js' import { EventEmitter } from '../utils/event_emitter.js' import { type } from "./messages.js"; +import { ModelWeightAccess } from "../training/disco.js"; const debug = createDebug("discojs:client"); @@ -38,6 +39,9 @@ export abstract class Client extends EventEmitter<{ */ protected promiseForMoreParticipants: Promise | undefined = undefined; + // Interface to access trainer's model weights + protected modelWeightAccess?: ModelWeightAccess; + /** * When the server notifies the client that they can resume training * after waiting for more participants, we want to be able to display what @@ -56,6 +60,15 @@ export abstract class Client extends EventEmitter<{ ) { super() } + + /** + * Used for decentralized learning. + * Set the interface used by client to access to trainer's model weights. + * Disco object provides this access. + */ + setModelWeightAccess(modelWeightAccess: ModelWeightAccess){ + this.modelWeightAccess = modelWeightAccess + } /** * Communication callback called at the beginning of every training round. diff --git a/discojs/src/client/decentralized/decentralized_client.ts b/discojs/src/client/decentralized/decentralized_client.ts index 0e9611531..93aaa3c78 100644 --- a/discojs/src/client/decentralized/decentralized_client.ts +++ b/discojs/src/client/decentralized/decentralized_client.ts @@ -26,6 +26,15 @@ export class DecentralizedClient extends Client<"decentralized"> { #pool?: PeerPool #connections?: Map + // Flag if this model requires model synchronization + #modelSyncNeeded?: boolean + + // Check if the training round is in progress + // Used to get the latest model for model synchronization + #isRoundInTraining = false + #roundFinishedPromise?: Promise + #resolveRoundFinished?: () => void // contains resolver + // Used to handle timeouts and promise resolving after calling disconnect private get isDisconnected() : boolean { return this._server === undefined @@ -69,6 +78,24 @@ export class DecentralizedClient extends Client<"decentralized"> { this.#pool.signal(event.peer, event.signal) }) + // Listen if the client is selected as a model provider node for a newly joining client. + // Upon receiving the signal, this client establishes a connection with the newcomer + // and sends the latest model weights. + this.server.on(type.SignalNewPeer, async (event) => { + if (this.#pool === undefined) throw new Error('received signal about new peer but peer pool is undefined') + const syncConnection = await this.#pool.getPeers(Set([event.newNode]), this.server, ()=>{}) + + const newcomerConn = syncConnection.get(event.newNode) + + if (newcomerConn === undefined){ + // if connection with newly joining client fails, print debug message + // and return + debug(`Cannot connect to newly joined client [${event.newNode}]`) + return + } + await this.sendModel(newcomerConn) + }) + // c.f. setupServerCallbacks doc for explanation let receivedEnoughParticipants = false this.setupServerCallbacks(() => receivedEnoughParticipants = true) @@ -79,8 +106,9 @@ export class DecentralizedClient extends Client<"decentralized"> { this.server.send(msg) const { id, waitForMoreParticipants, - nbOfParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo) - + nbOfParticipants, joinedMidTraining } = await waitMessage(this.server, type.NewDecentralizedNodeInfo) + + this.#modelSyncNeeded = joinedMidTraining this.nbOfParticipants = nbOfParticipants @@ -129,8 +157,38 @@ export class DecentralizedClient extends Client<"decentralized"> { * When connected, one peer creates a promise for every other peer's weight update * and waits for it to resolve. * + * If a client joined the training after the first round, + * model syncing happens first to get the latest model. */ override async onRoundBeginCommunication(): Promise { + if (this.#modelSyncNeeded) { + // 1. If model sync is needed, send server a request + this.server.send({ type: type.ModelSyncRequest }) + + // 2. Get the provider information from the server + const providerInfo = await waitMessageWithTimeout(this.server, type.SignalModelProvider, 30_000, "Timeout while waiting for the latest model provider") + + if (this.#pool === undefined) { + throw new Error('peer pool is undefined, make sure to call `client.connect()` first') + } + + // 3. Connect with model provider client and get the latest model + const syncConnection = await this.#pool.getPeers( + Set([providerInfo.providerNode]), + this.server, + ()=>{} + ) + const providerConn = syncConnection.get(providerInfo.providerNode) + + if (providerConn === undefined){ + throw new Error("The latest model provider is not connected") + } + + const latestModel = await this.receiveModel(providerConn) + this.modelWeightAccess?.setModelWeight(latestModel) + this.#modelSyncNeeded = false + } + // Notify the server we want to join the next round so that the server // waits for us to be ready before sending the list of peers for the round this.server.send({ type: type.JoinRound }) @@ -149,9 +207,11 @@ export class DecentralizedClient extends Client<"decentralized"> { // Once enough new participants join we can display the previous status again this.saveAndEmit("connecting to peers") // First we check if we are waiting for more participants before sending our weight update - await this.waitForParticipantsIfNeeded() while(true){ + // Wait until enough participants are available before continuing the round + await this.waitForParticipantsIfNeeded() + // Create peer-to-peer connections with all peers for the round await this.establishPeerConnections() @@ -185,7 +245,18 @@ export class DecentralizedClient extends Client<"decentralized"> { } } // Exchange weight updates with peers and return aggregated weights - return await this.exchangeWeightUpdates(weights) + let aggregatedWeight: WeightsContainer + try{ + aggregatedWeight = await this.exchangeWeightUpdates(weights) + } finally { + // Mark the round as finished so that model synchronization can proceed + this.#isRoundInTraining = false + this.#resolveRoundFinished?.() + this.#roundFinishedPromise = undefined + this.#resolveRoundFinished = undefined + } + + return aggregatedWeight } /** @@ -211,6 +282,12 @@ export class DecentralizedClient extends Client<"decentralized"> { debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`); const receivedMessage = await waitMessage(this.server, type.PeersForRound) + this.#isRoundInTraining = true + // Generate a promise that resolves when round training finishes + this.#roundFinishedPromise = new Promise((resolve) => { + this.#resolveRoundFinished = resolve + }) + const peers = Set(receivedMessage.peers) debug(`[${shortenId(this.ownId)}] received peer list: %o`, peers.toArray()); @@ -338,4 +415,39 @@ export class DecentralizedClient extends Client<"decentralized"> { } return await this.aggregationResult } + + /** + * Receive model from the model provider. + */ + private async receiveModel(providerConn: PeerConnection): Promise{ + const message = await waitMessageWithTimeout(providerConn, type.SharedModel, 30_000, "Timeout while waiting for the latest model") + + const decoded = serialization.weights.decode(message.model) + return decoded + } + + /** + * Send the latest available model to a newly joining client. + * If the current training round is in progress, wait until the round finishes + * and receive the latest aggregated model. + */ + private async sendModel(newcomerConn: PeerConnection): Promise { + if (this.#isRoundInTraining){ + await this.#roundFinishedPromise + } + + const model = this.modelWeightAccess?.getModelWeight() + + if (model === undefined){ + debug("Failed to get the latest model from model provider client") + return + } + const encoded = await serialization.weights.encode(model) + + const message: messages.SharedModel = { + type: type.SharedModel, + model: encoded + } + newcomerConn.send(message) + } } diff --git a/discojs/src/client/decentralized/messages.ts b/discojs/src/client/decentralized/messages.ts index 54dc234c4..1c56f45f8 100644 --- a/discojs/src/client/decentralized/messages.ts +++ b/discojs/src/client/decentralized/messages.ts @@ -12,6 +12,7 @@ export interface NewDecentralizedNodeInfo { id: NodeID waitForMoreParticipants: boolean nbOfParticipants: number + joinedMidTraining: boolean } // WebRTC signal to forward to other node @@ -60,6 +61,32 @@ export interface ConnectionFail { type: type.ConnectionFail } +// Nodes joining in the middle of the training send to server +// to request the latest model before starting local training +export interface ModelSyncRequest { + type: type.ModelSyncRequest +} + +// Server signals a node that shares the lastest model with node +// who joined in the middle of the training +export interface SignalNewPeer { + type: type.SignalNewPeer + newNode: NodeID +} + +// Server signals new node joining in the middle of the training +// about the model provider node +export interface SignalModelProvider { + type: type.SignalModelProvider + providerNode: NodeID +} + +// Sent by client to another client to share the latest model +export interface SharedModel { + type: type.SharedModel + model: serialization.Encoded +} + /// Phase 1 communication (between peers) export interface Payload { @@ -80,16 +107,21 @@ export type MessageFromServer = EnoughParticipants | StartWeightSharing | RetryPeerConnections | - ConnectionFail + ConnectionFail | + SignalModelProvider | + SignalNewPeer export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady | JoinRound | - ConnectionsReady + ConnectionsReady | + ModelSyncRequest -export type PeerMessage = Payload +export type PeerMessage = + Payload | + SharedModel export function isMessageFromServer (o: unknown): o is MessageFromServer { if (!hasMessageType(o)) return false @@ -101,14 +133,17 @@ export function isMessageFromServer (o: unknown): o is MessageFromServer { typeof o.waitForMoreParticipants === 'boolean' case type.SignalForPeer: return 'peer' in o && isNodeID(o.peer) && - 'signal' in o // TODO check signal content? + 'signal' in o case type.PeersForRound: return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID) + case type.SignalNewPeer: + return 'newNode' in o && isNodeID(o.newNode) case type.WaitingForMoreParticipants: case type.EnoughParticipants: case type.StartWeightSharing: case type.RetryPeerConnections: case type.ConnectionFail: + case type.SignalModelProvider: return true } @@ -123,10 +158,11 @@ export function isMessageToServer (o: unknown): o is MessageToServer { return true case type.SignalForPeer: return 'peer' in o && isNodeID(o.peer) && - 'signal' in o // TODO check signal content? + 'signal' in o case type.JoinRound: case type.PeerIsReady: case type.ConnectionsReady: + case type.ModelSyncRequest: return true } @@ -142,6 +178,10 @@ export function isPeerMessage (o: unknown): o is PeerMessage { 'peer' in o && isNodeID(o.peer) && 'payload' in o && serialization.isEncoded(o.payload) ) + case type.SharedModel: + return ( + 'model' in o && serialization.isEncoded(o.model) + ) } return false diff --git a/discojs/src/client/messages.ts b/discojs/src/client/messages.ts index c28e8b7f5..eb6682694 100644 --- a/discojs/src/client/messages.ts +++ b/discojs/src/client/messages.ts @@ -34,6 +34,14 @@ export enum type { ConnectionFail, // The weight update Payload, + // Sent by nodes to the server to request the latest model + ModelSyncRequest, + // Sent by the server to nodes to share the provider node info + SignalModelProvider, + // Sent by the server to nodes who was selected as a model provider node + SignalNewPeer, + // Sent by node to node to share the latest model weights + SharedModel, /* Federated */ // The server answers the ClientConnected message with the necessary information diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index 7dd019b2d..311ad1d7a 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -16,6 +16,7 @@ import type { Network, Task, } from "../index.js"; +import { WeightsContainer } from "../index.js"; import type { Aggregator } from "../aggregator/index.js"; import { getAggregator } from "../aggregator/index.js"; import { enumerate, split } from "../utils/async_iterator.js"; @@ -70,6 +71,15 @@ function buildSummaryLog(roundNum: number, epochNum: number, roundLogs: RoundLog } } +/** + * Interface providing an access to trainer's model weights. + * Used for model synchronization to retrieve and set the latest model. + */ +export interface ModelWeightAccess{ + getModelWeight(): WeightsContainer; + setModelWeight(weight: WeightsContainer): void; +} + /** * Top-level class handling distributed training from a client's perspective. It is meant to be * a convenient object providing a reduced yet complete API that wraps model training and @@ -127,6 +137,15 @@ export class Disco extends EventEmitter<{ this.#client = client; this.#task = task; this.trainer = new Trainer(task, client); + // Set ModelWeightAccess of the client + this.#client.setModelWeightAccess({ + getModelWeight: () => { + return new WeightsContainer(this.trainer.model.weights.weights.map(t => t.clone())); + }, + setModelWeight: (weights) => { + this.trainer.model.weights = weights; + } + }); // Simply propagate the training status events emitted by the client this.#client.on("status", (status) => this.emit("status", status)); this.#client.on("participants", (nbParticipants) => this.emit("participants", nbParticipants)); diff --git a/server/src/controllers/decentralized_controller.ts b/server/src/controllers/decentralized_controller.ts index d1f9c1f5b..ade4c8381 100644 --- a/server/src/controllers/decentralized_controller.ts +++ b/server/src/controllers/decentralized_controller.ts @@ -2,7 +2,7 @@ import createDebug from "debug"; import { v4 as randomUUID } from 'uuid' import * as msgpack from "@msgpack/msgpack"; import type WebSocket from 'ws' -import { Map } from 'immutable' +import { Map, Set } from 'immutable' import { client, DataType } from "@epfml/discojs"; @@ -24,7 +24,12 @@ export class DecentralizedController< #connectFinishedNodes = Map() #aggregationRound = 0 #timeout?: NodeJS.Timeout - #connectionRetry= 0 // number of connection retrial for specific aggregationRound + + // number of connection retries for the training round + #connectionRetry = 0 + // Client selected to provide the latest model to peers + // joining in the middle of training + #providerNode?: client.NodeID handle (ws: WebSocket): void { const minNbOfParticipants = this.task.trainingInformation.minNbOfParticipants @@ -49,12 +54,18 @@ export class DecentralizedController< debug(`peer [%s] joined ${this.task.id}`, shortId) this.connections = this.connections.set(peerId, ws) + let joinedMidTraining = false + if (this.#aggregationRound > 0){ + joinedMidTraining = true + } + // Answer with client id in an NewNodeInfo message const msg: messages.NewDecentralizedNodeInfo = { type: MessageTypes.NewDecentralizedNodeInfo, id: peerId, nbOfParticipants: this.connections.size, - waitForMoreParticipants: this.connections.size < minNbOfParticipants + waitForMoreParticipants: this.connections.size < minNbOfParticipants, + joinedMidTraining: joinedMidTraining, } ws.send(msgpack.encode(msg), { binary: true }) // Send an update to participants if we can start/resume training @@ -88,10 +99,40 @@ export class DecentralizedController< break } case MessageTypes.ConnectionsReady: { + // Select the first client that finishes peer connections + // as the model provider for clients joined mid-training + const numconnFinishedNodes = this.#connectFinishedNodes.reduce((acc, val) => acc + (val ? 1 : 0), 0) + if (!numconnFinishedNodes){ + this.#providerNode = peerId + } + this.#connectFinishedNodes = this.#connectFinishedNodes.set(peerId, true) this.signalWeightSharing() break } + case MessageTypes.ModelSyncRequest: { + // Upon receiving a model sync request, send relevant client information + // to both the model provider and the newly joined client + if (!this.#providerNode) { + debug("There is no provider node to share the latest model") + break + } + + // Signal the newly joined client with the provider client's information + const providerInfo: messages.SignalModelProvider = { + type: MessageTypes.SignalModelProvider, + providerNode: this.#providerNode + } + this.connections.get(peerId)?.send(msgpack.encode(providerInfo)) + + // Signal the provider client with newly joined client's information + const newNodeInfo: messages.SignalNewPeer = { + type: MessageTypes.SignalNewPeer, + newNode: peerId + } + this.connections.get(this.#providerNode)?.send(msgpack.encode(newNodeInfo)) + break + } default: { const _: never = msg throw new Error('should never happen') @@ -106,13 +147,13 @@ export class DecentralizedController< // Remove the participant when the websocket is closed this.connections = this.connections.delete(peerId) this.#roundPeers = this.#roundPeers.delete(peerId) + this.#connectFinishedNodes = this.#connectFinishedNodes.delete(peerId) debug("client [%s] left", shortId) // Check if we are already waiting for new participants to join if (this.waitingForMoreParticipants) return // If no, check if we are still above the minimum number of participant required if (this.connections.size >= minNbOfParticipants) { - // Check if remaining peers are all ready to exchange weight updates this.sendPeersForRoundIfNeeded() return } @@ -228,20 +269,20 @@ export class DecentralizedController< * reestablish connections with the current set of peers */ private handleTimeout(): void { + this.clearTimeout() debug(`Connection setup timeout for round ${this.#aggregationRound}, Retrying with same peers`) // Increment the connection retry count this.#connectionRetry += 1; - // If the number of retries exceeds the threshold, exclude the failed peers from the roundPeers + // If the number of retries exceeds the threshold, exclude the failed peers from the round + // and retry peer connection only with the remaining peers if (this.#connectionRetry >= 3){ - const numFailedClient = this.#connectFinishedNodes.valueSeq().count((val) => val === false) - const remainingPeers = this.#roundPeers.size - numFailedClient - // Exclude the failed peers this.#connectFinishedNodes.forEach((connected, nodeId) => { if (!connected){ // If the node failed connection, exclude from #roundPeers this.#roundPeers = this.#roundPeers.delete(nodeId) + this.#connectFinishedNodes = this.#connectFinishedNodes.delete(nodeId) // Signal the node that connection is failed for that node const conn = this.connections.get(nodeId) if (conn === undefined) { @@ -255,41 +296,41 @@ export class DecentralizedController< } }) - // If excluding failed peers would leave too few participants, - // restart the round - // TODO: We need to wait until minNbOfParticipants is satisfied - if (remainingPeers < this.task.trainingInformation.minNbOfParticipants){ - this.#roundPeers.keySeq() - .map((id) => { - const retrySignal = { - type: MessageTypes.RetryPeerConnections, - } - debug("Signaling connection retry to: %o", id.slice(0, 4)) + // Restart the round with remaining clients + this.#roundPeers.keySeq() + .map((id) => { + const retrySignal = { + type: MessageTypes.RetryPeerConnections, + } + debug("Signaling connection retry to: %o", id.slice(0, 4)) - const encoded = msgpack.encode(retrySignal) - return [id, encoded] as [client.NodeID, Buffer] - }) - .map(([id, encoded]) => { - const conn = this.connections.get(id) - if (conn === undefined) { - throw new Error(`peer ${id} marked as ready but not connection to it`) - } - return [conn, encoded] as [WebSocket, Buffer] - }) - .forEach(([conn, encoded]) => {conn.send(encoded)}) - - // empty the list of peers for the new round - // round number is not increased since this round failed - this.#roundPeers = Map() - this.#connectFinishedNodes = Map() - this.#connectionRetry = 0 - return - } + const encoded = msgpack.encode(retrySignal) + return [id, encoded] as [client.NodeID, Buffer] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error(`peer ${id} marked as ready but not connection to it`) + } + return [conn, encoded] as [WebSocket, Buffer] + }) + .forEach(([conn, encoded]) => {conn.send(encoded)}) + + // Reset the ready and connection status of roundPeers + this.#roundPeers = this.#roundPeers.map(() => false) + this.#connectFinishedNodes = this.#roundPeers.map(() => false) + + // Reset the connectionRetry since we excluded the failed clients + this.#connectionRetry = 0 + // Restart the timeout after sending retry messages + this.startTimeout() + return } - // Retry peer connection with the currently remaining round peers + // Retry peer connection with original roundPeers + // Reset the ready and connection status of roundPeers + this.#roundPeers = this.#roundPeers.map(() => false) this.#connectFinishedNodes = this.#roundPeers.map(() => false) - this.#connectionRetry = 0 this.#roundPeers.keySeq() .map((id) => { @@ -309,6 +350,9 @@ export class DecentralizedController< return [conn, encoded] as [WebSocket, Buffer] }) .forEach(([conn, encoded]) => {conn.send(encoded)}) + + // Restart the timeout after sending retry messages + this.startTimeout() } } diff --git a/webapp/src/components/training/TrainerDashboard.vue b/webapp/src/components/training/TrainerDashboard.vue index 1ecc775d2..3680175b1 100644 --- a/webapp/src/components/training/TrainerDashboard.vue +++ b/webapp/src/components/training/TrainerDashboard.vue @@ -273,8 +273,15 @@ async function startTraining(): Promise { toaster.error( "Client disconnected after multiple peer connection failure. Please rejoin the training." ); + } else if ( + e instanceof Error && + e.message.includes("Timeout while waiting for the latest model") + ){ + toaster.error( + "Timeout while waiting for the model syncing. Please rejoin the training." + ); } else { - toaster.error("An error occurred during training"); + toaster.error("An error occurred during training.") } debug("while training: %o", e); } finally {