11from typing import Callable
22
33import numpy as np
4- from keras import datasets , utils
4+ from keras import datasets , utils , models
55
66import fed_learn
77from fed_learn .weight_summarizer import WeightSummarizer
@@ -73,6 +73,11 @@ def _send_train_data_to_client(self, client):
7373 client .receive_data (x , y )
7474 return x , y
7575
76+ def _create_model_with_updated_weights (self ) -> models .Model :
77+ model = self .model_fn ()
78+ fed_learn .models .set_model_weights (model , self .global_model_weights )
79+ return model
80+
7681 def send_train_data (self ):
7782 self ._generate_data_indices ()
7883 for c in self .clients :
@@ -109,8 +114,7 @@ def update_client_train_params(self, param_dict: dict):
109114 self .client_train_params_dict .update (param_dict )
110115
111116 def test_global_model (self ):
112- model = self .model_fn ()
113- fed_learn .models .set_model_weights (model , self .global_model_weights )
117+ model = self ._create_model_with_updated_weights ()
114118 results = model .evaluate (self .x_test , self .y_test , batch_size = 32 , verbose = 1 )
115119
116120 results_dict = dict (zip (model .metrics_names , results ))
@@ -127,3 +131,14 @@ def select_clients(self):
127131 np .random .shuffle (client_indices )
128132 selected_client_indices = client_indices [:nb_clients_to_use ]
129133 return np .asarray (self .clients )[selected_client_indices ]
134+
135+ def save_model_weights (self , path : str ):
136+ model = self ._create_model_with_updated_weights ()
137+ model .save_weights (str (path ), overwrite = True )
138+ fed_learn .get_rid_of_the_models (model )
139+
140+ def load_model_weights (self , path : str , by_name : bool = False ):
141+ model = self ._create_model_with_updated_weights ()
142+ model .load_weights (str (path ), by_name = by_name )
143+ self .global_model_weights = model .get_weights ()
144+ fed_learn .get_rid_of_the_models (model )
0 commit comments