kinsung commited on
Commit
e05714c
·
1 Parent(s): 3483d27
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image, ImageOps
4
+ from transformers import DetrImageProcessor, DetrForObjectDetection
5
+ import torch
6
+
7
+ feature_extractor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101")
8
+ dmodel = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101")
9
+
10
+ i1 = gr.inputs.Image(type="pil", label="Input image")
11
+ i2 = gr.inputs.Textbox(label="Input text")
12
+ i3 = gr.inputs.Number(default=0.96, label="Threshold percentage score")
13
+ i4 = gr.inputs.Number(default=200, label="Custom Width (optional)")
14
+ i5 = gr.inputs.Number(default=200, label="Custom Height (optional)")
15
+ o1 = gr.outputs.Image(type="pil", label="Cropped part")
16
+ o2 = gr.outputs.Textbox(label="Similarity score")
17
+
18
+ def extract_image(image, text, prob, custom_width, custom_height):
19
+
20
+ inputs = feature_extractor(images=image, return_tensors="pt")
21
+ outputs = dmodel(**inputs)
22
+
23
+ target_sizes = torch.tensor([image.size[::-1]])
24
+ results = feature_extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
25
+
26
+ # Retrieve coordinates of the detected key object based on the input text
27
+ key_object_coordinates = None
28
+
29
+ object_to_detect = text.lower()
30
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
31
+ label_name = dmodel.config.id2label[label.item()].lower()
32
+ if object_to_detect in label_name:
33
+ key_object_coordinates = box.tolist()
34
+ break
35
+
36
+ # Define cropped_image before if condition
37
+ cropped_image = image
38
+
39
+ # Ensure that the key object is in the cropped image
40
+ if key_object_coordinates:
41
+ xmin, ymin, xmax, ymax = key_object_coordinates
42
+ width, height = cropped_image.size
43
+ if xmax > width:
44
+ xmin -= xmax - width
45
+ xmax = width
46
+ if ymax > height:
47
+ ymin -= ymax - height
48
+ ymax = height
49
+
50
+ cropped_image = image.crop((int(xmin), int(ymin), int(xmax), int(ymax)))
51
+
52
+ return cropped_image,int(xmax)
53
+
54
+ title = "ClipnCrop"
55
+ description = "<p style='color:white'>obj and Facebook DETR implemented on HuggingFace Transformers. If the similarity score is not high enough, consider the prediction void.</p>"
56
+ examples = [['ex3.jpg', 'black bag', 0.96, 200, 200, False], ['ex2.jpg', 'man in red dress', 0.85, 300, 300, True]]
57
+ gr.Interface(fn=extract_image, inputs=[i1, i2, i3, i4, i5], outputs=[o1, o2], title=title, description=description, examples=examples, enable_queue=True).launch()