from fastai.vision.all import * import gradio as gr # Define the list of 1st generation Pokémon class names class_names = [ "Bulbasaur", "Ivysaur", "Venusaur", "Charmander", "Charmeleon", "Charizard", "Squirtle", "Wartortle", "Blastoise", "Caterpie", "Metapod", "Butterfree", "Weedle", "Kakuna", "Beedrill", "Pidgey", "Pidgeotto", "Pidgeot", "Rattata", "Raticate", "Spearow", "Fearow", "Ekans", "Arbok", "Psyduck", "Golduck", "Machop", "Machoke", "Machamp", "Bellsprout", "Weepinbell", "Victreebel", "Tentacool", "Tentacruel", "Geodude", "Graveler", "Golem", "Ponyta", "Rapidash", "Magnemite", "Magneton", "Krabby", "Kingler", "Exeggcute", "Exeggutor", "Cubone", "Marowak", "Hitmonlee", "Hitmonchan", "Lickitung", "Koffing", "Weezing", "Rhyhorn", "Rhydon", "Chansey", "Tangela", "Kangaskhan", "Horsea", "Seadra", "Goldeen", "Seaking", "Staryu", "Starmie", "Mr. Mime", "Scyther", "Jynx", "Electabuzz", "Magmar", "Pinsir", "Tauros", "Magikarp", "Gyarados", "Lapras", "Ditto", "Eevee", "Vaporeon", "Jolteon", "Flareon", "Porygon", "Omanyte", "Omastar", "Kabuto", "Kabutops", "Aerodactyl", "Mew", "Mewtwo" ] def get_x(item): return item['image'] # Access images directly def get_y(item): return class_names[item['labels']] # Map label index to class name # Load the model learn = load_learner('poke_model.pkl') # Categories are derived from the vocabulary of the dataloader used during training categories = learn.dls.vocab # These are the class labels used in the trained model # Define the function for prediction def classify_pokemon(img): pred, idx, probs = learn.predict(img) return dict(zip(categories, map(float, probs))) # Gradio interface setup image = gr.inputs.Image(shape=(128, 128)) # Image size should match your training data size (e.g., 128x128) label = gr.outputs.Label() # Set up Gradio interface intf = gr.Interface( fn=classify_pokemon, inputs=image, outputs=label, ) # Launch the app intf.launch(share=True) # share=True allows you to share the interface via a public link