11from typing import Callable
2+
23import numpy as np
34from keras import datasets , utils
45
78
89
910class Server :
10- def __init__ (self , model_fn : Callable , nb_clients : int , weight_summarizer : WeightSummarizer ,
11+ def __init__ (self , model_fn : Callable ,
12+ nb_clients : int ,
13+ weight_summarizer : WeightSummarizer ,
1114 only_debugging : bool = True ):
1215 self .nb_clients = nb_clients
1316 self .weight_summarizer = weight_summarizer
@@ -17,13 +20,17 @@ def __init__(self, model_fn: Callable, nb_clients: int, weight_summarizer: Weigh
1720 self .model_weights = model .get_weights ()
1821 fed_learn .get_rid_of_the_models (model )
1922
20- (x_train , y_train ), ( _ , _ ) = datasets .cifar10 .load_data ()
23+ (x_train , y_train ), _ = datasets .cifar10 .load_data ()
2124
2225 if only_debugging :
26+ # TODO: remove me
2327 x_train = x_train [:100 ]
2428 y_train = y_train [:100 ]
2529
30+ # TODO: separate preprocessor for the data transformations
2631 y_train = utils .to_categorical (y_train , len (np .unique (y_train )))
32+ x_train = x_train .astype (np .float32 )
33+ x_train /= 255.0
2734
2835 self .x_train = x_train
2936 self .y_train = y_train
@@ -40,6 +47,7 @@ def __init__(self, model_fn: Callable, nb_clients: int, weight_summarizer: Weigh
4047 "shuffle" : True }
4148
4249 if only_debugging :
50+ # TODO: remove me
4351 self .train_dict ["epochs" ] = 1
4452
4553 def _generate_data_indices (self ):
0 commit comments