Skip to content

Commit 71a6fce

Browse files
committed
fedavg test added
1 parent 3ebe538 commit 71a6fce

1 file changed

Lines changed: 33 additions & 0 deletions

File tree

tests/test_weight_summarizer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import fed_learn
2+
import numpy as np
3+
import unittest
4+
5+
6+
class TestFedAvgAlgorithm(unittest.TestCase):
7+
8+
def setUp(self) -> None:
9+
self.weight_summarizer = fed_learn.FedAvg()
10+
11+
nb_clients = 3
12+
nb_weight_arrays = 6
13+
14+
self.all_clients_weights = []
15+
16+
for i in range(nb_clients):
17+
client_weight_arrays = []
18+
for k in range(nb_weight_arrays):
19+
rnd_weight_array = np.ones((8, 12))
20+
rnd_weight_array += i
21+
client_weight_arrays.append(rnd_weight_array)
22+
self.all_clients_weights.append(client_weight_arrays)
23+
24+
self.avg_weights = self.weight_summarizer.process(self.all_clients_weights)
25+
26+
def test_basic_averaging_mean(self):
27+
self.assertAlmostEqual(np.mean(self.avg_weights), 2.0)
28+
29+
def test_basic_averaging_min(self):
30+
self.assertAlmostEqual(np.min(self.avg_weights), 2.0)
31+
32+
def test_basic_averaging_max(self):
33+
self.assertAlmostEqual(np.max(self.avg_weights), 2.0)

0 commit comments

Comments
 (0)