|
1 | 1 | import json |
2 | 2 | import os |
3 | 3 | from pathlib import Path |
4 | | - |
| 4 | +from swiss_army_tensorboard import tfboard_loggers |
5 | 5 | import numpy as np |
6 | 6 |
|
7 | 7 | import fed_learn |
|
16 | 16 |
|
17 | 17 | args_json_path = EXPERIMENT_FOLDER_PATH / "args.json" |
18 | 18 | fed_learn.save_args_as_json(args, EXPERIMENT_FOLDER_PATH / args_json_path) |
| 19 | +tfboard_loggers.TFBoardTextLogger(EXPERIMENT_FOLDER_PATH).log_markdown("args", "```\n{0}\n```".format( |
| 20 | + json.dumps(args.__dict__, indent=4, sort_keys=True)), -1) |
19 | 21 |
|
20 | 22 | train_hist_path = EXPERIMENT_FOLDER_PATH / "fed_learn_global_test_results.json" |
21 | 23 | global_weight_path = EXPERIMENT_FOLDER_PATH / "global_weights.h5" |
22 | 24 |
|
| 25 | +tf_scalar_logger = tfboard_loggers.TFBoardScalarLogger(EXPERIMENT_FOLDER_PATH) |
| 26 | + |
23 | 27 | client_train_params = {"epochs": args.client_epochs, "batch_size": args.batch_size} |
24 | 28 |
|
25 | 29 |
|
@@ -58,12 +62,17 @@ def model_fn(): |
58 | 62 |
|
59 | 63 | epoch_mean_loss = np.mean(server.epoch_losses) |
60 | 64 | server.global_train_losses.append(epoch_mean_loss) |
| 65 | + tf_scalar_logger.log_scalar("train_loss/client_mean_loss", server.global_train_losses[-1], epoch) |
61 | 66 | print("Loss (client mean): {0}".format(server.global_train_losses[-1])) |
62 | 67 |
|
63 | 68 | global_test_results = server.test_global_model() |
64 | 69 | print("--- Global test ---") |
65 | | - for metric_name, value in global_test_results.items(): |
66 | | - print("{0}: {1}".format(metric_name, value)) |
| 70 | + test_loss = global_test_results["loss"] |
| 71 | + test_acc = global_test_results["acc"] |
| 72 | + print("{0}: {1}".format("Loss", test_loss)) |
| 73 | + print("{0}: {1}".format("Accuracy", test_acc)) |
| 74 | + tf_scalar_logger.log_scalar("test_loss/global_loss", test_loss, epoch) |
| 75 | + tf_scalar_logger.log_scalar("test_acc/global_acc", test_acc, epoch) |
67 | 76 |
|
68 | 77 | with open(str(train_hist_path), 'w') as f: |
69 | 78 | json.dump(server.global_test_metrics_dict, f) |
|
0 commit comments