Spaces:
Sleeping
Sleeping
from fastai.vision.all import * | |
import gradio as gr | |
# Define the list of 1st generation Pokémon class names | |
class_names = [ | |
"Bulbasaur", "Ivysaur", "Venusaur", | |
"Charmander", "Charmeleon", "Charizard", | |
"Squirtle", "Wartortle", "Blastoise", | |
"Caterpie", "Metapod", "Butterfree", | |
"Weedle", "Kakuna", "Beedrill", | |
"Pidgey", "Pidgeotto", "Pidgeot", | |
"Rattata", "Raticate", "Spearow", | |
"Fearow", "Ekans", "Arbok", | |
"Psyduck", "Golduck", "Machop", | |
"Machoke", "Machamp", "Bellsprout", | |
"Weepinbell", "Victreebel", "Tentacool", | |
"Tentacruel", "Geodude", "Graveler", | |
"Golem", "Ponyta", "Rapidash", | |
"Magnemite", "Magneton", "Krabby", | |
"Kingler", "Exeggcute", "Exeggutor", | |
"Cubone", "Marowak", "Hitmonlee", | |
"Hitmonchan", "Lickitung", "Koffing", | |
"Weezing", "Rhyhorn", "Rhydon", | |
"Chansey", "Tangela", "Kangaskhan", | |
"Horsea", "Seadra", "Goldeen", | |
"Seaking", "Staryu", "Starmie", | |
"Mr. Mime", "Scyther", "Jynx", | |
"Electabuzz", "Magmar", "Pinsir", | |
"Tauros", "Magikarp", "Gyarados", | |
"Lapras", "Ditto", "Eevee", | |
"Vaporeon", "Jolteon", "Flareon", | |
"Porygon", "Omanyte", "Omastar", | |
"Kabuto", "Kabutops", "Aerodactyl", | |
"Mew", "Mewtwo" | |
] | |
def get_x(item): | |
return item['image'] # Access images directly | |
def get_y(item): | |
return class_names[item['labels']] # Map label index to class name | |
# Load the model | |
learn = load_learner('poke_model.pkl') | |
# Categories are derived from the vocabulary of the dataloader used during training | |
categories = learn.dls.vocab # These are the class labels used in the trained model | |
# Define the function for prediction | |
def classify_pokemon(img): | |
pred, idx, probs = learn.predict(img) | |
return dict(zip(categories, map(float, probs))) | |
# Gradio interface setup | |
image = gr.inputs.Image(shape=(128, 128)) # Image size should match your training data size (e.g., 128x128) | |
label = gr.outputs.Label() | |
# Set up Gradio interface | |
intf = gr.Interface( | |
fn=classify_pokemon, | |
inputs=image, | |
outputs=label, | |
) | |
# Launch the app | |
intf.launch(share=True) # share=True allows you to share the interface via a public link | |