Skip to content

Commit 53031c1

Browse files
authored
files are stored in the experiments folder (#4)
1 parent 96ea2cb commit 53031c1

4 files changed

Lines changed: 22 additions & 8 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.idea/
22
__pycache__
3+
experiments/
34
*.json

fed_learn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .args_helper import get_args
1+
from .args_helper import get_args, save_args_as_json
22
from .data_sampling import iid_data_indices, non_iid_data_indices
33
from .fed_client import Client
44
from .fed_server import Server

fed_learn/args_helper.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import argparse
2+
import json
23

34

45
def get_args():
56
parser = argparse.ArgumentParser()
7+
parser.add_argument("-n", "--name", help="Name of the experiment", type=str, required=True)
8+
parser.add_argument("-oe", "--overwrite-experiment", help="Overwrite existing experiment", action="store_true",
9+
required=False)
610
parser.add_argument("-e", "--global-epochs", help="Number of global (server) epochs", type=int, default=10,
711
required=False)
812
parser.add_argument("-c", "--clients", help="Number of clients", type=int, default=100, required=False)
@@ -16,3 +20,8 @@ def get_args():
1620
required=False)
1721
args = parser.parse_args()
1822
return args
23+
24+
25+
def save_args_as_json(args, path):
26+
with open(str(path), "w") as f:
27+
json.dump(args.__dict__, f, sort_keys=True, indent=4)

federated_learning.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import json
2+
from pathlib import Path
23

34
import numpy as np
45

56
import fed_learn
67

78
args = fed_learn.get_args()
89

9-
nb_clients = args.clients
10-
client_fraction = args.fraction
11-
nb_global_epochs = args.global_epochs
12-
debug = args.debug
10+
EXPERIMENT_FOLDER_PATH = Path(__file__).resolve().parent / "experiments" / args.name
11+
EXPERIMENT_FOLDER_PATH.mkdir(parents=True, exist_ok=args.overwrite_experiment)
12+
13+
args_json_path = EXPERIMENT_FOLDER_PATH / "args.json"
14+
fed_learn.save_args_as_json(args, EXPERIMENT_FOLDER_PATH / args_json_path)
15+
16+
train_hist_path = EXPERIMENT_FOLDER_PATH / "fed_learn_global_test_results.json"
1317

1418
client_train_params = {"epochs": args.client_epochs, "batch_size": args.batch_size}
1519

@@ -19,12 +23,12 @@ def model_fn():
1923

2024

2125
weight_summarizer = fed_learn.FedAvg()
22-
server = fed_learn.Server(model_fn, weight_summarizer, nb_clients, client_fraction, debug)
26+
server = fed_learn.Server(model_fn, weight_summarizer, args.clients, args.fraction, args.debug)
2327
server.update_client_train_params(client_train_params)
2428
server.create_clients()
2529
server.send_train_data()
2630

27-
for epoch in range(nb_global_epochs):
31+
for epoch in range(args.global_epochs):
2832
print("Global Epoch {0} is starting".format(epoch))
2933
server.init_for_new_epoch()
3034
selected_clients = server.select_clients()
@@ -51,7 +55,7 @@ def model_fn():
5155
for metric_name, value in global_test_results.items():
5256
print("{0}: {1}".format(metric_name, value))
5357

54-
with open("fed_learn_global_test_results.json", 'w') as f:
58+
with open(str(train_hist_path), 'w') as f:
5559
json.dump(server.global_test_metrics_dict, f)
5660

5761
print("_" * 30)

0 commit comments

Comments
 (0)