sfmig commited on
Commit
f9e4a95
1 Parent(s): c9a31c4

changed to DERT segmentation model

Browse files
Files changed (2) hide show
  1. .gitignore +2 -0
  2. app.py +213 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ scrap*
2
+ .DS_Store
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Using as reference:
3
+ - https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512
4
+ - https://huggingface.co/spaces/chansung/segformer-tf-transformers/blob/main/app.py
5
+ - https://huggingface.co/facebook/detr-resnet-50-panoptic
6
+ """
7
+
8
+ from transformers import DetrFeatureExtractor, DetrForSegmentation
9
+ from PIL import Image
10
+ import gradio as gr
11
+ import numpy as np
12
+
13
+ # Returns a list with a color per ADE class (150 classes)
14
+ # from https://huggingface.co/spaces/chansung/segformer-tf-transformers/blob/main/app.py
15
+ def ade_palette():
16
+ """ADE20K palette that maps each class to RGB values."""
17
+ return [
18
+ [120, 120, 120],
19
+ [180, 120, 120],
20
+ [6, 230, 230],
21
+ [80, 50, 50],
22
+ [4, 200, 3],
23
+ [120, 120, 80],
24
+ [140, 140, 140],
25
+ [204, 5, 255],
26
+ [230, 230, 230],
27
+ [4, 250, 7],
28
+ [224, 5, 255],
29
+ [235, 255, 7],
30
+ [150, 5, 61],
31
+ [120, 120, 70],
32
+ [8, 255, 51],
33
+ [255, 6, 82],
34
+ [143, 255, 140],
35
+ [204, 255, 4],
36
+ [255, 51, 7],
37
+ [204, 70, 3],
38
+ [0, 102, 200],
39
+ [61, 230, 250],
40
+ [255, 6, 51],
41
+ [11, 102, 255],
42
+ [255, 7, 71],
43
+ [255, 9, 224],
44
+ [9, 7, 230],
45
+ [220, 220, 220],
46
+ [255, 9, 92],
47
+ [112, 9, 255],
48
+ [8, 255, 214],
49
+ [7, 255, 224],
50
+ [255, 184, 6],
51
+ [10, 255, 71],
52
+ [255, 41, 10],
53
+ [7, 255, 255],
54
+ [224, 255, 8],
55
+ [102, 8, 255],
56
+ [255, 61, 6],
57
+ [255, 194, 7],
58
+ [255, 122, 8],
59
+ [0, 255, 20],
60
+ [255, 8, 41],
61
+ [255, 5, 153],
62
+ [6, 51, 255],
63
+ [235, 12, 255],
64
+ [160, 150, 20],
65
+ [0, 163, 255],
66
+ [140, 140, 140],
67
+ [250, 10, 15],
68
+ [20, 255, 0],
69
+ [31, 255, 0],
70
+ [255, 31, 0],
71
+ [255, 224, 0],
72
+ [153, 255, 0],
73
+ [0, 0, 255],
74
+ [255, 71, 0],
75
+ [0, 235, 255],
76
+ [0, 173, 255],
77
+ [31, 0, 255],
78
+ [11, 200, 200],
79
+ [255, 82, 0],
80
+ [0, 255, 245],
81
+ [0, 61, 255],
82
+ [0, 255, 112],
83
+ [0, 255, 133],
84
+ [255, 0, 0],
85
+ [255, 163, 0],
86
+ [255, 102, 0],
87
+ [194, 255, 0],
88
+ [0, 143, 255],
89
+ [51, 255, 0],
90
+ [0, 82, 255],
91
+ [0, 255, 41],
92
+ [0, 255, 173],
93
+ [10, 0, 255],
94
+ [173, 255, 0],
95
+ [0, 255, 153],
96
+ [255, 92, 0],
97
+ [255, 0, 255],
98
+ [255, 0, 245],
99
+ [255, 0, 102],
100
+ [255, 173, 0],
101
+ [255, 0, 20],
102
+ [255, 184, 184],
103
+ [0, 31, 255],
104
+ [0, 255, 61],
105
+ [0, 71, 255],
106
+ [255, 0, 204],
107
+ [0, 255, 194],
108
+ [0, 255, 82],
109
+ [0, 10, 255],
110
+ [0, 112, 255],
111
+ [51, 0, 255],
112
+ [0, 194, 255],
113
+ [0, 122, 255],
114
+ [0, 255, 163],
115
+ [255, 153, 0],
116
+ [0, 255, 10],
117
+ [255, 112, 0],
118
+ [143, 255, 0],
119
+ [82, 0, 255],
120
+ [163, 255, 0],
121
+ [255, 235, 0],
122
+ [8, 184, 170],
123
+ [133, 0, 255],
124
+ [0, 255, 92],
125
+ [184, 0, 255],
126
+ [255, 0, 31],
127
+ [0, 184, 255],
128
+ [0, 214, 255],
129
+ [255, 0, 112],
130
+ [92, 255, 0],
131
+ [0, 224, 255],
132
+ [112, 224, 255],
133
+ [70, 184, 160],
134
+ [163, 0, 255],
135
+ [153, 0, 255],
136
+ [71, 255, 0],
137
+ [255, 0, 163],
138
+ [255, 204, 0],
139
+ [255, 0, 143],
140
+ [0, 255, 235],
141
+ [133, 255, 0],
142
+ [255, 0, 235],
143
+ [245, 0, 255],
144
+ [255, 0, 122],
145
+ [255, 245, 0],
146
+ [10, 190, 212],
147
+ [214, 255, 0],
148
+ [0, 204, 255],
149
+ [20, 0, 255],
150
+ [255, 255, 0],
151
+ [0, 153, 255],
152
+ [0, 41, 255],
153
+ [0, 255, 204],
154
+ [41, 0, 255],
155
+ [41, 255, 0],
156
+ [173, 0, 255],
157
+ [0, 245, 255],
158
+ [71, 0, 255],
159
+ [122, 0, 255],
160
+ [0, 255, 184],
161
+ [0, 92, 255],
162
+ [184, 255, 0],
163
+ [0, 133, 255],
164
+ [255, 214, 0],
165
+ [25, 194, 194],
166
+ [102, 255, 0],
167
+ [92, 0, 255],
168
+ ]
169
+
170
+ feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50-panoptic')
171
+ model = DetrForSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic')
172
+
173
+ # gradio components
174
+ input = gr.inputs.Image()
175
+ output = gr.outputs.Image()
176
+
177
+ def predict_animal_mask(image):
178
+ inputs = feature_extractor(images=image, return_tensors="pt") #pt=Pytorch, tf=TensorFlow
179
+ outputs = model(**inputs)
180
+ logits = outputs.logits
181
+ bboxes = outputs.pred_boxes
182
+ masks = outputs.pred_masks
183
+
184
+ # postprocess the image
185
+ label_per_pixel = torch.argmax(masks.squeeze(),dim=0).detach().numpy()
186
+ color_mask = np.zeros(image.size+(3,))
187
+ for lbl, color in enumerate(ade_palette()):
188
+ color_mask[label_per_pixel==lbl,:] = color
189
+
190
+ # Show image + mask
191
+ pred_img = np.array(image.convert('RGB'))*0.5 + color_mask*0.5
192
+ pred_img = pred_img.astype(np.uint8)
193
+
194
+
195
+ ####################################################
196
+ # Create user interface and launch
197
+ gr.Interface(predict_animal_mask,
198
+ inputs = input,
199
+ outputs = output,
200
+ title = 'Animals segmentation in images',
201
+ description = "An animal segmentation image webapp using DETR (End-to-End Object Detection) model with ResNet-50 backbone").launch()
202
+
203
+
204
+ ####################################
205
+ # url = "http://images.cocodataset.org/val2017/000000039769.jpg"
206
+ # image = Image.open(requests.get(url, stream=True).raw)
207
+
208
+ # inputs = feature_extractor(images=image, return_tensors="pt")
209
+ # outputs = model(**inputs)
210
+ # logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
211
+
212
+
213
+