MobileNet contains 4.2M parameters. VGG16 contains 138M parameters. MobileNet is 32 times smaller and 10× faster than VGG16, yet has little reduction in accuracy.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation, Dense, Flatten, Conv2D, MaxPool2D, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import categorical_crossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import confusion_matrix
import itertools
import shutil
import random
import os
import glob
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
import numpy as np
#run on cpu
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
#train on GPU
pysical_devices = tf.config.experimental.list_physical_devices('GPU')
#print("Num GPUs Available: ", len(pysical_devices))
tf.config.experimental.set_memory_growth(pysical_devices[0], True)
train_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input)\
.flow_from_directory(directory='C://Users//bob//keras//data//train', target_size=(224, 224),
classes=['cat', 'dog'], batch_size=32)
valid_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input)\
.flow_from_directory(directory='C://Users//bob//keras//data//valid', target_size=(224, 224),
classes=['cat', 'dog'], batch_size=32)
test_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input)\
.flow_from_directory(directory='C://Users//bob//keras//data//test', target_size=(224, 224),
classes=['cat', 'dog'], batch_size=32, shuffle=False)
imgs, labels = next(train_batches)
def plotImages(image_arr):
fig, axes = plt.subplots(1, 10, figsize=(20, 20))
axes = axes.flatten()
for img, ax in zip(image_arr, axes):
ax.imshow(img)
ax.axis('off')
plt.tight_layout()
plt.show()
#plotImages(imgs)
#print(labels)
mobilenet_model = tf.keras.applications.mobilenet.MobileNet()
mobilenet_model.summary()
#copy all layers from mobilenet except last one
model = Sequential()
for layer in mobilenet_model.layers[:-1]:
model.add(layer)
#model.summary()
#freeze copied layers
for layer in model.layers:
layer.trainable = False
#add last trainable layer
model.add(Dense(units=2, activation='softmax'))
model.summary()
#train model
model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x=train_batches, validation_data=valid_batches, epochs=10, verbose=2)
test_imgs, test_labels = next(test_batches)
#plotImages(test_imgs)
#print(test_labels)
#print( test_batches.classes)
predictions = model.predict(x=test_batches, verbose=0)
cm = confusion_matrix(y_true=test_batches.classes, y_pred=np.argmax(predictions, axis=-1))
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
cm_plot_labels = ['cat', 'dog']
plot_confusion_matrix(cm=cm, classes=cm_plot_labels, title='Confusion Matrix')
--------------------------
#logs
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 224, 224, 3)] 0
_________________________________________________________________
conv1 (Conv2D) (None, 112, 112, 32) 864
_________________________________________________________________
conv1_bn (BatchNormalization (None, 112, 112, 32) 128
_________________________________________________________________
conv1_relu (ReLU) (None, 112, 112, 32) 0
_________________________________________________________________
conv_dw_1 (DepthwiseConv2D) (None, 112, 112, 32) 288
_________________________________________________________________
conv_dw_1_bn (BatchNormaliza (None, 112, 112, 32) 128
_________________________________________________________________
conv_dw_1_relu (ReLU) (None, 112, 112, 32) 0
_________________________________________________________________
conv_pw_1 (Conv2D) (None, 112, 112, 64) 2048
_________________________________________________________________
conv_pw_1_bn (BatchNormaliza (None, 112, 112, 64) 256
_________________________________________________________________
conv_pw_1_relu (ReLU) (None, 112, 112, 64) 0
_________________________________________________________________
conv_pad_2 (ZeroPadding2D) (None, 113, 113, 64) 0
_________________________________________________________________
conv_dw_2 (DepthwiseConv2D) (None, 56, 56, 64) 576
_________________________________________________________________
conv_dw_2_bn (BatchNormaliza (None, 56, 56, 64) 256
_________________________________________________________________
conv_dw_2_relu (ReLU) (None, 56, 56, 64) 0
_________________________________________________________________
conv_pw_2 (Conv2D) (None, 56, 56, 128) 8192
_________________________________________________________________
conv_pw_2_bn (BatchNormaliza (None, 56, 56, 128) 512
_________________________________________________________________
conv_pw_2_relu (ReLU) (None, 56, 56, 128) 0
_________________________________________________________________
conv_dw_3 (DepthwiseConv2D) (None, 56, 56, 128) 1152
_________________________________________________________________
conv_dw_3_bn (BatchNormaliza (None, 56, 56, 128) 512
_________________________________________________________________
conv_dw_3_relu (ReLU) (None, 56, 56, 128) 0
_________________________________________________________________
conv_pw_3 (Conv2D) (None, 56, 56, 128) 16384
_________________________________________________________________
conv_pw_3_bn (BatchNormaliza (None, 56, 56, 128) 512
_________________________________________________________________
conv_pw_3_relu (ReLU) (None, 56, 56, 128) 0
_________________________________________________________________
conv_pad_4 (ZeroPadding2D) (None, 57, 57, 128) 0
_________________________________________________________________
conv_dw_4 (DepthwiseConv2D) (None, 28, 28, 128) 1152
_________________________________________________________________
conv_dw_4_bn (BatchNormaliza (None, 28, 28, 128) 512
_________________________________________________________________
conv_dw_4_relu (ReLU) (None, 28, 28, 128) 0
_________________________________________________________________
conv_pw_4 (Conv2D) (None, 28, 28, 256) 32768
_________________________________________________________________
conv_pw_4_bn (BatchNormaliza (None, 28, 28, 256) 1024
_________________________________________________________________
conv_pw_4_relu (ReLU) (None, 28, 28, 256) 0
_________________________________________________________________
conv_dw_5 (DepthwiseConv2D) (None, 28, 28, 256) 2304
_________________________________________________________________
conv_dw_5_bn (BatchNormaliza (None, 28, 28, 256) 1024
_________________________________________________________________
conv_dw_5_relu (ReLU) (None, 28, 28, 256) 0
_________________________________________________________________
conv_pw_5 (Conv2D) (None, 28, 28, 256) 65536
_________________________________________________________________
conv_pw_5_bn (BatchNormaliza (None, 28, 28, 256) 1024
_________________________________________________________________
conv_pw_5_relu (ReLU) (None, 28, 28, 256) 0
_________________________________________________________________
conv_pad_6 (ZeroPadding2D) (None, 29, 29, 256) 0
_________________________________________________________________
conv_dw_6 (DepthwiseConv2D) (None, 14, 14, 256) 2304
_________________________________________________________________
conv_dw_6_bn (BatchNormaliza (None, 14, 14, 256) 1024
_________________________________________________________________
conv_dw_6_relu (ReLU) (None, 14, 14, 256) 0
_________________________________________________________________
conv_pw_6 (Conv2D) (None, 14, 14, 512) 131072
_________________________________________________________________
conv_pw_6_bn (BatchNormaliza (None, 14, 14, 512) 2048
_________________________________________________________________
conv_pw_6_relu (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv_dw_7 (DepthwiseConv2D) (None, 14, 14, 512) 4608
_________________________________________________________________
conv_dw_7_bn (BatchNormaliza (None, 14, 14, 512) 2048
_________________________________________________________________
conv_dw_7_relu (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv_pw_7 (Conv2D) (None, 14, 14, 512) 262144
_________________________________________________________________
conv_pw_7_bn (BatchNormaliza (None, 14, 14, 512) 2048
_________________________________________________________________
conv_pw_7_relu (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv_dw_8 (DepthwiseConv2D) (None, 14, 14, 512) 4608
_________________________________________________________________
conv_dw_8_bn (BatchNormaliza (None, 14, 14, 512) 2048
_________________________________________________________________
conv_dw_8_relu (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv_pw_8 (Conv2D) (None, 14, 14, 512) 262144
_________________________________________________________________
conv_pw_8_bn (BatchNormaliza (None, 14, 14, 512) 2048
_________________________________________________________________
conv_pw_8_relu (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv_dw_9 (DepthwiseConv2D) (None, 14, 14, 512) 4608
_________________________________________________________________
conv_dw_9_bn (BatchNormaliza (None, 14, 14, 512) 2048
_________________________________________________________________
conv_dw_9_relu (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv_pw_9 (Conv2D) (None, 14, 14, 512) 262144
_________________________________________________________________
conv_pw_9_bn (BatchNormaliza (None, 14, 14, 512) 2048
_________________________________________________________________
conv_pw_9_relu (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv_dw_10 (DepthwiseConv2D) (None, 14, 14, 512) 4608
_________________________________________________________________
conv_dw_10_bn (BatchNormaliz (None, 14, 14, 512) 2048
_________________________________________________________________
conv_dw_10_relu (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv_pw_10 (Conv2D) (None, 14, 14, 512) 262144
_________________________________________________________________
conv_pw_10_bn (BatchNormaliz (None, 14, 14, 512) 2048
_________________________________________________________________
conv_pw_10_relu (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv_dw_11 (DepthwiseConv2D) (None, 14, 14, 512) 4608
_________________________________________________________________
conv_dw_11_bn (BatchNormaliz (None, 14, 14, 512) 2048
_________________________________________________________________
conv_dw_11_relu (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv_pw_11 (Conv2D) (None, 14, 14, 512) 262144
_________________________________________________________________
conv_pw_11_bn (BatchNormaliz (None, 14, 14, 512) 2048
_________________________________________________________________
conv_pw_11_relu (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv_pad_12 (ZeroPadding2D) (None, 15, 15, 512) 0
_________________________________________________________________
conv_dw_12 (DepthwiseConv2D) (None, 7, 7, 512) 4608
_________________________________________________________________
conv_dw_12_bn (BatchNormaliz (None, 7, 7, 512) 2048
_________________________________________________________________
conv_dw_12_relu (ReLU) (None, 7, 7, 512) 0
_________________________________________________________________
conv_pw_12 (Conv2D) (None, 7, 7, 1024) 524288
_________________________________________________________________
conv_pw_12_bn (BatchNormaliz (None, 7, 7, 1024) 4096
_________________________________________________________________
conv_pw_12_relu (ReLU) (None, 7, 7, 1024) 0
_________________________________________________________________
conv_dw_13 (DepthwiseConv2D) (None, 7, 7, 1024) 9216
_________________________________________________________________
conv_dw_13_bn (BatchNormaliz (None, 7, 7, 1024) 4096
_________________________________________________________________
conv_dw_13_relu (ReLU) (None, 7, 7, 1024) 0
_________________________________________________________________
conv_pw_13 (Conv2D) (None, 7, 7, 1024) 1048576
_________________________________________________________________
conv_pw_13_bn (BatchNormaliz (None, 7, 7, 1024) 4096
_________________________________________________________________
conv_pw_13_relu (ReLU) (None, 7, 7, 1024) 0
_________________________________________________________________
global_average_pooling2d (Gl (None, 1024) 0
_________________________________________________________________
reshape_1 (Reshape) (None, 1, 1, 1024) 0
_________________________________________________________________
dropout (Dropout) (None, 1, 1, 1024) 0
_________________________________________________________________
conv_preds (Conv2D) (None, 1, 1, 1000) 1025000
_________________________________________________________________
reshape_2 (Reshape) (None, 1000) 0
_________________________________________________________________
predictions (Activation) (None, 1000) 0
=================================================================
Total params: 4,253,864
Trainable params: 4,231,976
Non-trainable params: 21,888
_________________________________________________________________
32/32 - 5s - loss: 1.6811 - accuracy: 0.6460 - val_loss: 0.7809 - val_accuracy: 0.7600
Epoch 2/10
32/32 - 2s - loss: 0.4972 - accuracy: 0.8330 - val_loss: 0.3682 - val_accuracy: 0.8600
Epoch 3/10
32/32 - 2s - loss: 0.2633 - accuracy: 0.9140 - val_loss: 0.2485 - val_accuracy: 0.9050
Epoch 4/10
32/32 - 2s - loss: 0.1945 - accuracy: 0.9290 - val_loss: 0.1954 - val_accuracy: 0.9300
Epoch 5/10
32/32 - 2s - loss: 0.1546 - accuracy: 0.9390 - val_loss: 0.1614 - val_accuracy: 0.9300
Epoch 6/10
32/32 - 2s - loss: 0.1304 - accuracy: 0.9530 - val_loss: 0.1418 - val_accuracy: 0.9300
Epoch 7/10
32/32 - 2s - loss: 0.1162 - accuracy: 0.9580 - val_loss: 0.1225 - val_accuracy: 0.9350
Epoch 8/10
32/32 - 2s - loss: 0.1050 - accuracy: 0.9640 - val_loss: 0.1165 - val_accuracy: 0.9400
Epoch 9/10
32/32 - 2s - loss: 0.0941 - accuracy: 0.9700 - val_loss: 0.1054 - val_accuracy: 0.9500
Epoch 10/10
32/32 - 2s - loss: 0.0846 - accuracy: 0.9720 - val_loss: 0.0994 - val_accuracy: 0.9600
No comments:
Post a Comment