brxerq's picture
Update app.py
7331ba2
raw
history blame
1.23 kB
# -*- coding: utf-8 -*-
import gradio as gr
import numpy as np
import tensorflow_hub as hub
from tensorflow.keras.models import load_model
import cv2
# Define a dictionary to map the custom layer to its implementation
custom_objects = {'KerasLayer': hub.KerasLayer}
# Load your model (ensure the path is correct)
model = load_model('bird_model.h5', custom_objects=custom_objects)
# Define your class labels or categories for predictions
train_info = [] # Replace with your actual class labels
# Read image names from the text file
with open('labelwithspace.txt', 'r') as file:
train_info = [line.strip() for line in file.read().splitlines()]
def predict_image(image):
img = cv2.resize(image, (224, 224))
img = img / 255.0
predictions = model.predict(img[np.newaxis, ...])[0]
top_classes = np.argsort(predictions)[-3:][::-1]
top_class = top_classes[0] # Get the index of the top prediction
label = train_info[top_class] # Use the index to retrieve the label
return label
# Define Gradio interface
input_image = gr.inputs.Image(shape=(224, 224))
output_label = gr.outputs.Label()
gr.Interface(fn=predict_image, inputs=input_image, outputs=output_label, capture_session=True).launch()