Harshithtd commited on
Commit
b1c60a9
1 Parent(s): a188584

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import CLIPProcessor, CLIPModel
3
+ from PIL import Image
4
+ import torch
5
+ import numpy as np
6
+ import cv2
7
+
8
+ # Load the pre-trained CLIP model and processor
9
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
10
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
11
+
12
+ def apply_gradcam(image, text):
13
+ inputs = processor(text=[text], images=image, return_tensors="pt", padding=True)
14
+ outputs = model(**inputs)
15
+
16
+ image_embeds = outputs.image_embeds
17
+ text_embeds = outputs.text_embeds
18
+ similarity = torch.nn.functional.cosine_similarity(image_embeds, text_embeds)
19
+ similarity.backward()
20
+
21
+ gradients = model.get_input_embeddings().weight.grad
22
+ pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
23
+
24
+ activations = outputs.last_hidden_state
25
+ for i in range(pooled_gradients.shape[0]):
26
+ activations[:, i, :, :] *= pooled_gradients[i]
27
+ heatmap = torch.mean(activations, dim=1).squeeze().detach().cpu().numpy()
28
+
29
+ heatmap = np.maximum(heatmap, 0)
30
+ heatmap /= np.max(heatmap)
31
+ heatmap = cv2.resize(heatmap, (image.size[0], image.size[1]))
32
+ heatmap = np.uint8(255 * heatmap)
33
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
34
+
35
+ superimposed_img = cv2.addWeighted(np.array(image), 0.6, heatmap, 0.4, 0)
36
+ return superimposed_img
37
+
38
+ def highlight_image(image, text):
39
+ highlighted_image = apply_gradcam(image, text)
40
+ return Image.fromarray(highlighted_image)
41
+
42
+ # Define Gradio interface
43
+ iface = gr.Interface(
44
+ fn=highlight_image,
45
+ inputs=[gr.Image(type="pil"), gr.Textbox(label="Text Description")],
46
+ outputs=gr.Image(type="pil"),
47
+ title="Image Text Highlight",
48
+ description="Upload an image and provide a text description to highlight the relevant part of the image."
49
+ )
50
+
51
+ # Launch the Gradio app
52
+ iface.launch()