99
1010class Server :
1111 def __init__ (self , model_fn : Callable ,
12- nb_clients : int ,
1312 weight_summarizer : WeightSummarizer ,
13+ nb_clients : int = 100 ,
14+ client_fraction : float = 0.2 ,
1415 only_debugging : bool = True ):
1516 self .nb_clients = nb_clients
17+ self .client_fraction = client_fraction
1618 self .weight_summarizer = weight_summarizer
1719
1820 # Initialize the global model's weights
@@ -64,33 +66,34 @@ def _generate_data_indices(self):
6466 def _get_data_indices_for_client (self , client : int ):
6567 return self .client_data_indices [client ]
6668
67- def send_train_data (self , client ):
69+ def _send_train_data_to_client (self , client ):
6870 relevant_data_point_indices = self ._get_data_indices_for_client (client .id )
6971 x = self .x_train [relevant_data_point_indices ]
7072 y = self .y_train [relevant_data_point_indices ]
7173 client .receive_data (x , y )
7274 return x , y
7375
76+ def send_train_data (self ):
77+ self ._generate_data_indices ()
78+ for c in self .clients :
79+ self ._send_train_data_to_client (c )
80+
7481 def send_model (self , client ):
7582 client .receive_and_init_model (self .model_fn , self .global_model_weights )
7683
7784 def init_for_new_epoch (self ):
78- # Reset clients
79- self .clients .clear ()
8085 # Reset the collected weights
8186 self .client_model_weights .clear ()
8287 # Reset epoch losses
8388 self .epoch_losses .clear ()
84- # Generate new data indices for the clients
85- self ._generate_data_indices ()
8689
8790 def receive_results (self , client ):
8891 client_weights = client .model .get_weights ()
8992 self .client_model_weights .append (client_weights )
9093 client .reset_model ()
9194
9295 def create_clients (self ):
93- # Create new ones
96+ # Create all the clients
9497 for i in range (self .nb_clients ):
9598 client = fed_learn .Client (i )
9699 self .clients .append (client )
@@ -117,3 +120,10 @@ def test_global_model(self):
117120 fed_learn .get_rid_of_the_models (model )
118121
119122 return results_dict
123+
124+ def select_clients (self ):
125+ nb_clients_to_use = max (int (self .nb_clients * self .client_fraction ), 1 )
126+ client_indices = np .arange (self .nb_clients )
127+ np .random .shuffle (client_indices )
128+ selected_client_indices = client_indices [:nb_clients_to_use ]
129+ return np .asarray (self .clients )[selected_client_indices ]
0 commit comments