@@ -74,7 +74,7 @@ def get_neighbor_model_wts(self) -> List[Dict[str, TorchModelType]]:
7474 from all the neighbors because that's how
7575 most dynamic topologies work.
7676 """
77- neighbor_models = self .comm_utils .all_gather (ignore_super_node = True )
77+ neighbor_models : List [ Dict [ str , TorchModelType ]] = self .comm_utils .all_gather (ignore_super_node = True )
7878 return neighbor_models
7979
8080 def get_neighbor_similarity (self , others_wts : List [Dict [str , TorchModelType ]]) -> List [float ]:
@@ -95,11 +95,12 @@ def get_neighbor_similarity(self, others_wts: List[Dict[str, TorchModelType]]) -
9595 raise ValueError ("Similarity metric {} not implemented" .format (self .similarity ))
9696 return similarity_wts
9797
98- def sample_neighbours (self , k : int ) -> List [int ]:
98+ def sample_neighbours (self , k : int , mode : str | None = None ) -> List [int ]:
9999 """
100100 We perform neighbor sampling after
101101 we have the similarity weights of all the neighbors.
102102 """
103+ assert mode is None or mode == "pull" , "Only pull mode is supported for dynamic topology"
103104 if self .sampling == "closest" :
104105 return select_smallest_k (self .similarity_wts , k )
105106 else :
@@ -154,7 +155,7 @@ def __init__(
154155 self .topology = DynamicTopology (config , comm_utils , self )
155156 self .topology .initialize ()
156157
157- def get_representation (self , ** kwargs : Any ) -> TorchModelType :
158+ def get_representation (self , ** kwargs : Any ) -> Dict [ str , int | Dict [ str , Any ]] :
158159 """
159160 Returns the model weights as representation.
160161 """
@@ -172,24 +173,20 @@ def run_protocol(self) -> None:
172173 epochs_per_round = self .config .get ("epochs_per_round" , 1 )
173174
174175 for it in range (start_round , total_rounds ):
176+ self .round_init ()
177+
175178 # Train locally and send the representation to the server
176179 stats ["train_loss" ], stats ["train_acc" ], stats ["train_time" ] = self .local_train (
177180 it , epochs_per_round
178181 )
179182 self .local_round_done ()
180183
181184 # Collect the representations from all other nodes from the server
182- neighbors = self .topology .recv_and_agg (self .num_collaborators )
183- # TODO: Log the neighbors
184- stats ["neighbors" ] = neighbors
185-
186- stats ["bytes_received" ], stats ["bytes_sent" ] = self .comm_utils .get_comm_cost ()
185+ collabs = self .topology .recv_and_agg (self .num_collaborators )
187186
188- # evaluate the model on the test data
189- # Inside FedStaticNode.run_protocol()
190- stats ["test_loss" ], stats ["test_acc" ] = self .local_test ()
191- stats .update (self .get_memory_metrics ())
192- self .log_metrics (stats = stats , iteration = it )
187+ self .stats ["neighbors" ] = collabs
188+ self .local_test ()
189+ self .round_finalize ()
193190
194191
195192
0 commit comments