Skip to content

Commit f64623c

Browse files
committed
images are minmax scaled
1 parent e60fe2a commit f64623c

2 files changed

Lines changed: 12 additions & 4 deletions

File tree

fed_learn/fed_server.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Callable
2+
23
import numpy as np
34
from keras import datasets, utils
45

@@ -7,7 +8,9 @@
78

89

910
class Server:
10-
def __init__(self, model_fn: Callable, nb_clients: int, weight_summarizer: WeightSummarizer,
11+
def __init__(self, model_fn: Callable,
12+
nb_clients: int,
13+
weight_summarizer: WeightSummarizer,
1114
only_debugging: bool = True):
1215
self.nb_clients = nb_clients
1316
self.weight_summarizer = weight_summarizer
@@ -17,13 +20,17 @@ def __init__(self, model_fn: Callable, nb_clients: int, weight_summarizer: Weigh
1720
self.model_weights = model.get_weights()
1821
fed_learn.get_rid_of_the_models(model)
1922

20-
(x_train, y_train), (_, _) = datasets.cifar10.load_data()
23+
(x_train, y_train), _ = datasets.cifar10.load_data()
2124

2225
if only_debugging:
26+
# TODO: remove me
2327
x_train = x_train[:100]
2428
y_train = y_train[:100]
2529

30+
# TODO: separate preprocessor for the data transformations
2631
y_train = utils.to_categorical(y_train, len(np.unique(y_train)))
32+
x_train = x_train.astype(np.float32)
33+
x_train /= 255.0
2734

2835
self.x_train = x_train
2936
self.y_train = y_train
@@ -40,6 +47,7 @@ def __init__(self, model_fn: Callable, nb_clients: int, weight_summarizer: Weigh
4047
"shuffle": True}
4148

4249
if only_debugging:
50+
# TODO: remove me
4351
self.train_dict["epochs"] = 1
4452

4553
def _generate_data_indices(self):

single_model_learn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
model = fed_learn.create_model((32, 32, 3), 10)
88

9-
(x_train, y_train), (_, _) = datasets.cifar10.load_data()
9+
(x_train, y_train), _ = datasets.cifar10.load_data()
1010
y_train = utils.to_categorical(y_train, len(np.unique(y_train)))
1111
x_train = x_train.astype(np.float32)
12-
x_train = x_train / 255.0
12+
x_train /= 255.0
1313

1414
model.fit(x_train, y_train, batch_size=32, epochs=20, verbose=1)

0 commit comments

Comments
 (0)