diff --git a/src/sampling.py b/src/sampling.py index ee0a5f2c..e3121e33 100755 --- a/src/sampling.py +++ b/src/sampling.py @@ -30,7 +30,7 @@ def mnist_noniid(dataset, num_users): :param num_users: :return: """ - # 60,000 training imgs --> 200 imgs/shard X 300 shards + # 60,000 training imgs --> 300 imgs/shard X 200 shards num_shards, num_imgs = 200, 300 idx_shard = [i for i in range(num_shards)] dict_users = {i: np.array([]) for i in range(num_users)}