Skip to content

Commit 3e221d0

Browse files
authored
Dynamic algo fix (#167)
* little bit of typehinting * dynamic algo and config bug fix and type hint in comm_utils
1 parent ded2ba5 commit 3e221d0

4 files changed

Lines changed: 19 additions & 23 deletions

File tree

src/algos/base_class.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,8 @@ def set_data_parameters(self, config: ConfigType) -> None:
527527
self.classes_of_interest = classes
528528
self.train_indices = train_indices
529529
self.train_dset = train_dset
530-
self.dloader = DataLoader(train_dset, batch_size=len(train_dset), shuffle=False)
531-
self._test_loader = DataLoader(test_dset, batch_size=len(test_dset), shuffle=False)
530+
self.dloader: DataLoader[Any] = DataLoader(train_dset, batch_size=len(train_dset), shuffle=False)
531+
self._test_loader: DataLoader[Any] = DataLoader(test_dset, batch_size=len(test_dset), shuffle=False)
532532
print("Using GIA data setup")
533533
print(self.labels)
534534
else:

src/algos/fl_dynamic.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def get_neighbor_model_wts(self) -> List[Dict[str, TorchModelType]]:
7474
from all the neighbors because that's how
7575
most dynamic topologies work.
7676
"""
77-
neighbor_models = self.comm_utils.all_gather(ignore_super_node=True)
77+
neighbor_models: List[Dict[str, TorchModelType]] = self.comm_utils.all_gather(ignore_super_node=True)
7878
return neighbor_models
7979

8080
def get_neighbor_similarity(self, others_wts: List[Dict[str, TorchModelType]]) -> List[float]:
@@ -95,11 +95,12 @@ def get_neighbor_similarity(self, others_wts: List[Dict[str, TorchModelType]]) -
9595
raise ValueError("Similarity metric {} not implemented".format(self.similarity))
9696
return similarity_wts
9797

98-
def sample_neighbours(self, k: int) -> List[int]:
98+
def sample_neighbours(self, k: int, mode: str|None = None) -> List[int]:
9999
"""
100100
We perform neighbor sampling after
101101
we have the similarity weights of all the neighbors.
102102
"""
103+
assert mode is None or mode == "pull", "Only pull mode is supported for dynamic topology"
103104
if self.sampling == "closest":
104105
return select_smallest_k(self.similarity_wts, k)
105106
else:
@@ -154,7 +155,7 @@ def __init__(
154155
self.topology = DynamicTopology(config, comm_utils, self)
155156
self.topology.initialize()
156157

157-
def get_representation(self, **kwargs: Any) -> TorchModelType:
158+
def get_representation(self, **kwargs: Any) -> Dict[str, int|Dict[str, Any]]:
158159
"""
159160
Returns the model weights as representation.
160161
"""
@@ -172,24 +173,20 @@ def run_protocol(self) -> None:
172173
epochs_per_round = self.config.get("epochs_per_round", 1)
173174

174175
for it in range(start_round, total_rounds):
176+
self.round_init()
177+
175178
# Train locally and send the representation to the server
176179
stats["train_loss"], stats["train_acc"], stats["train_time"] = self.local_train(
177180
it, epochs_per_round
178181
)
179182
self.local_round_done()
180183

181184
# Collect the representations from all other nodes from the server
182-
neighbors = self.topology.recv_and_agg(self.num_collaborators)
183-
# TODO: Log the neighbors
184-
stats["neighbors"] = neighbors
185-
186-
stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost()
185+
collabs = self.topology.recv_and_agg(self.num_collaborators)
187186

188-
# evaluate the model on the test data
189-
# Inside FedStaticNode.run_protocol()
190-
stats["test_loss"], stats["test_acc"] = self.local_test()
191-
stats.update(self.get_memory_metrics())
192-
self.log_metrics(stats=stats, iteration=it)
187+
self.stats["neighbors"] = collabs
188+
self.local_test()
189+
self.round_finalize()
193190

194191

195192

src/configs/sys_config.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,10 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
159159
CIFAR10_DSET = "cifar10"
160160
CIAR10_DPATH = "./datasets/imgs/cifar10/"
161161

162-
NUM_COLLABORATORS = 1
163-
DUMP_DIR = "/tmp/"
162+
NUM_COLLABORATORS = 3
163+
DUMP_DIR = "/tmp/new_sonar/"
164164

165-
num_users = 3
165+
num_users = 9
166166
mpi_system_config: ConfigType = {
167167
"exp_id": "",
168168
"comm": {"type": "MPI"},
@@ -318,8 +318,6 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
318318
"exp_keys": [],
319319
}
320320

321-
num_users = 4
322-
323321
dropout_dict: Any = {
324322
"distribution_dict": { # leave dict empty to disable dropout
325323
"method": "uniform", # "uniform", "normal"
@@ -346,9 +344,10 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
346344
"dpath": CIAR10_DPATH,
347345
"seed": 2,
348346
"device_ids": get_device_ids(num_users, gpu_ids),
347+
"assign_based_on_host": True,
349348
# "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore
350-
"algos": get_algo_configs(num_users=num_users, algo_configs=[fed_dynamic_loss]), # type: ignore
351-
"samples_per_user": 10000 // num_users, # distributed equally
349+
"algos": get_algo_configs(num_users=num_users, algo_configs=[fed_dynamic_weights]), # type: ignore
350+
"samples_per_user": 500, # distributed equally
352351
"train_label_distribution": "non_iid",
353352
"alpha_data": 0.1,
354353
"test_label_distribution": "iid",

src/utils/communication/comm_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def receive(self, node_ids: List[int]) -> Any:
7272
def broadcast(self, data: Any, tag: int = 0):
7373
self.comm.broadcast(data)
7474

75-
def all_gather(self, tag: int = 0, ignore_super_node: bool = False):
75+
def all_gather(self, tag: int = 0, ignore_super_node: bool = False) -> List[Dict[str, Any]]:
7676
return self.comm.all_gather(ignore_super_node=ignore_super_node)
7777

7878
def send_quorum(self):

0 commit comments

Comments
 (0)