friesti1 commited on
Commit
be1f8d6
1 Parent(s): 63ece31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -15
app.py CHANGED
@@ -2,34 +2,36 @@ import gradio as gr
2
  import tensorflow as tf
3
  from PIL import Image
4
  import numpy as np
5
-
6
  labels = ['Cubone', 'Ditto', 'Psyduck', 'Snorlax', 'Weedle']
7
-
8
  def predict_pokemon_type(uploaded_file):
9
-
10
  if uploaded_file is None:
11
- return "No file uploaded."
12
-
13
  model = tf.keras.models.load_model('pokemon-model_transferlearning.keras')
 
14
  # Load the image from the file path
15
  with Image.open(uploaded_file) as img:
16
- img = img.resize((150, 150)).convert('RGB') # Convert image to RGB
17
  img_array = np.array(img)
18
-
19
  prediction = model.predict(np.expand_dims(img_array, axis=0))
20
  confidences = {labels[i]: np.round(float(prediction[0][i]), 2) for i in range(len(labels))}
21
-
22
- return confidences
23
-
24
-
 
 
25
  # Define the Gradio interface
26
  iface = gr.Interface(
27
  fn=predict_pokemon_type, # Function to process the input
28
  inputs=gr.File(label="Upload File"), # File upload widget
29
- outputs="text", # Output type
30
  title="Pokemon Classifier", # Title of the interface
31
- description="Upload a picture of a pokemon (preferably Cubone, Ditto, Psyduck, Snorlax or Weedle)" # Description of the interface
32
  )
33
-
34
  # Launch the interface
35
- iface.launch()
 
2
  import tensorflow as tf
3
  from PIL import Image
4
  import numpy as np
5
+
6
  labels = ['Cubone', 'Ditto', 'Psyduck', 'Snorlax', 'Weedle']
7
+
8
  def predict_pokemon_type(uploaded_file):
 
9
  if uploaded_file is None:
10
+ return "No file uploaded.", None, "No prediction"
11
+
12
  model = tf.keras.models.load_model('pokemon-model_transferlearning.keras')
13
+
14
  # Load the image from the file path
15
  with Image.open(uploaded_file) as img:
16
+ img = img.resize((150, 150))
17
  img_array = np.array(img)
18
+
19
  prediction = model.predict(np.expand_dims(img_array, axis=0))
20
  confidences = {labels[i]: np.round(float(prediction[0][i]), 2) for i in range(len(labels))}
21
+
22
+ # Identify the most confident prediction
23
+ confidences = {labels[i]: np.round(float(prediction[0][i]), 2) for i in range(len(labels))}
24
+
25
+ return img, confidences
26
+
27
  # Define the Gradio interface
28
  iface = gr.Interface(
29
  fn=predict_pokemon_type, # Function to process the input
30
  inputs=gr.File(label="Upload File"), # File upload widget
31
+ outputs=["image", "text"], # Output types for image and text
32
  title="Pokemon Classifier", # Title of the interface
33
+ description="Upload a picture of a Pokemon (preferably Cubone, Ditto, Psyduck, Snorlax, or Weedle) to see its type and confidence level." # Description of the interface
34
  )
35
+
36
  # Launch the interface
37
+ iface.launch()