captain-awesome commited on
Commit
ee5988c
1 Parent(s): 25a3c0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -30
app.py CHANGED
@@ -1,40 +1,84 @@
1
- from transformers import DetrImageProcessor, DetrForObjectDetection
2
- from transformers import BlipProcessor, BlipForConditionalGeneration
3
- import torch
4
- from PIL import Image
5
- import requests
6
  import gradio as gr
 
 
 
 
 
7
 
8
- box_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
9
- box_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
10
 
11
- caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
12
- caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
 
 
13
 
14
- def predict_bounding_boxes(imageurl:str):
15
- try:
16
- response = requests.get(imageurl, stream=True)
17
- response.raise_for_status()
 
 
 
 
 
 
 
18
 
19
- image_data = Image.open(response.raw)
20
- inputs = box_processor(images=image_data, return_tensors="pt")
21
- outputs = box_model(**inputs)
22
 
23
- target_sizes = torch.tensor([image_data.size[::-1]])
24
- results = box_processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.70)[0]
 
 
 
25
 
26
- detections = [{"score": score.item(), "label": box_model.config.id2label[label.item()], "box": box.tolist()} for score, label, box in zip(results["scores"], results["labels"], results["boxes"])]
 
 
 
 
 
 
 
27
 
28
- raw_image = image_data.convert('RGB')
29
- inputs = caption_processor(raw_image, return_tensors="pt")
30
- out = caption_model.generate(**inputs)
31
- label = caption_processor.decode(out[0], skip_special_tokens=True)
32
- return {"image label": label, "detections": detections}
33
 
34
- except Exception as e:
35
-
36
- return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- app = gr.Interface(fn=predict_bounding_boxes, inputs="text", outputs="json")
39
- app.api = True
40
- app.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ from transformers import AutoImageProcessor, AutoModelForObjectDetection
4
+ from PIL import Image, ImageDraw
5
+ import torch
6
+ from transformers import DetrImageProcessor, DetrForObjectDetection
7
 
 
 
8
 
9
+ #image_processor = AutoImageProcessor.from_pretrained('hustvl/yolos-small')
10
+ #model = AutoModelForObjectDetection.from_pretrained('hustvl/yolos-small')
11
+ image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
12
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
13
 
14
+ colors = ["red",
15
+ "orange",
16
+ "yellow",
17
+ "green",
18
+ "blue",
19
+ "indigo",
20
+ "violet",
21
+ "brown",
22
+ "black",
23
+ "slategray",
24
+ ]
25
 
26
+ # Resized image width
27
+ WIDTH = 900
 
28
 
29
+ def detect(image):
30
+ print(image)
31
+ width, height = image.size
32
+ ratio = float(WIDTH) / float(width)
33
+ new_h = height * ratio
34
 
35
+ image = image.resize((int(WIDTH), int(new_h)), Image.Resampling.LANCZOS)
36
+
37
+ inputs = image_processor(images=image, return_tensors="pt")
38
+ outputs = model(**inputs)
39
+
40
+ # convert outputs to COCO API
41
+ target_sizes = torch.tensor([image.size[::-1]])
42
+ results = image_processor.post_process_object_detection(outputs,threshold=0.9, target_sizes=target_sizes)[0]
43
 
44
+ draw = ImageDraw.Draw(image)
 
 
 
 
45
 
46
+ # label and the count
47
+ counts = {}
48
+
49
+ for score, label in zip(results["scores"], results["labels"]):
50
+ label_name = model.config.id2label[label.item()]
51
+ if label_name not in counts:
52
+ counts[label_name] = 0
53
+ counts[label_name] += 1
54
+
55
+ count_results = {k: v for k, v in (sorted(counts.items(), key=lambda item: item[1], reverse=True)[:10])}
56
+ label2color = {}
57
+ for idx, label in enumerate(count_results):
58
+ label2color[label] = colors[idx]
59
+
60
+ for label, box in zip(results["labels"], results["boxes"]):
61
+ label_name = model.config.id2label[label.item()]
62
+
63
+ if label_name in count_results:
64
+ box = [round(i, 4) for i in box.tolist()]
65
+ x1, y1, x2, y2 = tuple(box)
66
+ draw.rectangle((x1, y1, x2, y2), outline=label2color[label_name], width=2)
67
+ draw.text((x1, y1), label_name, fill="white")
68
+
69
+ df = pd.DataFrame({
70
+ 'label': [label for label in count_results],
71
+ 'counts': [counts[label] for label in count_results]
72
+ })
73
+
74
+ return image, df, count_results
75
+
76
+ demo = gr.Interface(
77
+ fn=detect,
78
+ inputs=[gr.Image(label="Input image", type="pil")],
79
+ outputs=[gr.Image(label="Output image"), gr.BarPlot(show_label=False, x="label", y="counts", x_title="Labels", y_title="Counts", vertical=False), gr.Textbox(show_label=False)],
80
+ title="FB Object Detection",
81
+ cache_examples=False
82
+ )
83
 
84
+ demo.launch()