@@ -18,32 +18,42 @@ def __init__(self, model_fn: Callable,
1818 # Initialize the global model's weights
1919 self .model_fn = model_fn
2020 model = self .model_fn ()
21+ self .global_test_metrics_dict = {k : [] for k in model .metrics_names }
2122 self .global_model_weights = model .get_weights ()
2223 fed_learn .get_rid_of_the_models (model )
2324
24- self .global_losses = []
25+ self .global_train_losses = []
2526 self .epoch_losses = []
2627
27- (x_train , y_train ), _ = datasets .cifar10 .load_data ()
28+ (x_train , y_train ), ( x_test , y_test ) = datasets .cifar10 .load_data ()
2829
2930 if only_debugging :
3031 # TODO: remove me
3132 x_train = x_train [:100 ]
3233 y_train = y_train [:100 ]
34+ x_test = x_test [:100 ]
35+ y_test = y_test [:100 ]
3336
3437 # TODO: separate preprocessor for the data transformations
3538 y_train = utils .to_categorical (y_train , len (np .unique (y_train )))
3639 x_train = x_train .astype (np .float32 )
3740 x_train /= 255.0
3841
42+ y_test = utils .to_categorical (y_test , len (np .unique (y_test )))
43+ x_test = x_test .astype (np .float32 )
44+ x_test /= 255.0
45+
3946 self .x_train = x_train
4047 self .y_train = y_train
48+ self .x_test = x_test
49+ self .y_test = y_test
4150
4251 self .client_data_indices = None
4352 self .clients = []
4453 self .client_model_weights = []
4554
4655 # Training parameters used by the clients
56+ # TODO: this should be configurable
4757 self .train_dict = {"batch_size" : 32 ,
4858 "epochs" : 5 ,
4959 "verbose" : 1 ,
@@ -96,3 +106,16 @@ def summarize_weights(self):
96106
97107 def get_client_train_param_dict (self ):
98108 return self .train_dict
109+
110+ def test_global_model (self ):
111+ model = self .model_fn ()
112+ fed_learn .models .set_model_weights (model , self .global_model_weights )
113+ results = model .evaluate (self .x_test , self .y_test , batch_size = 32 , verbose = 1 )
114+
115+ results_dict = dict (zip (model .metrics_names , results ))
116+ for metric_name , value in results_dict .items ():
117+ self .global_test_metrics_dict [metric_name ].append (value )
118+
119+ fed_learn .get_rid_of_the_models (model )
120+
121+ return results_dict
0 commit comments