77
88
99class Server :
10- def __init__ (self , model_fn : Callable , nb_clients : int , weight_summarizer : WeightSummarizer ):
10+ def __init__ (self , model_fn : Callable , nb_clients : int , weight_summarizer : WeightSummarizer ,
11+ only_debugging : bool = True ):
1112 self .nb_clients = nb_clients
1213 self .weight_summarizer = weight_summarizer
1314
@@ -17,6 +18,11 @@ def __init__(self, model_fn: Callable, nb_clients: int, weight_summarizer: Weigh
1718 fed_learn .get_rid_of_the_models (model )
1819
1920 (x_train , y_train ), (_ , _ ) = datasets .cifar10 .load_data ()
21+
22+ if only_debugging :
23+ x_train = x_train [:100 ]
24+ y_train = y_train [:100 ]
25+
2026 y_train = utils .to_categorical (y_train , len (np .unique (y_train )))
2127
2228 self .x_train = x_train
@@ -27,6 +33,12 @@ def __init__(self, model_fn: Callable, nb_clients: int, weight_summarizer: Weigh
2733
2834 self .client_model_weights = []
2935
36+ # Training parameters used by the clients
37+ self .train_dict = {"batch_size" : 32 ,
38+ "epochs" : 5 ,
39+ "verbose" : 1 ,
40+ "shuffle" : True }
41+
3042 def _generate_data_indices (self ):
3143 self .client_data_indices = fed_learn .iid_data_indices (self .nb_clients , len (self .x_train ))
3244
@@ -60,10 +72,5 @@ def summarize_weights(self):
6072 new_weights = self .weight_summarizer .process (self .client_model_weights )
6173 self .model_weights = new_weights
6274
63- @staticmethod
64- def get_client_train_param_dict ():
65- train_dict = {"batch_size" : 32 ,
66- "epochs" : 5 ,
67- "verbose" : 1 ,
68- "shuffle" : True }
69- return train_dict
75+ def get_client_train_param_dict (self ):
76+ return self .train_dict
0 commit comments