|
1 | | -import argparse |
2 | 1 | import json |
3 | 2 |
|
4 | 3 | import numpy as np |
5 | 4 |
|
6 | 5 | import fed_learn |
7 | 6 |
|
8 | | -parser = argparse.ArgumentParser() |
9 | | -parser.add_argument("-e", "--global-epochs", help="Number of global (server) epochs", type=int, default=5, |
10 | | - required=False) |
11 | | -parser.add_argument("-c", "--clients", help="Number of clients", type=int, default=10, required=False) |
12 | | -parser.add_argument("-d", "--debug", help="Debugging", action="store_true", required=False) |
13 | | -args = parser.parse_args() |
| 7 | +args = fed_learn.get_args() |
14 | 8 |
|
15 | 9 | nb_clients = args.clients |
16 | | -nb_epochs = args.global_epochs |
| 10 | +nb_global_epochs = args.global_epochs |
17 | 11 | debug = args.debug |
18 | 12 |
|
| 13 | +client_train_params = {"epochs": args.client_epochs, "batch_size": args.batch_size} |
| 14 | + |
19 | 15 |
|
20 | 16 | def model_fn(): |
21 | | - return fed_learn.create_model((32, 32, 3), 10, init_with_imagenet=False) |
| 17 | + return fed_learn.create_model((32, 32, 3), 10, init_with_imagenet=False, learning_rate=args.learning_rate) |
22 | 18 |
|
23 | 19 |
|
24 | 20 | weight_summarizer = fed_learn.FedAvg() |
25 | 21 | server = fed_learn.Server(model_fn, nb_clients, weight_summarizer, debug) |
| 22 | +server.update_client_train_params(client_train_params) |
26 | 23 |
|
27 | | -for epoch in range(nb_epochs): |
| 24 | +for epoch in range(nb_global_epochs): |
28 | 25 | print("Global Epoch {0} is starting".format(epoch)) |
29 | 26 | server.init_for_new_epoch() |
30 | 27 | server.create_clients() |
|
0 commit comments