Skip to content

Commit 4b43e8c

Browse files
committed
client train params are updateable by args and ars are moved to a helper script
1 parent 3fb479f commit 4b43e8c

5 files changed

Lines changed: 35 additions & 23 deletions

File tree

fed_learn/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from .args_helper import get_args
12
from .data_sampling import iid_data_indices, non_iid_data_indices
23
from .fed_client import Client
34
from .fed_server import Server
45
from .models import create_model, set_model_weights
5-
from .weight_summarizer import FedAvg, WeightSummarizer
66
from .utils import get_rid_of_the_models
7+
from .weight_summarizer import FedAvg, WeightSummarizer

fed_learn/args_helper.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import argparse
2+
3+
4+
def get_args():
5+
parser = argparse.ArgumentParser()
6+
parser.add_argument("-e", "--global-epochs", help="Number of global (server) epochs", type=int, default=5,
7+
required=False)
8+
parser.add_argument("-c", "--clients", help="Number of clients", type=int, default=10, required=False)
9+
parser.add_argument("-d", "--debug", help="Debugging", action="store_true", required=False)
10+
11+
parser.add_argument("-lr", "--learning-rate", help="Learning rate", type=float, default=0.01, required=False)
12+
parser.add_argument("-b", "--batch-size", help="Batch Size", type=int, default=32, required=False)
13+
parser.add_argument("-ce", "--client-epochs", help="Number of epochs for the clients", type=int, default=5,
14+
required=False)
15+
args = parser.parse_args()
16+
return args

fed_learn/fed_server.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,10 @@ def __init__(self, model_fn: Callable,
5353
self.client_model_weights = []
5454

5555
# Training parameters used by the clients
56-
# TODO: this should be configurable
57-
self.train_dict = {"batch_size": 32,
58-
"epochs": 5,
59-
"verbose": 1,
60-
"shuffle": True}
61-
62-
if only_debugging:
63-
# TODO: remove me
64-
self.train_dict["epochs"] = 1
56+
self.client_train_params_dict = {"batch_size": 32,
57+
"epochs": 5,
58+
"verbose": 1,
59+
"shuffle": True}
6560

6661
def _generate_data_indices(self):
6762
self.client_data_indices = fed_learn.iid_data_indices(self.nb_clients, len(self.x_train))
@@ -105,7 +100,10 @@ def summarize_weights(self):
105100
self.global_model_weights = new_weights
106101

107102
def get_client_train_param_dict(self):
108-
return self.train_dict
103+
return self.client_train_params_dict
104+
105+
def update_client_train_params(self, param_dict: dict):
106+
self.client_train_params_dict.update(param_dict)
109107

110108
def test_global_model(self):
111109
model = self.model_fn()

fed_learn/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from keras.applications.vgg16 import VGG16
44

55

6-
def create_model(input_shape: tuple, nb_classes: int, init_with_imagenet: bool = False):
6+
def create_model(input_shape: tuple, nb_classes: int, init_with_imagenet: bool = False, learning_rate: float = 0.01):
77
weights = None
88
if init_with_imagenet:
99
weights = "imagenet"
@@ -21,7 +21,7 @@ def create_model(input_shape: tuple, nb_classes: int, init_with_imagenet: bool =
2121
model = models.Model(model.input, x)
2222

2323
loss = losses.categorical_crossentropy
24-
optimizer = optimizers.Adam(lr=0.001)
24+
optimizer = optimizers.SGD(lr=learning_rate, decay=0.99)
2525

2626
model.compile(optimizer, loss, metrics=["accuracy"])
2727
return model

fl.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,27 @@
1-
import argparse
21
import json
32

43
import numpy as np
54

65
import fed_learn
76

8-
parser = argparse.ArgumentParser()
9-
parser.add_argument("-e", "--global-epochs", help="Number of global (server) epochs", type=int, default=5,
10-
required=False)
11-
parser.add_argument("-c", "--clients", help="Number of clients", type=int, default=10, required=False)
12-
parser.add_argument("-d", "--debug", help="Debugging", action="store_true", required=False)
13-
args = parser.parse_args()
7+
args = fed_learn.get_args()
148

159
nb_clients = args.clients
16-
nb_epochs = args.global_epochs
10+
nb_global_epochs = args.global_epochs
1711
debug = args.debug
1812

13+
client_train_params = {"epochs": args.client_epochs, "batch_size": args.batch_size}
14+
1915

2016
def model_fn():
21-
return fed_learn.create_model((32, 32, 3), 10, init_with_imagenet=False)
17+
return fed_learn.create_model((32, 32, 3), 10, init_with_imagenet=False, learning_rate=args.learning_rate)
2218

2319

2420
weight_summarizer = fed_learn.FedAvg()
2521
server = fed_learn.Server(model_fn, nb_clients, weight_summarizer, debug)
22+
server.update_client_train_params(client_train_params)
2623

27-
for epoch in range(nb_epochs):
24+
for epoch in range(nb_global_epochs):
2825
print("Global Epoch {0} is starting".format(epoch))
2926
server.init_for_new_epoch()
3027
server.create_clients()

0 commit comments

Comments
 (0)