File size: 2,948 Bytes
8ee89fb
56fce04
 
ce36f5c
8ee89fb
71d89ec
1b66cf8
ce36f5c
 
71d89ec
ce36f5c
 
 
71d89ec
ce36f5c
71d89ec
ce36f5c
71d89ec
ce36f5c
 
 
 
 
71d89ec
 
 
 
 
ce36f5c
71d89ec
ce36f5c
71d89ec
 
ce36f5c
71d89ec
56fce04
ce36f5c
 
71d89ec
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image

# Load the model
model_path = "pokemon-model_2_transferlearning.keras"
model = tf.keras.models.load_model(model_path)

# Define the core prediction function
def predict_pokemon(image):
    # Preprocess image
    image = Image.fromarray(image.astype('uint8'))  # Convert numpy array to PIL image
    image = image.resize((150, 150))  # Resize the image to match model input size
    image = np.array(image)
    image = np.expand_dims(image, axis=0)  # Add batch dimension
    
    # Predict
    prediction = model.predict(image)
    
    # Apply softmax to get probabilities for each class
    prediction = tf.nn.softmax(prediction)
    
    # Extract class names from the provided list
    class_names = ['Abra', 'Aerodactyl', 'Alakazam', 'Arbok', 'Arcanine', 'Articuno', 'Beedrill', 'Bellsprout', 'Blastoise', 'Bulbasaur', 'Butterfree', 'Caterpie', 'Chansey', 'Charizard', 'Charmander', 'Charmeleon', 'Clefable', 'Clefairy', 'Cloyster', 'Cubone', 'Dewgong', 'Diglett', 'Ditto', 'Dodrio', 'Doduo', 'Dragonair', 'Dragonite', 'Dratini', 'Drowzee', 'Dugtrio', 'Eevee', 'Ekans', 'Electabuzz', 'Electrode', 'Exeggcute', 'Exeggutor', 'Farfetchd', 'Fearow', 'Flareon', 'Gastly', 'Gengar', 'Geodude', 'Gloom', 'Golbat', 'Goldeen', 'Golduck', 'Graveler', 'Grimer', 'Growlithe', 'Gyarados', 'Haunter', 'Hitmonchan', 'Hitmonlee', 'Horsea', 'Hypno', 'Ivysaur', 'Jigglypuff', 'Jolteon', 'Jynx', 'Kabutops', 'Kadabra', 'Kakuna', 'Kangaskhan', 'Kingler', 'Koffing', 'Lapras', 'Lickitung', 'Machamp', 'Machoke', 'Machop', 'Magikarp', 'Magmar', 'Magnemite', 'Magneton', 'Mankey', 'Marowak', 'Meowth', 'Metapod', 'Mew', 'Mewtwo', 'Moltres', 'Mr. Mime', 'MrMime', 'Nidoking', 'Nidoqueen', 'Nidorina', 'Nidorino', 'Ninetales', 'Oddish', 'Omanyte', 'Omastar', 'Parasect', 'Pidgeot', 'Pidgeotto', 'Pidgey', 'Pikachu', 'Pinsir', 'Poliwag', 'Poliwhirl', 'Poliwrath', 'Ponyta', 'Porygon', 'Primeape', 'Psyduck', 'Raichu', 'Rapidash', 'Raticate', 'Rattata', 'Rhydon', 'Rhyhorn', 'Sandshrew', 'Sandslash', 'Scyther', 'Seadra', 'Seaking', 'Seel', 'Shellder', 'Slowbro', 'Slowpoke', 'Snorlax', 'Spearow', 'Squirtle', 'Starmie', 'Staryu', 'Tangela', 'Tauros', 'Tentacool', 'Tentacruel', 'Vaporeon', 'Venomoth', 'Venonat', 'Venusaur', 'Victreebel', 'Vileplume', 'Voltorb', 'Vulpix', 'Wartortle', 'Weedle', 'Weepinbell', 'Weezing', 'Wigglytuff', 'Zapdos', 'Zubat']
    
    # Create a dictionary with the probabilities for each class
    probabilities = {class_names[i]: np.round(float(prediction[0][i]), 2) for i in range(len(class_names))}
    
    return probabilities

# Update input component to accept image uploads
input_image = gr.inputs.Image()

# Launch the interface
iface = gr.Interface(
    fn=predict_pokemon,
    inputs=input_image, 
    outputs=gr.outputs.Label(),
    description="A simple MLP classification model for Pokémon image classification.")
iface.launch(share=True)