Skip to content

Commit fff74a1

Browse files
committed
shallow vgg16 model created for the cifar10 trainings
1 parent caa6f72 commit fff74a1

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

fed_learn/models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ def create_model(input_shape: tuple, nb_classes: int, init_with_imagenet: bool =
1212
classes=nb_classes,
1313
weights=weights,
1414
include_top=False)
15-
x = model.layers[-1].output
16-
x = layers.Flatten()(x)
15+
# "Shallow" VGG for Cifar10
16+
x = model.get_layer('block3_pool').output
17+
x = layers.Flatten(name='Flatten')(x)
18+
x = layers.Dense(512, activation='relu')(x)
1719
x = layers.Dense(nb_classes)(x)
1820
x = layers.Softmax()(x)
1921
model = models.Model(model.input, x)

0 commit comments

Comments
 (0)