kritsg commited on
Commit
7dd304e
·
1 Parent(s): c38a7bb

removing hardcoded aspects

Browse files
Files changed (3) hide show
  1. app.py +5 -1
  2. bayes/models.py +0 -1
  3. image_posterior.py +2 -2
app.py CHANGED
@@ -24,12 +24,16 @@ def get_image_data(image_name):
24
  if (image_name == "imagenet_diego.png"):
25
  image = get_dataset_by_name("imagenet_diego", get_label=False)
26
  model_and_data = process_imagenet_get_model(image)
 
 
 
 
27
 
28
  return image, model_and_data
29
 
30
 
31
  def segmentation_generation(image_name, c_width, n_top, n_gif_imgs):
32
- print("GRADIO INPUTS:", image_name, c_width, n_top, n_gif_imgs)
33
 
34
  cred_width = c_width
35
  n_top_segs = n_top
 
24
  if (image_name == "imagenet_diego.png"):
25
  image = get_dataset_by_name("imagenet_diego", get_label=False)
26
  model_and_data = process_imagenet_get_model(image)
27
+ elif (image_name == "mnist_9.png"):
28
+ image = get_dataset_by_name(image_name, get_label=False)
29
+ model_and_data = process_mnist_get_model(image)
30
+
31
 
32
  return image, model_and_data
33
 
34
 
35
  def segmentation_generation(image_name, c_width, n_top, n_gif_imgs):
36
+ print("Inputs Received:", image_name, c_width, n_top, n_gif_imgs)
37
 
38
  cred_width = c_width
39
  n_top_segs = n_top
bayes/models.py CHANGED
@@ -39,7 +39,6 @@ def get_xtrain(segs):
39
 
40
  def process_imagenet_get_model(data):
41
  """Gets wrapped imagenet model."""
42
- print("DATA RECEIVED FROM APP:\n", data)
43
  # Get the vgg16 model, used in the experiments
44
  model = models.vgg16(pretrained=True)
45
  model.eval()
 
39
 
40
  def process_imagenet_get_model(data):
41
  """Gets wrapped imagenet model."""
 
42
  # Get the vgg16 model, used in the experiments
43
  model = models.vgg16(pretrained=True)
44
  model.eval()
image_posterior.py CHANGED
@@ -64,10 +64,10 @@ def create_gif(explanation_blr, img_name, segments, image, n_images=20, n_max=5)
64
  plt.imshow(c_image, alpha=0.3)
65
  paths.append(os.path.join(tmpdirname, f"{i}.png"))
66
  plt.savefig(paths[-1])
67
- print("CREATING GIF NOW")
68
  # Save to gif
69
  # https://stackoverflow.com/questions/61716066/creating-an-animation-out-of-matplotlib-pngs
70
- # print(f"Saving gif to {save_loc}")
71
 
72
  ims = [imageio.imread(f) for f in paths]
73
  imageio.mimwrite(f'{img_name}_explanation.gif', ims)
 
64
  plt.imshow(c_image, alpha=0.3)
65
  paths.append(os.path.join(tmpdirname, f"{i}.png"))
66
  plt.savefig(paths[-1])
67
+
68
  # Save to gif
69
  # https://stackoverflow.com/questions/61716066/creating-an-animation-out-of-matplotlib-pngs
70
+ print(f"Saving gif to {img_name}_explanation.gif")
71
 
72
  ims = [imageio.imread(f) for f in paths]
73
  imageio.mimwrite(f'{img_name}_explanation.gif', ims)