File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 11from keras import backend as K
2- from keras import optimizers , losses , models
2+ from keras import optimizers , losses , models , layers
33from keras .applications .vgg16 import VGG16
44
55
6- def create_model (input_shape : tuple ,
7- nb_classes : int ,
8- optimizer = optimizers .Adam (lr = 0.001 ),
9- loss = losses .categorical_crossentropy ):
6+ def create_model (input_shape : tuple , nb_classes : int , init_with_imagenet : bool = False ):
7+ weights = None
8+ if init_with_imagenet :
9+ weights = "imagenet"
10+
1011 model = VGG16 (input_shape = input_shape ,
1112 classes = nb_classes ,
12- weights = 'imagenet' ,
13+ weights = weights ,
1314 include_top = False )
15+ x = model .layers [- 1 ].output
16+ x = layers .Flatten ()(x )
17+ x = layers .Dense (nb_classes )(x )
18+ x = layers .Softmax ()(x )
19+ model = models .Model (model .input , x )
20+
21+ loss = losses .categorical_crossentropy
22+ optimizer = optimizers .Adam (lr = 0.001 )
23+
1424 model .compile (optimizer , loss , metrics = ["accuracy" ])
1525 return model
1626
You can’t perform that action at this time.
0 commit comments