Spaces:
Runtime error
Runtime error
Update pokemon-deploy.py
Browse files- pokemon-deploy.py +34 -16
pokemon-deploy.py
CHANGED
@@ -3,38 +3,56 @@ import tensorflow as tf
|
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
|
6 |
-
|
7 |
model_path = "pokemon-predict-model_transferlearning.keras"
|
8 |
model = tf.keras.models.load_model(model_path)
|
9 |
|
10 |
# Define the core prediction function
|
11 |
def predict_pokemon(image):
|
12 |
# Preprocess image
|
13 |
-
print(type(image))
|
14 |
image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image
|
15 |
-
image = image.resize((150, 150))
|
16 |
image = np.array(image)
|
17 |
-
image = np.expand_dims(image, axis=0)
|
18 |
-
|
19 |
# Predict
|
20 |
prediction = model.predict(image)
|
21 |
-
|
22 |
# Apply softmax to get probabilities for each class
|
23 |
prediction = tf.nn.softmax(prediction)
|
24 |
-
|
25 |
# Create a dictionary with the probabilities for each Pokemon
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
|
|
33 |
|
|
|
34 |
input_image = gr.Image()
|
35 |
iface = gr.Interface(
|
36 |
fn=predict_pokemon,
|
37 |
-
inputs=input_image,
|
38 |
outputs=gr.Label(),
|
39 |
-
description="A simple
|
40 |
-
|
|
|
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
|
6 |
+
# Load the model
|
7 |
model_path = "pokemon-predict-model_transferlearning.keras"
|
8 |
model = tf.keras.models.load_model(model_path)
|
9 |
|
10 |
# Define the core prediction function
|
11 |
def predict_pokemon(image):
|
12 |
# Preprocess image
|
|
|
13 |
image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image
|
14 |
+
image = image.resize((150, 150)) # Resize the image to 150x150
|
15 |
image = np.array(image)
|
16 |
+
image = np.expand_dims(image, axis=0) # Same as image[None, ...]
|
17 |
+
|
18 |
# Predict
|
19 |
prediction = model.predict(image)
|
20 |
+
|
21 |
# Apply softmax to get probabilities for each class
|
22 |
prediction = tf.nn.softmax(prediction)
|
23 |
+
|
24 |
# Create a dictionary with the probabilities for each Pokemon
|
25 |
+
pokemon_classes = [
|
26 |
+
'Abra', 'Aerodactyl', 'Alakazam', 'Arbok', 'Arcanine', 'Articuno', 'Beedrill', 'Bellsprout',
|
27 |
+
'Blastoise', 'Bulbasaur', 'Butterfree', 'Caterpie', 'Chansey', 'Charizard', 'Charmander',
|
28 |
+
'Charmeleon', 'Clefable', 'Clefairy', 'Cloyster', 'Cubone', 'Dewgong', 'Diglett', 'Ditto',
|
29 |
+
'Dodrio', 'Doduo', 'Dragonair', 'Dragonite', 'Dratini', 'Drowzee', 'Dugtrio', 'Eevee', 'Ekans',
|
30 |
+
'Electabuzz', 'Electrode', 'Exeggcute', 'Exeggutor', 'Farfetchd', 'Fearow', 'Flareon', 'Gastly',
|
31 |
+
'Gengar', 'Geodude', 'Gloom', 'Golbat', 'Goldeen', 'Golduck', 'Graveler', 'Grimer', 'Growlithe',
|
32 |
+
'Gyarados', 'Haunter', 'Hitmonchan', 'Hitmonlee', 'Horsea', 'Hypno', 'Ivysaur', 'Jigglypuff',
|
33 |
+
'Jolteon', 'Jynx', 'Kabutops', 'Kadabra', 'Kakuna', 'Kangaskhan', 'Kingler', 'Koffing', 'Lapras',
|
34 |
+
'Lickitung', 'Machamp', 'Machoke', 'Machop', 'Magikarp', 'Magmar', 'Magnemite', 'Magneton', 'Mankey',
|
35 |
+
'Marowak', 'Meowth', 'Metapod', 'Mew', 'Mewtwo', 'Moltres', 'Mr. Mime', 'MrMime', 'Nidoking', 'Nidoqueen',
|
36 |
+
'Nidorina', 'Nidorino', 'Ninetales', 'Oddish', 'Omanyte', 'Omastar', 'Parasect', 'Pidgeot', 'Pidgeotto',
|
37 |
+
'Pidgey', 'Pikachu', 'Pinsir', 'Poliwag', 'Poliwhirl', 'Poliwrath', 'Ponyta', 'Porygon', 'Primeape',
|
38 |
+
'Psyduck', 'Raichu', 'Rapidash', 'Raticate', 'Rattata', 'Rhydon', 'Rhyhorn', 'Sandshrew', 'Sandslash',
|
39 |
+
'Scyther', 'Seadra', 'Seaking', 'Seel', 'Shellder', 'Slowbro', 'Slowpoke', 'Snorlax', 'Spearow', 'Squirtle',
|
40 |
+
'Starmie', 'Staryu', 'Tangela', 'Tauros', 'Tentacool', 'Tentacruel', 'Vaporeon', 'Venomoth', 'Venonat',
|
41 |
+
'Venusaur', 'Victreebel', 'Vileplume', 'Voltorb', 'Vulpix', 'Wartortle', 'Weedle', 'Weepinbell', 'Weezing',
|
42 |
+
'Wigglytuff', 'Zapdos', 'Zubat'
|
43 |
+
]
|
44 |
+
|
45 |
+
probabilities = [np.round(float(prediction[0][i]), 2) for i in range(len(pokemon_classes))]
|
46 |
+
pokemon_probabilities = dict(zip(pokemon_classes, probabilities))
|
47 |
|
48 |
+
return pokemon_probabilities
|
49 |
|
50 |
+
# Interface setup
|
51 |
input_image = gr.Image()
|
52 |
iface = gr.Interface(
|
53 |
fn=predict_pokemon,
|
54 |
+
inputs=input_image,
|
55 |
outputs=gr.Label(),
|
56 |
+
description="A simple MLP classification model for image classification using the Pokémon dataset."
|
57 |
+
)
|
58 |
+
iface.launch(share=True)
|