Manu101 commited on
Commit
47f1ab8
1 Parent(s): 190826e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -22
app.py CHANGED
@@ -40,31 +40,23 @@ def resize_image_pil(image, new_width, new_height):
40
 
41
  return resized
42
 
43
- def inference(input_image, transparency=0.5, target_layer_number=-1):
44
- input_image = resize_image_pil(input_image, 32, 32)
45
- input_image = np.array(input_image)
46
- org_img = input_image
47
- input_image = input_image.reshape((32, 32, 3))
48
  transform = transforms.ToTensor()
49
- input_image = transform(input_image)
50
- input_image = input_image.unsqueeze(0)
51
- outputs = model(input_image)
52
- softmax = torch.nn.Softmax(dim=0)
53
- o = softmax(outputs.flatten())
54
- confidences = {classes[i]: float(o[i]) for i in range(10)}
55
  _, prediction = torch.max(outputs, 1)
56
- target_layers = [model.layer2[target_layer_number]]
57
- cam = GradCAM(model= model, target_layers = target_layers)
58
- grayscale_cam = cam(input_tensor=input_image, targets=None)
59
  grayscale_cam = grayscale_cam[0, :]
60
- visualization = show_cam_on_image(
61
- org_img/255,
62
- grayscale_cam,
63
- use_rgb=True,
64
- image_weight = transparency
65
- )
66
-
67
- return classes[prediction[0].item(), visualization, confidences]
68
 
69
  demo = gr.Interface(
70
  inference,
 
40
 
41
  return resized
42
 
43
+ def inference(input_img, transparency):
 
 
 
 
44
  transform = transforms.ToTensor()
45
+ input_img = transform(input_img)
46
+ input_img = input_img.to(device)
47
+ input_img = input_img.unsqueeze(0)
48
+ outputs = model(input_img)
 
 
49
  _, prediction = torch.max(outputs, 1)
50
+ target_layers = [model.layer2[-2]]
51
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
52
+ grayscale_cam = cam(input_tensor=input_img, targets=targets)
53
  grayscale_cam = grayscale_cam[0, :]
54
+ img = input_img.squeeze(0).to('cpu')
55
+ img = inv_normalize(img)
56
+ rgb_img = np.transpose(img, (1, 2, 0))
57
+ rgb_img = rgb_img.numpy()
58
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
59
+ return classes[prediction[0].item()], visualization
 
 
60
 
61
  demo = gr.Interface(
62
  inference,