Skip to content

Commit ac303ee

Browse files
committed
loading the model weights
1 parent d82567b commit ac303ee

2 files changed

Lines changed: 6 additions & 0 deletions

File tree

fed_learn/args_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ def get_args():
77
parser.add_argument("-n", "--name", help="Name of the experiment", type=str, required=True)
88
parser.add_argument("-oe", "--overwrite-experiment", help="Overwrite existing experiment", action="store_true",
99
required=False)
10+
parser.add_argument("-w", "--weights-file", help="Weights file path to load", type=str, required=False)
1011
parser.add_argument("-e", "--global-epochs", help="Number of global (server) epochs", type=int, default=1000,
1112
required=False)
1213
parser.add_argument("-c", "--clients", help="Number of clients", type=int, default=100, required=False)

federated_learning.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ def model_fn():
2525

2626
weight_summarizer = fed_learn.FedAvg()
2727
server = fed_learn.Server(model_fn, weight_summarizer, args.clients, args.fraction, args.debug)
28+
29+
weight_path = args.weights_file
30+
if weight_path is not None:
31+
server.load_model_weights(weight_path)
32+
2833
server.update_client_train_params(client_train_params)
2934
server.create_clients()
3035
server.send_train_data()

0 commit comments

Comments
 (0)