diff --git a/discojs/src/client/client.ts b/discojs/src/client/client.ts index 9e1298b93..e0fb26372 100644 --- a/discojs/src/client/client.ts +++ b/discojs/src/client/client.ts @@ -14,6 +14,7 @@ import type { EventConnection } from './event_connection.js' import type { Aggregator } from '../aggregator/index.js' import { EventEmitter } from '../utils/event_emitter.js' import { type } from "./messages.js"; +import { ModelWeightAccess } from "../training/disco.js"; const debug = createDebug("discojs:client"); @@ -38,6 +39,9 @@ export abstract class Client extends EventEmitter<{ */ protected promiseForMoreParticipants: Promise | undefined = undefined; + // Interface to access trainer's model weights + protected modelWeightAccess?: ModelWeightAccess; + /** * When the server notifies the client that they can resume training * after waiting for more participants, we want to be able to display what @@ -56,6 +60,15 @@ export abstract class Client extends EventEmitter<{ ) { super() } + + /** + * Used for decentralized learning. + * Set the interface used by client to access to trainer's model weights. + * Disco object provides this access. + */ + setModelWeightAccess(modelWeightAccess: ModelWeightAccess){ + this.modelWeightAccess = modelWeightAccess + } /** * Communication callback called at the beginning of every training round. diff --git a/discojs/src/client/decentralized/decentralized_client.ts b/discojs/src/client/decentralized/decentralized_client.ts index 6f9da6e77..93aaa3c78 100644 --- a/discojs/src/client/decentralized/decentralized_client.ts +++ b/discojs/src/client/decentralized/decentralized_client.ts @@ -26,6 +26,15 @@ export class DecentralizedClient extends Client<"decentralized"> { #pool?: PeerPool #connections?: Map + // Flag if this model requires model synchronization + #modelSyncNeeded?: boolean + + // Check if the training round is in progress + // Used to get the latest model for model synchronization + #isRoundInTraining = false + #roundFinishedPromise?: Promise + #resolveRoundFinished?: () => void // contains resolver + // Used to handle timeouts and promise resolving after calling disconnect private get isDisconnected() : boolean { return this._server === undefined @@ -69,6 +78,24 @@ export class DecentralizedClient extends Client<"decentralized"> { this.#pool.signal(event.peer, event.signal) }) + // Listen if the client is selected as a model provider node for a newly joining client. + // Upon receiving the signal, this client establishes a connection with the newcomer + // and sends the latest model weights. + this.server.on(type.SignalNewPeer, async (event) => { + if (this.#pool === undefined) throw new Error('received signal about new peer but peer pool is undefined') + const syncConnection = await this.#pool.getPeers(Set([event.newNode]), this.server, ()=>{}) + + const newcomerConn = syncConnection.get(event.newNode) + + if (newcomerConn === undefined){ + // if connection with newly joining client fails, print debug message + // and return + debug(`Cannot connect to newly joined client [${event.newNode}]`) + return + } + await this.sendModel(newcomerConn) + }) + // c.f. setupServerCallbacks doc for explanation let receivedEnoughParticipants = false this.setupServerCallbacks(() => receivedEnoughParticipants = true) @@ -79,8 +106,9 @@ export class DecentralizedClient extends Client<"decentralized"> { this.server.send(msg) const { id, waitForMoreParticipants, - nbOfParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo) - + nbOfParticipants, joinedMidTraining } = await waitMessage(this.server, type.NewDecentralizedNodeInfo) + + this.#modelSyncNeeded = joinedMidTraining this.nbOfParticipants = nbOfParticipants @@ -129,8 +157,38 @@ export class DecentralizedClient extends Client<"decentralized"> { * When connected, one peer creates a promise for every other peer's weight update * and waits for it to resolve. * + * If a client joined the training after the first round, + * model syncing happens first to get the latest model. */ override async onRoundBeginCommunication(): Promise { + if (this.#modelSyncNeeded) { + // 1. If model sync is needed, send server a request + this.server.send({ type: type.ModelSyncRequest }) + + // 2. Get the provider information from the server + const providerInfo = await waitMessageWithTimeout(this.server, type.SignalModelProvider, 30_000, "Timeout while waiting for the latest model provider") + + if (this.#pool === undefined) { + throw new Error('peer pool is undefined, make sure to call `client.connect()` first') + } + + // 3. Connect with model provider client and get the latest model + const syncConnection = await this.#pool.getPeers( + Set([providerInfo.providerNode]), + this.server, + ()=>{} + ) + const providerConn = syncConnection.get(providerInfo.providerNode) + + if (providerConn === undefined){ + throw new Error("The latest model provider is not connected") + } + + const latestModel = await this.receiveModel(providerConn) + this.modelWeightAccess?.setModelWeight(latestModel) + this.#modelSyncNeeded = false + } + // Notify the server we want to join the next round so that the server // waits for us to be ready before sending the list of peers for the round this.server.send({ type: type.JoinRound }) @@ -149,11 +207,56 @@ export class DecentralizedClient extends Client<"decentralized"> { // Once enough new participants join we can display the previous status again this.saveAndEmit("connecting to peers") // First we check if we are waiting for more participants before sending our weight update - await this.waitForParticipantsIfNeeded() - // Create peer-to-peer connections with all peers for the round - await this.establishPeerConnections() + + while(true){ + // Wait until enough participants are available before continuing the round + await this.waitForParticipantsIfNeeded() + + // Create peer-to-peer connections with all peers for the round + await this.establishPeerConnections() + + // Wait for connection related messages from the server before exchanging weight updates + // (1) If the client receives a StartWeightSharing message, it proceeds to weight update exchange + // (2) If it receives a RetryPeerConnections message, it retries peer connection establishment + // (3) After multiple retires, if the connection is still unsuccessful, the server starts excluding nodes from the round + // and sends a ConnectionFail message to those nodes + // (4) Upon receiving ConnectionFail, the client disconnects from the server + const msg = await Promise.race([ + waitMessage(this.server, type.StartWeightSharing), + waitMessage(this.server, type.RetryPeerConnections), + waitMessage(this.server, type.ConnectionFail), + ]) + + if (msg.type === type.StartWeightSharing){ + break + } else if (msg.type === type.RetryPeerConnections){ + debug(`[${shortenId(this.ownId)}] retrying peer connection establishment`) + // clear the communication round peer pool + await this.#pool?.shutdown() + this.#pool = new PeerPool(this.ownId) + // clear the connections + this.#connections = Map() + this.setAggregatorNodes(Set(this.ownId)) + continue + } else if (msg.type === type.ConnectionFail){ + debug(`[${shortenId(this.ownId)}] disconnect from the server`) + await this.disconnect() + throw new Error("Client disconnected after connection failure") + } + } // Exchange weight updates with peers and return aggregated weights - return await this.exchangeWeightUpdates(weights) + let aggregatedWeight: WeightsContainer + try{ + aggregatedWeight = await this.exchangeWeightUpdates(weights) + } finally { + // Mark the round as finished so that model synchronization can proceed + this.#isRoundInTraining = false + this.#resolveRoundFinished?.() + this.#roundFinishedPromise = undefined + this.#resolveRoundFinished = undefined + } + + return aggregatedWeight } /** @@ -178,8 +281,15 @@ export class DecentralizedClient extends Client<"decentralized"> { try { debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`); const receivedMessage = await waitMessage(this.server, type.PeersForRound) - + + this.#isRoundInTraining = true + // Generate a promise that resolves when round training finishes + this.#roundFinishedPromise = new Promise((resolve) => { + this.#resolveRoundFinished = resolve + }) + const peers = Set(receivedMessage.peers) + debug(`[${shortenId(this.ownId)}] received peer list: %o`, peers.toArray()); if (this.ownId !== undefined && peers.has(this.ownId)) { throw new Error('received peer list contains our own id') @@ -198,7 +308,9 @@ export class DecentralizedClient extends Client<"decentralized"> { (conn) => this.receivePayloads(conn) ) - debug(`[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS()); + // Signal server that all connections with other peers in the round are established + this.server.send({ type: type.ConnectionsReady }); + debug(`[${shortenId(this.ownId)}] peer connections ready: %o`, connections.keySeq().toJS()); this.#connections = connections } catch (e) { debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e); @@ -303,4 +415,39 @@ export class DecentralizedClient extends Client<"decentralized"> { } return await this.aggregationResult } + + /** + * Receive model from the model provider. + */ + private async receiveModel(providerConn: PeerConnection): Promise{ + const message = await waitMessageWithTimeout(providerConn, type.SharedModel, 30_000, "Timeout while waiting for the latest model") + + const decoded = serialization.weights.decode(message.model) + return decoded + } + + /** + * Send the latest available model to a newly joining client. + * If the current training round is in progress, wait until the round finishes + * and receive the latest aggregated model. + */ + private async sendModel(newcomerConn: PeerConnection): Promise { + if (this.#isRoundInTraining){ + await this.#roundFinishedPromise + } + + const model = this.modelWeightAccess?.getModelWeight() + + if (model === undefined){ + debug("Failed to get the latest model from model provider client") + return + } + const encoded = await serialization.weights.encode(model) + + const message: messages.SharedModel = { + type: type.SharedModel, + model: encoded + } + newcomerConn.send(message) + } } diff --git a/discojs/src/client/decentralized/messages.ts b/discojs/src/client/decentralized/messages.ts index 626062ad4..1c56f45f8 100644 --- a/discojs/src/client/decentralized/messages.ts +++ b/discojs/src/client/decentralized/messages.ts @@ -12,6 +12,7 @@ export interface NewDecentralizedNodeInfo { id: NodeID waitForMoreParticipants: boolean nbOfParticipants: number + joinedMidTraining: boolean } // WebRTC signal to forward to other node @@ -38,6 +39,54 @@ export interface PeersForRound { aggregationRound: number } +// peer sends to server to signal all the connections to other peers +// are established +export interface ConnectionsReady { + type: type.ConnectionsReady +} + +// Server signals each peer to start weight update sharing +export interface StartWeightSharing { + type: type.StartWeightSharing +} + +// Server signals peers to reestablish peer connections +export interface RetryPeerConnections { + type: type.RetryPeerConnections + aggregationRound: number +} + +// Server signals a node that the connection with other peers failed +export interface ConnectionFail { + type: type.ConnectionFail +} + +// Nodes joining in the middle of the training send to server +// to request the latest model before starting local training +export interface ModelSyncRequest { + type: type.ModelSyncRequest +} + +// Server signals a node that shares the lastest model with node +// who joined in the middle of the training +export interface SignalNewPeer { + type: type.SignalNewPeer + newNode: NodeID +} + +// Server signals new node joining in the middle of the training +// about the model provider node +export interface SignalModelProvider { + type: type.SignalModelProvider + providerNode: NodeID +} + +// Sent by client to another client to share the latest model +export interface SharedModel { + type: type.SharedModel + model: serialization.Encoded +} + /// Phase 1 communication (between peers) export interface Payload { @@ -55,15 +104,24 @@ export type MessageFromServer = SignalForPeer | PeersForRound | WaitingForMoreParticipants | - EnoughParticipants + EnoughParticipants | + StartWeightSharing | + RetryPeerConnections | + ConnectionFail | + SignalModelProvider | + SignalNewPeer export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady | - JoinRound + JoinRound | + ConnectionsReady | + ModelSyncRequest -export type PeerMessage = Payload +export type PeerMessage = + Payload | + SharedModel export function isMessageFromServer (o: unknown): o is MessageFromServer { if (!hasMessageType(o)) return false @@ -75,11 +133,17 @@ export function isMessageFromServer (o: unknown): o is MessageFromServer { typeof o.waitForMoreParticipants === 'boolean' case type.SignalForPeer: return 'peer' in o && isNodeID(o.peer) && - 'signal' in o // TODO check signal content? + 'signal' in o case type.PeersForRound: return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID) + case type.SignalNewPeer: + return 'newNode' in o && isNodeID(o.newNode) case type.WaitingForMoreParticipants: case type.EnoughParticipants: + case type.StartWeightSharing: + case type.RetryPeerConnections: + case type.ConnectionFail: + case type.SignalModelProvider: return true } @@ -94,9 +158,11 @@ export function isMessageToServer (o: unknown): o is MessageToServer { return true case type.SignalForPeer: return 'peer' in o && isNodeID(o.peer) && - 'signal' in o // TODO check signal content? + 'signal' in o case type.JoinRound: case type.PeerIsReady: + case type.ConnectionsReady: + case type.ModelSyncRequest: return true } @@ -112,6 +178,10 @@ export function isPeerMessage (o: unknown): o is PeerMessage { 'peer' in o && isNodeID(o.peer) && 'payload' in o && serialization.isEncoded(o.payload) ) + case type.SharedModel: + return ( + 'model' in o && serialization.isEncoded(o.model) + ) } return false diff --git a/discojs/src/client/messages.ts b/discojs/src/client/messages.ts index f5b5f9bb4..eb6682694 100644 --- a/discojs/src/client/messages.ts +++ b/discojs/src/client/messages.ts @@ -24,8 +24,24 @@ export enum type { // Message forwarded by the server from a client to another client // to establish a peer-to-peer (WebRTC) connection SignalForPeer, + // Message sent by nodes to server to signal all connections are established + ConnectionsReady, + // Sent by the server to signal nodes proceed to weight update sharing + StartWeightSharing, + // Sent by the server to signal nodes reestablish connections + RetryPeerConnections, + // Sent by the server to signal that the node's connection was not successful + ConnectionFail, // The weight update Payload, + // Sent by nodes to the server to request the latest model + ModelSyncRequest, + // Sent by the server to nodes to share the provider node info + SignalModelProvider, + // Sent by the server to nodes who was selected as a model provider node + SignalNewPeer, + // Sent by node to node to share the latest model weights + SharedModel, /* Federated */ // The server answers the ClientConnected message with the necessary information diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index 7dd019b2d..311ad1d7a 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -16,6 +16,7 @@ import type { Network, Task, } from "../index.js"; +import { WeightsContainer } from "../index.js"; import type { Aggregator } from "../aggregator/index.js"; import { getAggregator } from "../aggregator/index.js"; import { enumerate, split } from "../utils/async_iterator.js"; @@ -70,6 +71,15 @@ function buildSummaryLog(roundNum: number, epochNum: number, roundLogs: RoundLog } } +/** + * Interface providing an access to trainer's model weights. + * Used for model synchronization to retrieve and set the latest model. + */ +export interface ModelWeightAccess{ + getModelWeight(): WeightsContainer; + setModelWeight(weight: WeightsContainer): void; +} + /** * Top-level class handling distributed training from a client's perspective. It is meant to be * a convenient object providing a reduced yet complete API that wraps model training and @@ -127,6 +137,15 @@ export class Disco extends EventEmitter<{ this.#client = client; this.#task = task; this.trainer = new Trainer(task, client); + // Set ModelWeightAccess of the client + this.#client.setModelWeightAccess({ + getModelWeight: () => { + return new WeightsContainer(this.trainer.model.weights.weights.map(t => t.clone())); + }, + setModelWeight: (weights) => { + this.trainer.model.weights = weights; + } + }); // Simply propagate the training status events emitted by the client this.#client.on("status", (status) => this.emit("status", status)); this.#client.on("participants", (nbParticipants) => this.emit("participants", nbParticipants)); diff --git a/server/src/controllers/decentralized_controller.ts b/server/src/controllers/decentralized_controller.ts index 37ea428f4..ade4c8381 100644 --- a/server/src/controllers/decentralized_controller.ts +++ b/server/src/controllers/decentralized_controller.ts @@ -2,7 +2,7 @@ import createDebug from "debug"; import { v4 as randomUUID } from 'uuid' import * as msgpack from "@msgpack/msgpack"; import type WebSocket from 'ws' -import { Map } from 'immutable' +import { Map, Set } from 'immutable' import { client, DataType } from "@epfml/discojs"; @@ -21,7 +21,15 @@ export class DecentralizedController< // the node has already sent a PeerIsReady message) // We wait for all peers to be ready to exchange weight updates #roundPeers = Map() + #connectFinishedNodes = Map() #aggregationRound = 0 + #timeout?: NodeJS.Timeout + + // number of connection retries for the training round + #connectionRetry = 0 + // Client selected to provide the latest model to peers + // joining in the middle of training + #providerNode?: client.NodeID handle (ws: WebSocket): void { const minNbOfParticipants = this.task.trainingInformation.minNbOfParticipants @@ -46,12 +54,18 @@ export class DecentralizedController< debug(`peer [%s] joined ${this.task.id}`, shortId) this.connections = this.connections.set(peerId, ws) + let joinedMidTraining = false + if (this.#aggregationRound > 0){ + joinedMidTraining = true + } + // Answer with client id in an NewNodeInfo message const msg: messages.NewDecentralizedNodeInfo = { type: MessageTypes.NewDecentralizedNodeInfo, id: peerId, nbOfParticipants: this.connections.size, - waitForMoreParticipants: this.connections.size < minNbOfParticipants + waitForMoreParticipants: this.connections.size < minNbOfParticipants, + joinedMidTraining: joinedMidTraining, } ws.send(msgpack.encode(msg), { binary: true }) // Send an update to participants if we can start/resume training @@ -84,6 +98,41 @@ export class DecentralizedController< this.connections.get(msg.peer)?.send(msgpack.encode(forward)) break } + case MessageTypes.ConnectionsReady: { + // Select the first client that finishes peer connections + // as the model provider for clients joined mid-training + const numconnFinishedNodes = this.#connectFinishedNodes.reduce((acc, val) => acc + (val ? 1 : 0), 0) + if (!numconnFinishedNodes){ + this.#providerNode = peerId + } + + this.#connectFinishedNodes = this.#connectFinishedNodes.set(peerId, true) + this.signalWeightSharing() + break + } + case MessageTypes.ModelSyncRequest: { + // Upon receiving a model sync request, send relevant client information + // to both the model provider and the newly joined client + if (!this.#providerNode) { + debug("There is no provider node to share the latest model") + break + } + + // Signal the newly joined client with the provider client's information + const providerInfo: messages.SignalModelProvider = { + type: MessageTypes.SignalModelProvider, + providerNode: this.#providerNode + } + this.connections.get(peerId)?.send(msgpack.encode(providerInfo)) + + // Signal the provider client with newly joined client's information + const newNodeInfo: messages.SignalNewPeer = { + type: MessageTypes.SignalNewPeer, + newNode: peerId + } + this.connections.get(this.#providerNode)?.send(msgpack.encode(newNodeInfo)) + break + } default: { const _: never = msg throw new Error('should never happen') @@ -98,13 +147,13 @@ export class DecentralizedController< // Remove the participant when the websocket is closed this.connections = this.connections.delete(peerId) this.#roundPeers = this.#roundPeers.delete(peerId) + this.#connectFinishedNodes = this.#connectFinishedNodes.delete(peerId) debug("client [%s] left", shortId) // Check if we are already waiting for new participants to join if (this.waitingForMoreParticipants) return // If no, check if we are still above the minimum number of participant required if (this.connections.size >= minNbOfParticipants) { - // Check if remaining peers are all ready to exchange weight updates this.sendPeersForRoundIfNeeded() return } @@ -145,9 +194,165 @@ export class DecentralizedController< } return [conn, encoded] as [WebSocket, Buffer] }).forEach(([conn, encoded]) => { conn.send(encoded) }) + + // Initialize connectFinishedNodes with all peers set to false + this.#connectFinishedNodes = this.#roundPeers.map(() => false) + // Change the peer states to not ready + this.#roundPeers = this.#roundPeers.map(() => false) + + // Start timeout to check peer connections are successful + this.startTimeout() + } + + /** + * Check if all the participants of the round finished connecting + * with other peers in the round + * If so, send StartWeightSharing message to signal peers to proceed + */ + private signalWeightSharing(): void { + // Return if not all participants are ready + if (!this.#connectFinishedNodes.every((ready) => ready)) + return + + // Stop the timeout + this.clearTimeout() + + // Send round participants StartWeightSharing messages + this.#roundPeers.keySeq() + .map((id) => { + const startSignal = { + type: MessageTypes.StartWeightSharing, + } + debug("Signaling weight sharing to: %o", id.slice(0, 4)) + + const encoded = msgpack.encode(startSignal) + return [id, encoded] as [client.NodeID, Buffer] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error(`peer ${id} marked as ready but not connection to it`) + } + return [conn, encoded] as [WebSocket, Buffer] + }) + .forEach(([conn, encoded]) => {conn.send(encoded)}) + // empty the list of peers for the next round this.#roundPeers = Map() + this.#connectFinishedNodes = Map() this.#aggregationRound++ } + + /** + * Set a timeout to check peer connections establishment + */ + private startTimeout(maxTime: number = 60_000): void { + this.#timeout = setTimeout(() => { + this.handleTimeout() + }, maxTime) + } + + /** + * Clear previously set timeout once all peer connections + * are established before the timeout + */ + private clearTimeout(): void { + if (this.#timeout !== undefined){ + clearTimeout(this.#timeout) + this.#timeout = undefined + } + } + + /** + * Called when a timeout occurs during peer connection + * Signals peers to discard existing connections and + * reestablish connections with the current set of peers + */ + private handleTimeout(): void { + this.clearTimeout() + debug(`Connection setup timeout for round ${this.#aggregationRound}, Retrying with same peers`) + // Increment the connection retry count + this.#connectionRetry += 1; + + // If the number of retries exceeds the threshold, exclude the failed peers from the round + // and retry peer connection only with the remaining peers + if (this.#connectionRetry >= 3){ + // Exclude the failed peers + this.#connectFinishedNodes.forEach((connected, nodeId) => { + if (!connected){ + // If the node failed connection, exclude from #roundPeers + this.#roundPeers = this.#roundPeers.delete(nodeId) + this.#connectFinishedNodes = this.#connectFinishedNodes.delete(nodeId) + // Signal the node that connection is failed for that node + const conn = this.connections.get(nodeId) + if (conn === undefined) { + throw new Error(`peer ${nodeId} marked as ready but not connection to it`) + } + const failSignal : messages.ConnectionFail = { + type: MessageTypes.ConnectionFail + } + const encoded = msgpack.encode(failSignal) + conn.send(encoded) + } + }) + + // Restart the round with remaining clients + this.#roundPeers.keySeq() + .map((id) => { + const retrySignal = { + type: MessageTypes.RetryPeerConnections, + } + debug("Signaling connection retry to: %o", id.slice(0, 4)) + + const encoded = msgpack.encode(retrySignal) + return [id, encoded] as [client.NodeID, Buffer] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error(`peer ${id} marked as ready but not connection to it`) + } + return [conn, encoded] as [WebSocket, Buffer] + }) + .forEach(([conn, encoded]) => {conn.send(encoded)}) + + // Reset the ready and connection status of roundPeers + this.#roundPeers = this.#roundPeers.map(() => false) + this.#connectFinishedNodes = this.#roundPeers.map(() => false) + + // Reset the connectionRetry since we excluded the failed clients + this.#connectionRetry = 0 + // Restart the timeout after sending retry messages + this.startTimeout() + return + } + + // Retry peer connection with original roundPeers + // Reset the ready and connection status of roundPeers + this.#roundPeers = this.#roundPeers.map(() => false) + this.#connectFinishedNodes = this.#roundPeers.map(() => false) + + this.#roundPeers.keySeq() + .map((id) => { + const retrySignal = { + type: MessageTypes.RetryPeerConnections, + } + debug("Signaling connection retry to: %o", id.slice(0, 4)) + + const encoded = msgpack.encode(retrySignal) + return [id, encoded] as [client.NodeID, Buffer] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error(`peer ${id} marked as ready but not connection to it`) + } + return [conn, encoded] as [WebSocket, Buffer] + }) + .forEach(([conn, encoded]) => {conn.send(encoded)}) + + // Restart the timeout after sending retry messages + this.startTimeout() + } } diff --git a/webapp/src/components/training/TrainerDashboard.vue b/webapp/src/components/training/TrainerDashboard.vue index c7493060e..3680175b1 100644 --- a/webapp/src/components/training/TrainerDashboard.vue +++ b/webapp/src/components/training/TrainerDashboard.vue @@ -222,6 +222,9 @@ async function startTraining(): Promise { // manually interrupt the training cleanupDisco.value = async () => await disco.close() + // For the training completed message + let trainingCompleted = true + try { trainingGenerator.value = disco.train(dataset); @@ -245,6 +248,7 @@ async function startTraining(): Promise { epochsOfRoundLogs.value = List(); } } catch (e) { + trainingCompleted = false; if (e === stopper) { toaster.info("Training stopped"); return; @@ -262,8 +266,22 @@ async function startTraining(): Promise { toaster.error( "Training is not converging. Data potentially needs better preprocessing.", ); + } else if ( + e instanceof Error && + e.message.includes("Client disconnected after connection failure") + ){ + toaster.error( + "Client disconnected after multiple peer connection failure. Please rejoin the training." + ); + } else if ( + e instanceof Error && + e.message.includes("Timeout while waiting for the latest model") + ){ + toaster.error( + "Timeout while waiting for the model syncing. Please rejoin the training." + ); } else { - toaster.error("An error occurred during training"); + toaster.error("An error occurred during training.") } debug("while training: %o", e); } finally { @@ -271,7 +289,10 @@ async function startTraining(): Promise { await cleanupTrainingSession() } - toaster.success("Training successfully completed"); + if (trainingCompleted){ + // printed only when the training is compeleted successfully + toaster.success("Training successfully completed"); + } } async function cleanupTrainingSession() {