Skip to content

Commit ac061d7

Browse files
committed
losses introduced to server class
1 parent 0911fa0 commit ac061d7

2 files changed

Lines changed: 14 additions & 8 deletions

File tree

fed_learn/fed_server.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@ def __init__(self, model_fn: Callable,
1515
self.nb_clients = nb_clients
1616
self.weight_summarizer = weight_summarizer
1717

18+
# Initialize the global model's weights
1819
self.model_fn = model_fn
1920
model = self.model_fn()
2021
self.global_model_weights = model.get_weights()
2122
fed_learn.get_rid_of_the_models(model)
2223

24+
self.global_losses = []
25+
self.epoch_losses = []
26+
2327
(x_train, y_train), _ = datasets.cifar10.load_data()
2428

2529
if only_debugging:
@@ -37,7 +41,6 @@ def __init__(self, model_fn: Callable,
3741

3842
self.client_data_indices = None
3943
self.clients = []
40-
4144
self.client_model_weights = []
4245

4346
# Training parameters used by the clients
@@ -68,9 +71,11 @@ def send_model(self, client):
6871

6972
def init_for_new_epoch(self):
7073
# Reset clients
71-
self.clients = []
74+
self.clients.clear()
7275
# Reset the collected weights
73-
self.client_model_weights = []
76+
self.client_model_weights.clear()
77+
# Reset epoch losses
78+
self.epoch_losses.clear()
7479
# Generate new data indices for the clients
7580
self._generate_data_indices()
7681

fl.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,23 @@ def model_fn():
2828
server.create_clients()
2929
server.init_for_new_epoch()
3030

31-
loss = []
32-
3331
for client in server.clients:
3432
print("Client {0} is starting the training".format(client.id))
3533

3634
server.send_model(client)
3735
server.send_train_data(client)
3836

3937
hist = client.edge_train(server.get_client_train_param_dict())
40-
loss.append(hist.history["loss"][-1])
38+
server.epoch_losses.append(hist.history["loss"][-1])
4139

4240
server.receive_results(client)
4341

4442
server.summarize_weights()
45-
print("Loss (mean): {0}".format(np.mean(loss)))
46-
loss.clear()
43+
44+
epoch_mean_loss = np.mean(server.epoch_losses)
45+
server.global_losses.append(epoch_mean_loss)
46+
print("Loss (mean): {0}".format(server.global_losses[-1]))
47+
4748
print("-" * 30)
4849

4950
# TODO: test the base model with the aggregated weights

0 commit comments

Comments
 (0)