test-model / cat&dog.py
zegoop's picture
Update cat&dog.py
ea3d8c3
# -*- coding: utf-8 -*-
"""Cat&Dogs.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/17AFfKN67SFvxF7FdjjugeJGIa000SGU8
"""
import tensorflow as tf
#load the datasets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
#pre-process the data
x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)
#define the model input and set the layers
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128, activation=tf.nn.leaky_relu))
model.add(tf.keras.layers.Dense(128, activation=tf.nn.leaky_relu))
model.add(tf.keras.layers.Dense(10, activation=tf.nn.sigmoid))
#compile the model
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
#train the model
model.fit(x_train, y_train, epochs=100)
#evaluate the model
val_loss, val_acc = model.evaluate(x_test, y_test)
print(val_loss)
print(val_acc)
#make predictions
predictions = model.predict(x_test)
import matplotlib.pyplot as plt
import numpy as np
# Select 5 random images from the test set
indices = np.random.randint(0, len(x_test), size=1)
images = x_test[indices]
# Make predictions for the selected images
predictions = model.predict(images)
# Iterate over the images and predictions
for i, (image, prediction) in enumerate(zip(images, predictions)):
# Convert the image to uint8 and reshape it to (32, 32, 3)
image = np.uint8(image * 255).reshape(32, 32, 3)
# Get the class label and probability
label = np.argmax(prediction)
probability = prediction[label]
# Plot the image and the prediction
plt.subplot(1, 5, i + 1)
plt.imshow(image)
# The labels of the CIFAR-10 dataset are represented as integers in the range 0 to 9. Each integer corresponds to a class of image:
# 0: airplane
# 1: automobile
# 2: bird
# 3: cat
# 4: deer
# 5: dog
# 6: frog
# 7: horse
# 8: ship
# 9: truck
plt.title("Prediction: {} ({:.2f})".format(label, probability))
plt.show()