Skip to content

Commit d7474ed

Browse files
committed
working GPU is configurable
1 parent 47f6ea5 commit d7474ed

2 files changed

Lines changed: 5 additions & 0 deletions

File tree

fed_learn/args_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def get_args():
1919
parser.add_argument("-b", "--batch-size", help="Batch Size", type=int, default=32, required=False)
2020
parser.add_argument("-ce", "--client-epochs", help="Number of epochs for the clients", type=int, default=1,
2121
required=False)
22+
parser.add_argument("-g", "--gpu", help="GPU to use (-1 is CPU)", type=int, default=0, required=False)
2223
args = parser.parse_args()
2324
return args
2425

federated_learning.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import os
23
from pathlib import Path
34

45
import numpy as np
@@ -7,6 +8,9 @@
78

89
args = fed_learn.get_args()
910

11+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
12+
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
13+
1014
EXPERIMENT_FOLDER_PATH = Path(__file__).resolve().parent / "experiments" / args.name
1115
EXPERIMENT_FOLDER_PATH.mkdir(parents=True, exist_ok=args.overwrite_experiment)
1216

0 commit comments

Comments
 (0)