victor
added txt
7095d4d
raw
history blame
2.15 kB
from fastai.vision.all import *
import gradio as gr
# Define the list of 1st generation Pokémon class names
class_names = [
"Bulbasaur", "Ivysaur", "Venusaur",
"Charmander", "Charmeleon", "Charizard",
"Squirtle", "Wartortle", "Blastoise",
"Caterpie", "Metapod", "Butterfree",
"Weedle", "Kakuna", "Beedrill",
"Pidgey", "Pidgeotto", "Pidgeot",
"Rattata", "Raticate", "Spearow",
"Fearow", "Ekans", "Arbok",
"Psyduck", "Golduck", "Machop",
"Machoke", "Machamp", "Bellsprout",
"Weepinbell", "Victreebel", "Tentacool",
"Tentacruel", "Geodude", "Graveler",
"Golem", "Ponyta", "Rapidash",
"Magnemite", "Magneton", "Krabby",
"Kingler", "Exeggcute", "Exeggutor",
"Cubone", "Marowak", "Hitmonlee",
"Hitmonchan", "Lickitung", "Koffing",
"Weezing", "Rhyhorn", "Rhydon",
"Chansey", "Tangela", "Kangaskhan",
"Horsea", "Seadra", "Goldeen",
"Seaking", "Staryu", "Starmie",
"Mr. Mime", "Scyther", "Jynx",
"Electabuzz", "Magmar", "Pinsir",
"Tauros", "Magikarp", "Gyarados",
"Lapras", "Ditto", "Eevee",
"Vaporeon", "Jolteon", "Flareon",
"Porygon", "Omanyte", "Omastar",
"Kabuto", "Kabutops", "Aerodactyl",
"Mew", "Mewtwo"
]
def get_x(item):
return item['image'] # Access images directly
def get_y(item):
return class_names[item['labels']] # Map label index to class name
# Load the model
learn = load_learner('poke_model.pkl')
# Categories are derived from the vocabulary of the dataloader used during training
categories = learn.dls.vocab # These are the class labels used in the trained model
# Define the function for prediction
def classify_pokemon(img):
pred, idx, probs = learn.predict(img)
return dict(zip(categories, map(float, probs)))
# Gradio interface setup
image = gr.inputs.Image(shape=(128, 128)) # Image size should match your training data size (e.g., 128x128)
label = gr.outputs.Label()
# Set up Gradio interface
intf = gr.Interface(
fn=classify_pokemon,
inputs=image,
outputs=label,
)
# Launch the app
intf.launch(share=True) # share=True allows you to share the interface via a public link