1+ import argparse
2+
13import fed_learn
4+ import numpy as np
5+
6+ parser = argparse .ArgumentParser ()
7+ parser .add_argument ("-ge" , "--global-epochs" , help = "Number of global (server) epochs" , type = int , default = 5 ,
8+ required = False )
9+ parser .add_argument ("-c" , "--clients" , help = "Number of clients" , type = int , default = 10 , required = False )
10+ parser .add_argument ("-d" , "--debug" , help = "Debugging" , action = "store_true" , required = False )
11+ args = parser .parse_args ()
212
3- NB_CLIENTS = 10
4- NB_EPOCHS = 3
13+ nb_clients = args .clients
14+ nb_epochs = args .global_epochs
15+ debug = args .debug
516
617
718def model_fn ():
819 return fed_learn .create_model ((32 , 32 , 3 ), 10 )
920
1021
1122weight_summarizer = fed_learn .FedAvg ()
12- server = fed_learn .Server (model_fn , NB_CLIENTS , weight_summarizer )
23+ server = fed_learn .Server (model_fn , nb_clients , weight_summarizer , debug )
1324
14- for epoch in range (NB_EPOCHS ):
25+ for epoch in range (nb_epochs ):
1526 print ("Global Epoch {0} is starting" .format (epoch ))
1627 server .create_clients ()
1728 server .init_for_new_epoch ()
1829
1930 loss = []
2031
2132 for client in server .clients :
22- print ("Client {0} is starting the training" .format ({ client .id } ))
33+ print ("Client {0} is starting the training" .format (client .id ))
2334
2435 server .send_model (client )
2536 server .send_train_data (client )
@@ -30,6 +41,7 @@ def model_fn():
3041 server .receive_results (client )
3142
3243 server .summarize_weights ()
44+ print ("Loss (mean): {0}" .format (np .mean (loss )))
3345 print ("-" * 30 )
3446
3547 # TODO: test the base model with the aggregated weights
0 commit comments