a-guy-from-burma's picture
Update app.py
357f978 verified
import gradio as gr
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
import numpy as np
# Load the MobileNet V2 model from TensorFlow Hub
model_url = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"
model = hub.KerasLayer(model_url, input_shape=(224, 224, 3)) # Input shape specified for MobileNet V2
# Load the ImageNet labels for MobileNetV2
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
# Define the prediction function
def classify_image(img):
# Resize the image to the required size
img = img.resize((224, 224))
# Convert image to array
img_array = image.img_to_array(img)
# Add batch dimension and preprocess the image
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)
# Make predictions
predictions = model(img_array)
predictions = tf.nn.softmax(predictions) # Apply softmax to get probabilities
# Get top 3 predictions
top3_indices = tf.argsort(predictions, axis=-1, direction='DESCENDING')[0][:3]
top3_labels = tf.gather(imagenet_labels, top3_indices).numpy()
top3_scores = tf.gather(predictions[0], top3_indices).numpy()
# Format the top 3 predictions
results = {f"{label.decode('utf-8')}": f"{score:.2f}" for label, score in zip(top3_labels, top3_scores)}
return results
# Create the Gradio interface
iface = gr.Interface(fn=classify_image,
inputs=gr.Image(type="pil"),
outputs="label",
title="Fruit and Vegetable Classification",
description="Upload an image of a fruit or vegetable, and the model will classify it.")
# Launch the Gradio app
iface.launch()