Skip to content

Commit a7ed062

Browse files
committed
debugging option added for much faster trainings
1 parent 859775b commit a7ed062

1 file changed

Lines changed: 15 additions & 8 deletions

File tree

fed_learn/fed_server.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88

99
class Server:
10-
def __init__(self, model_fn: Callable, nb_clients: int, weight_summarizer: WeightSummarizer):
10+
def __init__(self, model_fn: Callable, nb_clients: int, weight_summarizer: WeightSummarizer,
11+
only_debugging: bool = True):
1112
self.nb_clients = nb_clients
1213
self.weight_summarizer = weight_summarizer
1314

@@ -17,6 +18,11 @@ def __init__(self, model_fn: Callable, nb_clients: int, weight_summarizer: Weigh
1718
fed_learn.get_rid_of_the_models(model)
1819

1920
(x_train, y_train), (_, _) = datasets.cifar10.load_data()
21+
22+
if only_debugging:
23+
x_train = x_train[:100]
24+
y_train = y_train[:100]
25+
2026
y_train = utils.to_categorical(y_train, len(np.unique(y_train)))
2127

2228
self.x_train = x_train
@@ -27,6 +33,12 @@ def __init__(self, model_fn: Callable, nb_clients: int, weight_summarizer: Weigh
2733

2834
self.client_model_weights = []
2935

36+
# Training parameters used by the clients
37+
self.train_dict = {"batch_size": 32,
38+
"epochs": 5,
39+
"verbose": 1,
40+
"shuffle": True}
41+
3042
def _generate_data_indices(self):
3143
self.client_data_indices = fed_learn.iid_data_indices(self.nb_clients, len(self.x_train))
3244

@@ -60,10 +72,5 @@ def summarize_weights(self):
6072
new_weights = self.weight_summarizer.process(self.client_model_weights)
6173
self.model_weights = new_weights
6274

63-
@staticmethod
64-
def get_client_train_param_dict():
65-
train_dict = {"batch_size": 32,
66-
"epochs": 5,
67-
"verbose": 1,
68-
"shuffle": True}
69-
return train_dict
75+
def get_client_train_param_dict(self):
76+
return self.train_dict

0 commit comments

Comments
 (0)