Skip to content

Commit 3ebe538

Browse files
committed
fedavg algorithm is implemented
1 parent ba3615e commit 3ebe538

2 files changed

Lines changed: 10 additions & 3 deletions

File tree

fed_learn/fed_server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from keras import datasets
44

55
import fed_learn
6+
from fed_learn.weight_summarizer import WeightSummarizer
67

78

89
class Server:
9-
def __init__(self, model_fn: Callable, nb_clients: int, weight_summarizer: fed_learn.WeightSummarizer):
10+
def __init__(self, model_fn: Callable, nb_clients: int, weight_summarizer: WeightSummarizer):
1011
self.nb_clients = nb_clients
1112
self.weight_summarizer = weight_summarizer
1213

fed_learn/weight_summarizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,11 @@ def __init__(self):
1616
super().__init__()
1717

1818
def process(self, client_weight_list: List[List[np.ndarray]]) -> List[np.ndarray]:
19-
# TODO: implement simple averaging
20-
return client_weight_list[0]
19+
weights_average = [np.zeros_like(w) for w in client_weight_list[0]]
20+
21+
for i in range(len(weights_average)):
22+
w = weights_average[i]
23+
for k in range(len(client_weight_list)):
24+
w += client_weight_list[k][i]
25+
w /= len(client_weight_list)
26+
return weights_average

0 commit comments

Comments
 (0)