|
import gradio as gr |
|
import tensorflow as tf |
|
import tensorflow_hub as hub |
|
from tensorflow.keras.preprocessing import image |
|
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input |
|
import numpy as np |
|
|
|
|
|
model_url = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4" |
|
model = hub.KerasLayer(model_url, input_shape=(224, 224, 3)) |
|
|
|
|
|
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt') |
|
imagenet_labels = np.array(open(labels_path).read().splitlines()) |
|
|
|
|
|
def classify_image(img): |
|
|
|
img = img.resize((224, 224)) |
|
|
|
|
|
img_array = image.img_to_array(img) |
|
|
|
|
|
img_array = np.expand_dims(img_array, axis=0) |
|
img_array = preprocess_input(img_array) |
|
|
|
|
|
predictions = model(img_array) |
|
predictions = tf.nn.softmax(predictions) |
|
|
|
|
|
top3_indices = tf.argsort(predictions, axis=-1, direction='DESCENDING')[0][:3] |
|
top3_labels = tf.gather(imagenet_labels, top3_indices).numpy() |
|
top3_scores = tf.gather(predictions[0], top3_indices).numpy() |
|
|
|
|
|
results = {f"{label.decode('utf-8')}": f"{score:.2f}" for label, score in zip(top3_labels, top3_scores)} |
|
return results |
|
|
|
|
|
iface = gr.Interface(fn=classify_image, |
|
inputs=gr.Image(type="pil"), |
|
outputs="label", |
|
title="Fruit and Vegetable Classification", |
|
description="Upload an image of a fruit or vegetable, and the model will classify it.") |
|
|
|
|
|
iface.launch() |
|
|