Spaces:
Sleeping
Sleeping
from fastai.vision.all import * | |
import gradio as gr | |
# Define the list of 1st generation Pokémon class names | |
def get_x(item): | |
return item['image'] # Access images directly | |
def get_y(item): | |
return class_names[item['labels']] # Map label index to class name | |
# Load the model | |
learn = load_learner('poke_model.pkl') | |
# Categories are derived from the vocabulary of the dataloader used during training | |
categories = learn.dls.vocab # These are the class labels used in the trained model | |
# Define the function for prediction | |
def classify_pokemon(img): | |
pred, idx, probs = learn.predict(img) | |
return dict(zip(categories, map(float, probs))) | |
# Gradio interface setup | |
title = "Pokémon first gen classifier" | |
description = "Based on the famous scene (Who's that Pokémon) of the Pokémon TV show, this neural network accurately classifies a Pokémon image." | |
image = gr.Image(type='pil') # Image size should match your training data size (e.g., 128x128) | |
label = gr.Label() | |
examples = ['zapdos.jpg'] | |
# Set up Gradio interface | |
intf = gr.Interface( | |
fn=classify_pokemon, | |
inputs=image, | |
outputs=label, | |
examples=examples, | |
title=title, | |
description=description | |
) | |
# Launch the app | |
intf.launch(inline=False) # share=True allows you to share the interface via a public link | |