thomen commited on
Commit
56fce04
1 Parent(s): bed9915

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -21
app.py CHANGED
@@ -1,31 +1,38 @@
 
1
  import gradio as gr
2
- import numpy as np
3
- from keras.models import load_model
4
  from PIL import Image
 
5
 
6
- model = load_model('pokemon-model_2_transferlearning.keras')
7
-
8
- def predict_pokemon(image):
9
- # Resize and preprocess the image
10
- image = Image.fromarray((image * 255).astype(np.uint8))
11
- image = image.resize((224, 224))
12
- image_array = np.asarray(image)
13
- image_array = image_array / 255.0
14
-
15
-
16
- prediction = model.predict(np.expand_dims(image_array, axis=0))
17
- predicted_class = np.argmax(prediction)
18
-
19
 
20
- pokemon_names = ['Pikachu', 'Charmander', 'Bulbasaur', ...]
21
- predicted_pokemon = pokemon_names[predicted_class]
 
22
 
23
- return predicted_pokemon
24
 
 
 
 
 
25
 
26
- input_component = gr.inputs.Image(shape=(224, 224))
 
 
 
 
27
 
 
28
 
29
- output_component = gr.outputs.Label(num_top_classes=1)
 
 
 
 
 
 
 
30
 
31
- gr.Interface(fn=predict_pokemon, inputs=input_component, outputs=output_component, title='Pokémon Classifier').launch()
 
 
1
+
2
  import gradio as gr
3
+ import tensorflow as tf
 
4
  from PIL import Image
5
+ import numpy as np
6
 
7
+ labels = ['Haunter', 'Gengar', 'Ditto', 'Vulpix']
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ def predict_pokemon_type(uploaded_file):
10
+ if uploaded_file is None:
11
+ return "No file uploaded.", None, "No prediction"
12
 
13
+ model = tf.keras.models.load_model('pokemon-model_2_transferlearning.keras')
14
 
15
+ # Load the image from the file path
16
+ with Image.open(uploaded_file) as img:
17
+ img = img.resize((150, 150))
18
+ img_array = np.array(img)
19
 
20
+ prediction = model.predict(np.expand_dims(img_array, axis=0))
21
+ confidences = {labels[i]: np.round(float(prediction[0][i]), 2) for i in range(len(labels))}
22
+
23
+ # Identify the most confident prediction
24
+ confidences = {labels[i]: np.round(float(prediction[0][i]), 2) for i in range(len(labels))}
25
 
26
+ return img, confidences
27
 
28
+ # Define the Gradio interface
29
+ iface = gr.Interface(
30
+ fn=predict_pokemon_type, # Function to process the input
31
+ inputs=gr.File(label="Upload File"), # File upload widget
32
+ outputs=["image", "text"], # Output types for image and text
33
+ title="Pokemon Classifier", # Title of the interface
34
+ description="Upload a picture of a Pokemon (preferably Cubone, Ditto, Psyduck, Snorlax, or Weedle) to see its type and confidence level. The trained model has an accuracy of 96%!" # Description of the interface
35
+ )
36
 
37
+ # Launch the interface
38
+ iface.launch()