thomen commited on
Commit
51b626b
·
verified ·
1 Parent(s): e138f09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -7
app.py CHANGED
@@ -3,10 +3,8 @@ import numpy as np
3
  from keras.models import load_model
4
  from PIL import Image
5
 
6
- # Load your Keras model
7
  model = load_model('pokemon-model_2_transferlearning.keras')
8
 
9
- # Define function to preprocess and predict on images
10
  def predict_pokemon(image):
11
  # Resize and preprocess the image
12
  image = Image.fromarray((image * 255).astype(np.uint8))
@@ -14,21 +12,20 @@ def predict_pokemon(image):
14
  image_array = np.asarray(image)
15
  image_array = image_array / 255.0
16
 
17
- # Make prediction
18
  prediction = model.predict(np.expand_dims(image_array, axis=0))
19
  predicted_class = np.argmax(prediction)
20
 
21
- # Example: Assuming you have a list of Pokémon names
22
  pokemon_names = ['Pikachu', 'Charmander', 'Bulbasaur', ...]
23
  predicted_pokemon = pokemon_names[predicted_class]
24
 
25
  return predicted_pokemon
26
 
27
- # Define input component for Gradio
28
  input_component = gr.inputs.Image(shape=(224, 224))
29
 
30
- # Define output component for Gradio
31
  output_component = gr.outputs.Label(num_top_classes=1)
32
 
33
- # Create the Gradio interface
34
  gr.Interface(fn=predict_pokemon, inputs=input_component, outputs=output_component, title='Pokémon Classifier').launch()
 
3
  from keras.models import load_model
4
  from PIL import Image
5
 
 
6
  model = load_model('pokemon-model_2_transferlearning.keras')
7
 
 
8
  def predict_pokemon(image):
9
  # Resize and preprocess the image
10
  image = Image.fromarray((image * 255).astype(np.uint8))
 
12
  image_array = np.asarray(image)
13
  image_array = image_array / 255.0
14
 
15
+
16
  prediction = model.predict(np.expand_dims(image_array, axis=0))
17
  predicted_class = np.argmax(prediction)
18
 
19
+
20
  pokemon_names = ['Pikachu', 'Charmander', 'Bulbasaur', ...]
21
  predicted_pokemon = pokemon_names[predicted_class]
22
 
23
  return predicted_pokemon
24
 
25
+
26
  input_component = gr.inputs.Image(shape=(224, 224))
27
 
28
+
29
  output_component = gr.outputs.Label(num_top_classes=1)
30
 
 
31
  gr.Interface(fn=predict_pokemon, inputs=input_component, outputs=output_component, title='Pokémon Classifier').launch()