Skip to content

Commit f5e5d27

Browse files
committed
tensorboard logging of the scalars and args
1 parent f40a257 commit f5e5d27

1 file changed

Lines changed: 12 additions & 3 deletions

File tree

federated_learning.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import os
33
from pathlib import Path
4-
4+
from swiss_army_tensorboard import tfboard_loggers
55
import numpy as np
66

77
import fed_learn
@@ -16,10 +16,14 @@
1616

1717
args_json_path = EXPERIMENT_FOLDER_PATH / "args.json"
1818
fed_learn.save_args_as_json(args, EXPERIMENT_FOLDER_PATH / args_json_path)
19+
tfboard_loggers.TFBoardTextLogger(EXPERIMENT_FOLDER_PATH).log_markdown("args", "```\n{0}\n```".format(
20+
json.dumps(args.__dict__, indent=4, sort_keys=True)), -1)
1921

2022
train_hist_path = EXPERIMENT_FOLDER_PATH / "fed_learn_global_test_results.json"
2123
global_weight_path = EXPERIMENT_FOLDER_PATH / "global_weights.h5"
2224

25+
tf_scalar_logger = tfboard_loggers.TFBoardScalarLogger(EXPERIMENT_FOLDER_PATH)
26+
2327
client_train_params = {"epochs": args.client_epochs, "batch_size": args.batch_size}
2428

2529

@@ -58,12 +62,17 @@ def model_fn():
5862

5963
epoch_mean_loss = np.mean(server.epoch_losses)
6064
server.global_train_losses.append(epoch_mean_loss)
65+
tf_scalar_logger.log_scalar("train_loss/client_mean_loss", server.global_train_losses[-1], epoch)
6166
print("Loss (client mean): {0}".format(server.global_train_losses[-1]))
6267

6368
global_test_results = server.test_global_model()
6469
print("--- Global test ---")
65-
for metric_name, value in global_test_results.items():
66-
print("{0}: {1}".format(metric_name, value))
70+
test_loss = global_test_results["loss"]
71+
test_acc = global_test_results["acc"]
72+
print("{0}: {1}".format("Loss", test_loss))
73+
print("{0}: {1}".format("Accuracy", test_acc))
74+
tf_scalar_logger.log_scalar("test_loss/global_loss", test_loss, epoch)
75+
tf_scalar_logger.log_scalar("test_acc/global_acc", test_acc, epoch)
6776

6877
with open(str(train_hist_path), 'w') as f:
6978
json.dump(server.global_test_metrics_dict, f)

0 commit comments

Comments
 (0)