Skip to content

Commit 13be511

Browse files
committed
global model testing implemented
1 parent 59a3130 commit 13be511

3 files changed

Lines changed: 38 additions & 6 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
.idea/
22
__pycache__
3+
*.json

fed_learn/fed_server.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,42 @@ def __init__(self, model_fn: Callable,
1818
# Initialize the global model's weights
1919
self.model_fn = model_fn
2020
model = self.model_fn()
21+
self.global_test_metrics_dict = {k: [] for k in model.metrics_names}
2122
self.global_model_weights = model.get_weights()
2223
fed_learn.get_rid_of_the_models(model)
2324

24-
self.global_losses = []
25+
self.global_train_losses = []
2526
self.epoch_losses = []
2627

27-
(x_train, y_train), _ = datasets.cifar10.load_data()
28+
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
2829

2930
if only_debugging:
3031
# TODO: remove me
3132
x_train = x_train[:100]
3233
y_train = y_train[:100]
34+
x_test = x_test[:100]
35+
y_test = y_test[:100]
3336

3437
# TODO: separate preprocessor for the data transformations
3538
y_train = utils.to_categorical(y_train, len(np.unique(y_train)))
3639
x_train = x_train.astype(np.float32)
3740
x_train /= 255.0
3841

42+
y_test = utils.to_categorical(y_test, len(np.unique(y_test)))
43+
x_test = x_test.astype(np.float32)
44+
x_test /= 255.0
45+
3946
self.x_train = x_train
4047
self.y_train = y_train
48+
self.x_test = x_test
49+
self.y_test = y_test
4150

4251
self.client_data_indices = None
4352
self.clients = []
4453
self.client_model_weights = []
4554

4655
# Training parameters used by the clients
56+
# TODO: this should be configurable
4757
self.train_dict = {"batch_size": 32,
4858
"epochs": 5,
4959
"verbose": 1,
@@ -96,3 +106,16 @@ def summarize_weights(self):
96106

97107
def get_client_train_param_dict(self):
98108
return self.train_dict
109+
110+
def test_global_model(self):
111+
model = self.model_fn()
112+
fed_learn.models.set_model_weights(model, self.global_model_weights)
113+
results = model.evaluate(self.x_test, self.y_test, batch_size=32, verbose=1)
114+
115+
results_dict = dict(zip(model.metrics_names, results))
116+
for metric_name, value in results_dict.items():
117+
self.global_test_metrics_dict[metric_name].append(value)
118+
119+
fed_learn.get_rid_of_the_models(model)
120+
121+
return results_dict

fl.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import json
23

34
import numpy as np
45

@@ -42,9 +43,16 @@ def model_fn():
4243
server.summarize_weights()
4344

4445
epoch_mean_loss = np.mean(server.epoch_losses)
45-
server.global_losses.append(epoch_mean_loss)
46-
print("Loss (mean): {0}".format(server.global_losses[-1]))
46+
server.global_train_losses.append(epoch_mean_loss)
47+
print("Loss (client mean): {0}".format(server.global_train_losses[-1]))
4748

48-
print("-" * 30)
49+
global_test_results = server.test_global_model()
50+
for metric_name, value in global_test_results.items():
51+
print("Global test|")
52+
print("_" * 10)
53+
print("{0}: {1}".format(metric_name, value))
4954

50-
# TODO: test the base model with the aggregated weights
55+
print("_" * 30)
56+
57+
with open("fed_learn_global_test_results.json", 'w') as f:
58+
json.dump(server.global_test_metrics_dict, f)

0 commit comments

Comments
 (0)