Skip to content

Commit 7bbd6e0

Browse files
committed
all the clients are created and then only a selected fraction is used for the trainings
1 parent 2cbc14c commit 7bbd6e0

5 files changed

Lines changed: 41 additions & 19 deletions

File tree

fed_learn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
from .fed_client import Client
44
from .fed_server import Server
55
from .models import create_model, set_model_weights
6-
from .utils import get_rid_of_the_models
6+
from .utils import get_rid_of_the_models, print_selected_clients
77
from .weight_summarizer import FedAvg, WeightSummarizer

fed_learn/args_helper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
def get_args():
55
parser = argparse.ArgumentParser()
6-
parser.add_argument("-e", "--global-epochs", help="Number of global (server) epochs", type=int, default=5,
6+
parser.add_argument("-e", "--global-epochs", help="Number of global (server) epochs", type=int, default=10,
7+
required=False)
8+
parser.add_argument("-c", "--clients", help="Number of clients", type=int, default=100, required=False)
9+
parser.add_argument("-f", "--fraction", help="Client fraction to use", type=float, default=0.2,
710
required=False)
8-
parser.add_argument("-c", "--clients", help="Number of clients", type=int, default=10, required=False)
911
parser.add_argument("-d", "--debug", help="Debugging", action="store_true", required=False)
1012

1113
parser.add_argument("-lr", "--learning-rate", help="Learning rate", type=float, default=0.01, required=False)

fed_learn/fed_server.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
class Server:
1111
def __init__(self, model_fn: Callable,
12-
nb_clients: int,
1312
weight_summarizer: WeightSummarizer,
13+
nb_clients: int = 100,
14+
client_fraction: float = 0.2,
1415
only_debugging: bool = True):
1516
self.nb_clients = nb_clients
17+
self.client_fraction = client_fraction
1618
self.weight_summarizer = weight_summarizer
1719

1820
# Initialize the global model's weights
@@ -64,33 +66,34 @@ def _generate_data_indices(self):
6466
def _get_data_indices_for_client(self, client: int):
6567
return self.client_data_indices[client]
6668

67-
def send_train_data(self, client):
69+
def _send_train_data_to_client(self, client):
6870
relevant_data_point_indices = self._get_data_indices_for_client(client.id)
6971
x = self.x_train[relevant_data_point_indices]
7072
y = self.y_train[relevant_data_point_indices]
7173
client.receive_data(x, y)
7274
return x, y
7375

76+
def send_train_data(self):
77+
self._generate_data_indices()
78+
for c in self.clients:
79+
self._send_train_data_to_client(c)
80+
7481
def send_model(self, client):
7582
client.receive_and_init_model(self.model_fn, self.global_model_weights)
7683

7784
def init_for_new_epoch(self):
78-
# Reset clients
79-
self.clients.clear()
8085
# Reset the collected weights
8186
self.client_model_weights.clear()
8287
# Reset epoch losses
8388
self.epoch_losses.clear()
84-
# Generate new data indices for the clients
85-
self._generate_data_indices()
8689

8790
def receive_results(self, client):
8891
client_weights = client.model.get_weights()
8992
self.client_model_weights.append(client_weights)
9093
client.reset_model()
9194

9295
def create_clients(self):
93-
# Create new ones
96+
# Create all the clients
9497
for i in range(self.nb_clients):
9598
client = fed_learn.Client(i)
9699
self.clients.append(client)
@@ -117,3 +120,10 @@ def test_global_model(self):
117120
fed_learn.get_rid_of_the_models(model)
118121

119122
return results_dict
123+
124+
def select_clients(self):
125+
nb_clients_to_use = max(int(self.nb_clients * self.client_fraction), 1)
126+
client_indices = np.arange(self.nb_clients)
127+
np.random.shuffle(client_indices)
128+
selected_client_indices = client_indices[:nb_clients_to_use]
129+
return np.asarray(self.clients)[selected_client_indices]

fed_learn/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
from keras import backend as K
21
import gc
2+
from typing import List
3+
4+
from keras import backend as K
5+
6+
import fed_learn
37

48

59
def get_rid_of_the_models(model=None):
6-
# TODO: somehow this does not free up the GPU memory (after a while you will get OOM) (tested on Windows 10...)
710
K.clear_session()
811
if model is not None:
912
del model
1013
gc.collect()
14+
15+
16+
def print_selected_clients(clients: List[fed_learn.fed_client.Client]):
17+
client_ids = [c.id for c in clients]
18+
print("Selected clients for epoch: {0}".format("| ".join(map(str, client_ids))))

fl.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
args = fed_learn.get_args()
88

99
nb_clients = args.clients
10+
client_fraction = args.fraction
1011
nb_global_epochs = args.global_epochs
1112
debug = args.debug
1213

@@ -18,20 +19,22 @@ def model_fn():
1819

1920

2021
weight_summarizer = fed_learn.FedAvg()
21-
server = fed_learn.Server(model_fn, nb_clients, weight_summarizer, debug)
22+
server = fed_learn.Server(model_fn, weight_summarizer, nb_clients, client_fraction, debug)
2223
server.update_client_train_params(client_train_params)
24+
server.create_clients()
25+
server.send_train_data()
2326

2427
for epoch in range(nb_global_epochs):
2528
print("Global Epoch {0} is starting".format(epoch))
2629
server.init_for_new_epoch()
27-
server.create_clients()
30+
selected_clients = server.select_clients()
2831

29-
for client in server.clients:
32+
fed_learn.print_selected_clients(selected_clients)
33+
34+
for client in selected_clients:
3035
print("Client {0} is starting the training".format(client.id))
3136

3237
server.send_model(client)
33-
server.send_train_data(client)
34-
3538
hist = client.edge_train(server.get_client_train_param_dict())
3639
server.epoch_losses.append(hist.history["loss"][-1])
3740

@@ -44,9 +47,8 @@ def model_fn():
4447
print("Loss (client mean): {0}".format(server.global_train_losses[-1]))
4548

4649
global_test_results = server.test_global_model()
47-
print("Global test|")
50+
print("--- Global test ---")
4851
for metric_name, value in global_test_results.items():
49-
print("_" * 10)
5052
print("{0}: {1}".format(metric_name, value))
5153

5254
with open("fed_learn_global_test_results.json", 'w') as f:

0 commit comments

Comments
 (0)