Skip to content

Commit b985f6d

Browse files
committed
y_train is converted to onehot vectors
1 parent 71a6fce commit b985f6d

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

fed_learn/fed_server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Callable
2-
3-
from keras import datasets
2+
import numpy as np
3+
from keras import datasets, utils
44

55
import fed_learn
66
from fed_learn.weight_summarizer import WeightSummarizer
@@ -17,6 +17,7 @@ def __init__(self, model_fn: Callable, nb_clients: int, weight_summarizer: Weigh
1717
fed_learn.get_rid_of_the_models(model)
1818

1919
(x_train, y_train), (_, _) = datasets.cifar10.load_data()
20+
y_train = utils.to_categorical(y_train, len(np.unique(y_train)))
2021

2122
self.x_train = x_train
2223
self.y_train = y_train

0 commit comments

Comments
 (0)