Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions discojs/src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import type { EventConnection } from './event_connection.js'
import type { Aggregator } from '../aggregator/index.js'
import { EventEmitter } from '../utils/event_emitter.js'
import { type } from "./messages.js";
import { ModelWeightAccess } from "../training/disco.js";

const debug = createDebug("discojs:client");

Expand All @@ -38,6 +39,9 @@ export abstract class Client<N extends Network> extends EventEmitter<{
*/
protected promiseForMoreParticipants: Promise<void> | undefined = undefined;

// Interface to access trainer's model weights
protected modelWeightAccess?: ModelWeightAccess;

/**
* When the server notifies the client that they can resume training
* after waiting for more participants, we want to be able to display what
Expand All @@ -56,6 +60,15 @@ export abstract class Client<N extends Network> extends EventEmitter<{
) {
super()
}

/**
* Used for decentralized learning.
* Set the interface used by client to access to trainer's model weights.
* Disco object provides this access.
*/
setModelWeightAccess(modelWeightAccess: ModelWeightAccess){
this.modelWeightAccess = modelWeightAccess
}

/**
* Communication callback called at the beginning of every training round.
Expand Down
163 changes: 155 additions & 8 deletions discojs/src/client/decentralized/decentralized_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@
#pool?: PeerPool
#connections?: Map<NodeID, PeerConnection>

// Flag if this model requires model synchronization
#modelSyncNeeded?: boolean

// Check if the training round is in progress
// Used to get the latest model for model synchronization
#isRoundInTraining = false
#roundFinishedPromise?: Promise<void>
#resolveRoundFinished?: () => void // contains resolver

// Used to handle timeouts and promise resolving after calling disconnect
private get isDisconnected() : boolean {
return this._server === undefined
Expand Down Expand Up @@ -69,6 +78,24 @@
this.#pool.signal(event.peer, event.signal)
})

// Listen if the client is selected as a model provider node for a newly joining client.
// Upon receiving the signal, this client establishes a connection with the newcomer
// and sends the latest model weights.
this.server.on(type.SignalNewPeer, async (event) => {

Check failure on line 84 in discojs/src/client/decentralized/decentralized_client.ts

View workflow job for this annotation

GitHub Actions / lint-most

Promise returned in function argument where a void return was expected
if (this.#pool === undefined) throw new Error('received signal about new peer but peer pool is undefined')
const syncConnection = await this.#pool.getPeers(Set([event.newNode]), this.server, ()=>{})

const newcomerConn = syncConnection.get(event.newNode)

if (newcomerConn === undefined){
// if connection with newly joining client fails, print debug message
// and return
debug(`Cannot connect to newly joined client [${event.newNode}]`)
return
}
await this.sendModel(newcomerConn)
})

// c.f. setupServerCallbacks doc for explanation
let receivedEnoughParticipants = false
this.setupServerCallbacks(() => receivedEnoughParticipants = true)
Expand All @@ -79,8 +106,9 @@
this.server.send(msg)

const { id, waitForMoreParticipants,
nbOfParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo)

nbOfParticipants, joinedMidTraining } = await waitMessage(this.server, type.NewDecentralizedNodeInfo)

this.#modelSyncNeeded = joinedMidTraining
this.nbOfParticipants = nbOfParticipants


Expand Down Expand Up @@ -129,8 +157,38 @@
* When connected, one peer creates a promise for every other peer's weight update
* and waits for it to resolve.
*
* If a client joined the training after the first round,
* model syncing happens first to get the latest model.
*/
override async onRoundBeginCommunication(): Promise<void> {
if (this.#modelSyncNeeded) {
// 1. If model sync is needed, send server a request
this.server.send({ type: type.ModelSyncRequest })

// 2. Get the provider information from the server
const providerInfo = await waitMessageWithTimeout(this.server, type.SignalModelProvider, 30_000, "Timeout while waiting for the latest model provider")

if (this.#pool === undefined) {
throw new Error('peer pool is undefined, make sure to call `client.connect()` first')
}

// 3. Connect with model provider client and get the latest model
const syncConnection = await this.#pool.getPeers(
Set([providerInfo.providerNode]),
this.server,
()=>{}
)
const providerConn = syncConnection.get(providerInfo.providerNode)

if (providerConn === undefined){
throw new Error("The latest model provider is not connected")
}

const latestModel = await this.receiveModel(providerConn)
this.modelWeightAccess?.setModelWeight(latestModel)
this.#modelSyncNeeded = false
}

// Notify the server we want to join the next round so that the server
// waits for us to be ready before sending the list of peers for the round
this.server.send({ type: type.JoinRound })
Expand All @@ -149,11 +207,56 @@
// Once enough new participants join we can display the previous status again
this.saveAndEmit("connecting to peers")
// First we check if we are waiting for more participants before sending our weight update
await this.waitForParticipantsIfNeeded()
// Create peer-to-peer connections with all peers for the round
await this.establishPeerConnections()

while(true){
// Wait until enough participants are available before continuing the round
await this.waitForParticipantsIfNeeded()

// Create peer-to-peer connections with all peers for the round
await this.establishPeerConnections()

// Wait for connection related messages from the server before exchanging weight updates
// (1) If the client receives a StartWeightSharing message, it proceeds to weight update exchange
// (2) If it receives a RetryPeerConnections message, it retries peer connection establishment
// (3) After multiple retires, if the connection is still unsuccessful, the server starts excluding nodes from the round
// and sends a ConnectionFail message to those nodes
// (4) Upon receiving ConnectionFail, the client disconnects from the server
const msg = await Promise.race([
waitMessage(this.server, type.StartWeightSharing),
waitMessage(this.server, type.RetryPeerConnections),
waitMessage(this.server, type.ConnectionFail),
])

if (msg.type === type.StartWeightSharing){
break
} else if (msg.type === type.RetryPeerConnections){
debug(`[${shortenId(this.ownId)}] retrying peer connection establishment`)
// clear the communication round peer pool
await this.#pool?.shutdown()
this.#pool = new PeerPool(this.ownId)
// clear the connections
this.#connections = Map()
this.setAggregatorNodes(Set(this.ownId))
continue
} else if (msg.type === type.ConnectionFail){
debug(`[${shortenId(this.ownId)}] disconnect from the server`)
await this.disconnect()
throw new Error("Client disconnected after connection failure")
}
}
// Exchange weight updates with peers and return aggregated weights
return await this.exchangeWeightUpdates(weights)
let aggregatedWeight: WeightsContainer
try{
aggregatedWeight = await this.exchangeWeightUpdates(weights)
} finally {
// Mark the round as finished so that model synchronization can proceed
this.#isRoundInTraining = false
this.#resolveRoundFinished?.()
this.#roundFinishedPromise = undefined
this.#resolveRoundFinished = undefined
}

return aggregatedWeight
}

/**
Expand All @@ -178,8 +281,15 @@
try {
debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`);
const receivedMessage = await waitMessage(this.server, type.PeersForRound)


this.#isRoundInTraining = true
// Generate a promise that resolves when round training finishes
this.#roundFinishedPromise = new Promise<void>((resolve) => {
this.#resolveRoundFinished = resolve
})

const peers = Set(receivedMessage.peers)
debug(`[${shortenId(this.ownId)}] received peer list: %o`, peers.toArray());

if (this.ownId !== undefined && peers.has(this.ownId)) {
throw new Error('received peer list contains our own id')
Expand All @@ -198,7 +308,9 @@
(conn) => this.receivePayloads(conn)
)

debug(`[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS());
// Signal server that all connections with other peers in the round are established
this.server.send({ type: type.ConnectionsReady });
debug(`[${shortenId(this.ownId)}] peer connections ready: %o`, connections.keySeq().toJS());
this.#connections = connections
} catch (e) {
debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e);
Expand Down Expand Up @@ -303,4 +415,39 @@
}
return await this.aggregationResult
}

/**
* Receive model from the model provider.
*/
private async receiveModel(providerConn: PeerConnection): Promise<WeightsContainer>{
const message = await waitMessageWithTimeout(providerConn, type.SharedModel, 30_000, "Timeout while waiting for the latest model")

const decoded = serialization.weights.decode(message.model)
return decoded
}

/**
* Send the latest available model to a newly joining client.
* If the current training round is in progress, wait until the round finishes
* and receive the latest aggregated model.
*/
private async sendModel(newcomerConn: PeerConnection): Promise<void> {
if (this.#isRoundInTraining){
await this.#roundFinishedPromise
}

const model = this.modelWeightAccess?.getModelWeight()

if (model === undefined){
debug("Failed to get the latest model from model provider client")
return
}
const encoded = await serialization.weights.encode(model)

const message: messages.SharedModel = {
type: type.SharedModel,
model: encoded
}
newcomerConn.send(message)
}
}
80 changes: 75 additions & 5 deletions discojs/src/client/decentralized/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export interface NewDecentralizedNodeInfo {
id: NodeID
waitForMoreParticipants: boolean
nbOfParticipants: number
joinedMidTraining: boolean
}

// WebRTC signal to forward to other node
Expand All @@ -38,6 +39,54 @@ export interface PeersForRound {
aggregationRound: number
}

// peer sends to server to signal all the connections to other peers
// are established
export interface ConnectionsReady {
type: type.ConnectionsReady
}

// Server signals each peer to start weight update sharing
export interface StartWeightSharing {
type: type.StartWeightSharing
}

// Server signals peers to reestablish peer connections
export interface RetryPeerConnections {
type: type.RetryPeerConnections
aggregationRound: number
}

// Server signals a node that the connection with other peers failed
export interface ConnectionFail {
type: type.ConnectionFail
}

// Nodes joining in the middle of the training send to server
// to request the latest model before starting local training
export interface ModelSyncRequest {
type: type.ModelSyncRequest
}

// Server signals a node that shares the lastest model with node
// who joined in the middle of the training
export interface SignalNewPeer {
type: type.SignalNewPeer
newNode: NodeID
}

// Server signals new node joining in the middle of the training
// about the model provider node
export interface SignalModelProvider {
type: type.SignalModelProvider
providerNode: NodeID
}

// Sent by client to another client to share the latest model
export interface SharedModel {
type: type.SharedModel
model: serialization.Encoded
}

/// Phase 1 communication (between peers)

export interface Payload {
Expand All @@ -55,15 +104,24 @@ export type MessageFromServer =
SignalForPeer |
PeersForRound |
WaitingForMoreParticipants |
EnoughParticipants
EnoughParticipants |
StartWeightSharing |
RetryPeerConnections |
ConnectionFail |
SignalModelProvider |
SignalNewPeer

export type MessageToServer =
ClientConnected |
SignalForPeer |
PeerIsReady |
JoinRound
JoinRound |
ConnectionsReady |
ModelSyncRequest

export type PeerMessage = Payload
export type PeerMessage =
Payload |
SharedModel

export function isMessageFromServer (o: unknown): o is MessageFromServer {
if (!hasMessageType(o)) return false
Expand All @@ -75,11 +133,17 @@ export function isMessageFromServer (o: unknown): o is MessageFromServer {
typeof o.waitForMoreParticipants === 'boolean'
case type.SignalForPeer:
return 'peer' in o && isNodeID(o.peer) &&
'signal' in o // TODO check signal content?
'signal' in o
case type.PeersForRound:
return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID)
case type.SignalNewPeer:
return 'newNode' in o && isNodeID(o.newNode)
case type.WaitingForMoreParticipants:
case type.EnoughParticipants:
case type.StartWeightSharing:
case type.RetryPeerConnections:
case type.ConnectionFail:
case type.SignalModelProvider:
return true
}

Expand All @@ -94,9 +158,11 @@ export function isMessageToServer (o: unknown): o is MessageToServer {
return true
case type.SignalForPeer:
return 'peer' in o && isNodeID(o.peer) &&
'signal' in o // TODO check signal content?
'signal' in o
case type.JoinRound:
case type.PeerIsReady:
case type.ConnectionsReady:
case type.ModelSyncRequest:
return true
}

Expand All @@ -112,6 +178,10 @@ export function isPeerMessage (o: unknown): o is PeerMessage {
'peer' in o && isNodeID(o.peer) &&
'payload' in o && serialization.isEncoded(o.payload)
)
case type.SharedModel:
return (
'model' in o && serialization.isEncoded(o.model)
)
}

return false
Expand Down
Loading
Loading