Skip to content

Commit d82567b

Browse files
committed
saving of the global model weights implemented
1 parent 1c109b4 commit d82567b

3 files changed

Lines changed: 26 additions & 3 deletions

File tree

fed_learn/fed_server.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Callable
22

33
import numpy as np
4-
from keras import datasets, utils
4+
from keras import datasets, utils, models
55

66
import fed_learn
77
from fed_learn.weight_summarizer import WeightSummarizer
@@ -73,6 +73,11 @@ def _send_train_data_to_client(self, client):
7373
client.receive_data(x, y)
7474
return x, y
7575

76+
def _create_model_with_updated_weights(self) -> models.Model:
77+
model = self.model_fn()
78+
fed_learn.models.set_model_weights(model, self.global_model_weights)
79+
return model
80+
7681
def send_train_data(self):
7782
self._generate_data_indices()
7883
for c in self.clients:
@@ -109,8 +114,7 @@ def update_client_train_params(self, param_dict: dict):
109114
self.client_train_params_dict.update(param_dict)
110115

111116
def test_global_model(self):
112-
model = self.model_fn()
113-
fed_learn.models.set_model_weights(model, self.global_model_weights)
117+
model = self._create_model_with_updated_weights()
114118
results = model.evaluate(self.x_test, self.y_test, batch_size=32, verbose=1)
115119

116120
results_dict = dict(zip(model.metrics_names, results))
@@ -127,3 +131,14 @@ def select_clients(self):
127131
np.random.shuffle(client_indices)
128132
selected_client_indices = client_indices[:nb_clients_to_use]
129133
return np.asarray(self.clients)[selected_client_indices]
134+
135+
def save_model_weights(self, path: str):
136+
model = self._create_model_with_updated_weights()
137+
model.save_weights(str(path), overwrite=True)
138+
fed_learn.get_rid_of_the_models(model)
139+
140+
def load_model_weights(self, path: str, by_name: bool = False):
141+
model = self._create_model_with_updated_weights()
142+
model.load_weights(str(path), by_name=by_name)
143+
self.global_model_weights = model.get_weights()
144+
fed_learn.get_rid_of_the_models(model)

fed_learn/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77

88

99
def get_rid_of_the_models(model=None):
10+
"""
11+
This function clears the TF session from the model.
12+
This is needed as TF/Keras models are not automatically cleared, and the memory will be overloaded
13+
"""
14+
1015
K.clear_session()
1116
if model is not None:
1217
del model

federated_learning.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
fed_learn.save_args_as_json(args, EXPERIMENT_FOLDER_PATH / args_json_path)
1515

1616
train_hist_path = EXPERIMENT_FOLDER_PATH / "fed_learn_global_test_results.json"
17+
global_weight_path = EXPERIMENT_FOLDER_PATH / "global_weights.h5"
1718

1819
client_train_params = {"epochs": args.client_epochs, "batch_size": args.batch_size}
1920

@@ -58,4 +59,6 @@ def model_fn():
5859
with open(str(train_hist_path), 'w') as f:
5960
json.dump(server.global_test_metrics_dict, f)
6061

62+
server.save_model_weights(global_weight_path)
63+
6164
print("_" * 30)

0 commit comments

Comments
 (0)