jaekookang commited on
Commit
bcaf154
β€’
1 Parent(s): dbb7b85
.ipynb_checkpoints/gradio_artist_classifier-checkpoint.py CHANGED
@@ -12,6 +12,7 @@ import seaborn as sns
12
 
13
  import io
14
  import json
 
15
  import skimage.io
16
  from loguru import logger
17
  from huggingface_hub import from_pretrained_keras
 
12
 
13
  import io
14
  import json
15
+ import numpy as np
16
  import skimage.io
17
  from loguru import logger
18
  from huggingface_hub import from_pretrained_keras
gradio_artist_classifier.py CHANGED
@@ -12,6 +12,7 @@ import seaborn as sns
12
 
13
  import io
14
  import json
 
15
  import skimage.io
16
  from loguru import logger
17
  from huggingface_hub import from_pretrained_keras
@@ -25,6 +26,7 @@ from gradcam_utils import get_img_4d_array, make_gradcam_heatmap, align_image_wi
25
  ARTIST_META = 'artist.json'
26
  TREND_META = 'trend.json'
27
  EXAMPLES = ['monet.jpg']
 
28
 
29
  # ---------- Logging ----------
30
  logger.add('app.log', mode='a')
@@ -68,7 +70,7 @@ def predict(input_image):
68
  img_4d_array,
69
  pred_idx=None)
70
  a_img_pil = align_image_with_heatmap(
71
- img_4d_array, a_heatmap, alpha=alpha, cmap='jet')
72
  a_img = np.asarray(a_img_pil).astype('float32')/255
73
  a_label = id2artist[a_pred_id]
74
  a_prob = a_pred_out[a_pred_id]
@@ -79,7 +81,7 @@ def predict(input_image):
79
  pred_idx=None)
80
 
81
  t_img_pil = align_image_with_heatmap(
82
- img_4d_array, t_heatmap, alpha=alpha, cmap='jet')
83
  t_img = np.asarray(t_img_pil).astype('float32')/255
84
  t_label = id2trend[t_pred_id]
85
  t_prob = t_pred_out[t_pred_id]
@@ -95,7 +97,7 @@ def predict(input_image):
95
  ax2.imshow(a_img)
96
  ax3.imshow(t_img)
97
 
98
- ax1.set_title(f'Artist: {artist}\nTrend: {trend}', ha='left', x=0, y=1.05)
99
  ax2.set_title(f'Artist Prediction:\n =>{a_label} ({a_prob:.2f})', ha='left', x=0, y=1.05)
100
  ax3.set_title(f'Trend Prediction:\n =>{t_label} ({t_prob:.2f})', ha='left', x=0, y=1.05)
101
  fig.tight_layout()
 
12
 
13
  import io
14
  import json
15
+ import numpy as np
16
  import skimage.io
17
  from loguru import logger
18
  from huggingface_hub import from_pretrained_keras
 
26
  ARTIST_META = 'artist.json'
27
  TREND_META = 'trend.json'
28
  EXAMPLES = ['monet.jpg']
29
+ ALPHA = 0.9
30
 
31
  # ---------- Logging ----------
32
  logger.add('app.log', mode='a')
 
70
  img_4d_array,
71
  pred_idx=None)
72
  a_img_pil = align_image_with_heatmap(
73
+ img_4d_array, a_heatmap, alpha=ALPHA, cmap='jet')
74
  a_img = np.asarray(a_img_pil).astype('float32')/255
75
  a_label = id2artist[a_pred_id]
76
  a_prob = a_pred_out[a_pred_id]
 
81
  pred_idx=None)
82
 
83
  t_img_pil = align_image_with_heatmap(
84
+ img_4d_array, t_heatmap, alpha=ALPHA, cmap='jet')
85
  t_img = np.asarray(t_img_pil).astype('float32')/255
86
  t_label = id2trend[t_pred_id]
87
  t_prob = t_pred_out[t_pred_id]
 
97
  ax2.imshow(a_img)
98
  ax3.imshow(t_img)
99
 
100
+ ax1.set_title(f'Input Image', ha='left', x=0, y=1.05)
101
  ax2.set_title(f'Artist Prediction:\n =>{a_label} ({a_prob:.2f})', ha='left', x=0, y=1.05)
102
  ax3.set_title(f'Trend Prediction:\n =>{t_label} ({t_prob:.2f})', ha='left', x=0, y=1.05)
103
  fig.tight_layout()