11import json
2+ from pathlib import Path
23
34import numpy as np
45
56import fed_learn
67
78args = fed_learn .get_args ()
89
9- nb_clients = args .clients
10- client_fraction = args .fraction
11- nb_global_epochs = args .global_epochs
12- debug = args .debug
10+ EXPERIMENT_FOLDER_PATH = Path (__file__ ).resolve ().parent / "experiments" / args .name
11+ EXPERIMENT_FOLDER_PATH .mkdir (parents = True , exist_ok = args .overwrite_experiment )
12+
13+ args_json_path = EXPERIMENT_FOLDER_PATH / "args.json"
14+ fed_learn .save_args_as_json (args , EXPERIMENT_FOLDER_PATH / args_json_path )
15+
16+ train_hist_path = EXPERIMENT_FOLDER_PATH / "fed_learn_global_test_results.json"
1317
1418client_train_params = {"epochs" : args .client_epochs , "batch_size" : args .batch_size }
1519
@@ -19,12 +23,12 @@ def model_fn():
1923
2024
2125weight_summarizer = fed_learn .FedAvg ()
22- server = fed_learn .Server (model_fn , weight_summarizer , nb_clients , client_fraction , debug )
26+ server = fed_learn .Server (model_fn , weight_summarizer , args . clients , args . fraction , args . debug )
2327server .update_client_train_params (client_train_params )
2428server .create_clients ()
2529server .send_train_data ()
2630
27- for epoch in range (nb_global_epochs ):
31+ for epoch in range (args . global_epochs ):
2832 print ("Global Epoch {0} is starting" .format (epoch ))
2933 server .init_for_new_epoch ()
3034 selected_clients = server .select_clients ()
@@ -51,7 +55,7 @@ def model_fn():
5155 for metric_name , value in global_test_results .items ():
5256 print ("{0}: {1}" .format (metric_name , value ))
5357
54- with open ("fed_learn_global_test_results.json" , 'w' ) as f :
58+ with open (str ( train_hist_path ) , 'w' ) as f :
5559 json .dump (server .global_test_metrics_dict , f )
5660
5761 print ("_" * 30 )
0 commit comments