Skip to content

Commit 356d29b

Browse files
committed
epoch, clients, debug arguments added
1 parent a10a97c commit 356d29b

1 file changed

Lines changed: 17 additions & 5 deletions

File tree

fl.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,36 @@
1+
import argparse
2+
13
import fed_learn
4+
import numpy as np
5+
6+
parser = argparse.ArgumentParser()
7+
parser.add_argument("-ge", "--global-epochs", help="Number of global (server) epochs", type=int, default=5,
8+
required=False)
9+
parser.add_argument("-c", "--clients", help="Number of clients", type=int, default=10, required=False)
10+
parser.add_argument("-d", "--debug", help="Debugging", action="store_true", required=False)
11+
args = parser.parse_args()
212

3-
NB_CLIENTS = 10
4-
NB_EPOCHS = 3
13+
nb_clients = args.clients
14+
nb_epochs = args.global_epochs
15+
debug = args.debug
516

617

718
def model_fn():
819
return fed_learn.create_model((32, 32, 3), 10)
920

1021

1122
weight_summarizer = fed_learn.FedAvg()
12-
server = fed_learn.Server(model_fn, NB_CLIENTS, weight_summarizer)
23+
server = fed_learn.Server(model_fn, nb_clients, weight_summarizer, debug)
1324

14-
for epoch in range(NB_EPOCHS):
25+
for epoch in range(nb_epochs):
1526
print("Global Epoch {0} is starting".format(epoch))
1627
server.create_clients()
1728
server.init_for_new_epoch()
1829

1930
loss = []
2031

2132
for client in server.clients:
22-
print("Client {0} is starting the training".format({client.id}))
33+
print("Client {0} is starting the training".format(client.id))
2334

2435
server.send_model(client)
2536
server.send_train_data(client)
@@ -30,6 +41,7 @@ def model_fn():
3041
server.receive_results(client)
3142

3243
server.summarize_weights()
44+
print("Loss (mean): {0}".format(np.mean(loss)))
3345
print("-" * 30)
3446

3547
# TODO: test the base model with the aggregated weights

0 commit comments

Comments
 (0)