create test, train, valid folder under data
copy 500 cats images into cats folder, 500 into dogs images into dog folder
copy 100 cats and dogs images into valid, and 50 for each into test
data source https://www.kaggle.com/tongpython/cat-and-dog
#cat_dog.py
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation, Dense, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import categorical_crossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import confusion_matrix
import itertools
import shutil
import random
import os
import glob
import matplotlib.pyplot as plt
#train on GPU
pysical_devices = tf.config.experimental.list_physical_devices('GPU')
#print("Num GPUs Available: ", len(pysical_devices))
tf.config.experimental.set_memory_growth(pysical_devices[0], True)
"""
print(os.path.exists('C://Users//bob//keras//data'))
for c in random.sample(glob.glob('C://Users//bob//keras//data//cats//cat*'), 500):
shutil.move(c, 'C://Users//bob//keras//data//cats//train')
"""
train_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.vgg16.preprocess_input)\
.flow_from_directory(directory='C://Users//bob//keras//data//train', target_size=(224, 224),
classes=['cat', 'dog'], batch_size=32)
valid_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.vgg16.preprocess_input)\
.flow_from_directory(directory='C://Users//bob//keras//data//valid', target_size=(224, 224),
classes=['cat', 'dog'], batch_size=32)
test_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.vgg16.preprocess_input)\
.flow_from_directory(directory='C://Users//bob//keras//data//test', target_size=(224, 224),
classes=['cat', 'dog'], batch_size=32, shuffle=False)
imgs, labels = next(train_batches)
def plotImages(image_arr):
fig, axes = plt.subplots(1, 10, figsize=(20, 20))
axes = axes.flatten()
for img, ax in zip(image_arr, axes):
ax.imshow(img)
ax.axis('off')
plt.tight_layout()
plt.show()
plotImages(imgs)
print(labels)
image after vgg16 preprocess
#logs
Found 1000 images belonging to 2 classes.
Found 200 images belonging to 2 classes.
Found 100 images belonging to 2 classes.
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
...
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
[[1. 0.]
[1. 0.]
[0. 1.]
...
[0. 1.]
[1. 0.]]
reference:
No comments:
Post a Comment