orhir commited on
Commit
6ca8d3a
·
1 Parent(s): 7aa7d05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -94
app.py CHANGED
@@ -162,6 +162,122 @@ def process(query_img, state,
162
  return out, state
163
 
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  with gr.Blocks() as demo:
166
  state = gr.State({
167
  'kp_src': [],
@@ -182,7 +298,8 @@ with gr.Blocks() as demo:
182
  this crucial structural information, our method enhances the accuracy of
183
  keypoint localization, marking a significant departure from conventional
184
  CAPE techniques that treat keypoints as isolated entities.
185
- ### [Paper](https://arxiv.org/abs/2311.17891) | [Official Repo](https://github.com/orhir/PoseAnything)
 
186
  ## Instructions
187
  1. Upload an image of the object you want to pose on the **left** image.
188
  2. Click on the **left** image to mark keypoints.
@@ -207,103 +324,65 @@ with gr.Blocks() as demo:
207
  eval_btn = gr.Button(value="Evaluate")
208
  with gr.Row():
209
  output_img = gr.Plot(label="Output Image", height=400, width=400)
210
-
211
-
212
- def get_select_coords(kp_support,
213
- limb_support,
214
- state,
215
- evt: gr.SelectData,
216
- r=0.015):
217
- # global original_support_image
218
- # if len(kp_src) == 0:
219
- # original_support_image = np.array(kp_support)[:, :,
220
- # ::-1].copy()
221
- pixels_in_queue = set()
222
- pixels_in_queue.add((evt.index[1], evt.index[0]))
223
- while len(pixels_in_queue) > 0:
224
- pixel = pixels_in_queue.pop()
225
- if pixel[0] is not None and pixel[
226
- 1] is not None and pixel not in state['kp_src']:
227
- state['kp_src'].append(pixel)
228
- else:
229
- print("Invalid pixel")
230
- if limb_support is None:
231
- canvas_limb = kp_support
232
- else:
233
- canvas_limb = limb_support
234
- canvas_kp = kp_support
235
- w, h = canvas_kp.size
236
- draw_pose = ImageDraw.Draw(canvas_kp)
237
- draw_limb = ImageDraw.Draw(canvas_limb)
238
- r = int(r * w)
239
- leftUpPoint = (pixel[1] - r, pixel[0] - r)
240
- rightDownPoint = (pixel[1] + r, pixel[0] + r)
241
- twoPointList = [leftUpPoint, rightDownPoint]
242
- draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255))
243
- draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
244
-
245
- return canvas_kp, canvas_limb, state
246
-
247
-
248
- def get_limbs(kp_support,
249
- state,
250
- evt: gr.SelectData,
251
- r=0.02, width=0.02):
252
- curr_pixel = (evt.index[1], evt.index[0])
253
- pixels_in_queue = set()
254
- pixels_in_queue.add((evt.index[1], evt.index[0]))
255
- canvas_kp = kp_support
256
- w, h = canvas_kp.size
257
- r = int(r * w)
258
- width = int(width * w)
259
- while len(pixels_in_queue) > 0 and curr_pixel != state['prev_clicked']:
260
- pixel = pixels_in_queue.pop()
261
- state['prev_clicked'] = pixel
262
- closest_point = min(state['kp_src'],
263
- key=lambda p: (p[0] - pixel[0]) ** 2 +
264
- (p[1] - pixel[1]) ** 2)
265
- closest_point_index = state['kp_src'].index(closest_point)
266
- draw_limb = ImageDraw.Draw(canvas_kp)
267
- if state['color_idx'] < len(COLORS):
268
- c = COLORS[state['color_idx']]
269
- else:
270
- c = random.choices(range(256), k=3)
271
- leftUpPoint = (closest_point[1] - r, closest_point[0] - r)
272
- rightDownPoint = (closest_point[1] + r, closest_point[0] + r)
273
- twoPointList = [leftUpPoint, rightDownPoint]
274
- draw_limb.ellipse(twoPointList, fill=tuple(c))
275
- if state['count'] == 0:
276
- state['prev_pt'] = closest_point[1], closest_point[0]
277
- state['prev_pt_idx'] = closest_point_index
278
- state['count'] = state['count'] + 1
279
- else:
280
- if state['prev_pt_idx'] != closest_point_index:
281
- # Create Line and add Limb
282
- draw_limb.line(
283
- [state['prev_pt'], (closest_point[1], closest_point[0])],
284
- fill=tuple(c),
285
- width=width)
286
- state['skeleton'].append((state['prev_pt_idx'], closest_point_index))
287
- state['color_idx'] = state['color_idx'] + 1
288
- else:
289
- draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
290
- state['count'] = 0
291
- return canvas_kp, state
292
-
293
-
294
- def set_qery(support_img, state):
295
- state['skeleton'].clear()
296
- state['kp_src'].clear()
297
- state['original_support_image'] = np.array(support_img)[:, :, ::-1].copy()
298
- width, height = support_img.size
299
- support_img = support_img.resize((width // 4, width // 4), Image.Resampling.LANCZOS)
300
- return support_img, support_img, state
301
-
302
 
303
  support_img.select(get_select_coords,
304
  [support_img, posed_support, state],
305
  [support_img, posed_support, state])
306
- support_img.upload(set_qery,
307
  inputs=[support_img, state],
308
  outputs=[support_img, posed_support, state])
309
  posed_support.select(get_limbs,
 
162
  return out, state
163
 
164
 
165
+ def update_examples(support_img, posed_support, query_img, state, r=0.015, width=0.02):
166
+ state['color_idx'] = 0
167
+ state['original_support_image'] = np.array(support_img)[:, :, ::-1].copy()
168
+ support_img, posed_support, _ = set_query(support_img, state, example=True)
169
+ w, h = support_img.size
170
+ draw_pose = ImageDraw.Draw(support_img)
171
+ draw_limb = ImageDraw.Draw(posed_support)
172
+ r = int(r * w)
173
+ width = int(width * w)
174
+ for pixel in state['kp_src']:
175
+ leftUpPoint = (pixel[1] - r, pixel[0] - r)
176
+ rightDownPoint = (pixel[1] + r, pixel[0] + r)
177
+ twoPointList = [leftUpPoint, rightDownPoint]
178
+ draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255))
179
+ draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
180
+ for limb in state['skeleton']:
181
+ point_a = state['kp_src'][limb[0]][::-1]
182
+ point_b = state['kp_src'][limb[1]][::-1]
183
+ if state['color_idx'] < len(COLORS):
184
+ c = COLORS[state['color_idx']]
185
+ state['color_idx'] += 1
186
+ else:
187
+ c = random.choices(range(256), k=3)
188
+ draw_limb.line([point_a, point_b], fill=tuple(c), width=width)
189
+ return support_img, posed_support, query_img, state
190
+
191
+
192
+ def get_select_coords(kp_support,
193
+ limb_support,
194
+ state,
195
+ evt: gr.SelectData,
196
+ r=0.015):
197
+ pixels_in_queue = set()
198
+ pixels_in_queue.add((evt.index[1], evt.index[0]))
199
+ while len(pixels_in_queue) > 0:
200
+ pixel = pixels_in_queue.pop()
201
+ if pixel[0] is not None and pixel[1] is not None and pixel not in \
202
+ state['kp_src']:
203
+ state['kp_src'].append(pixel)
204
+ else:
205
+ continue
206
+ if limb_support is None:
207
+ canvas_limb = kp_support
208
+ else:
209
+ canvas_limb = limb_support
210
+ canvas_kp = kp_support
211
+ w, h = canvas_kp.size
212
+ draw_pose = ImageDraw.Draw(canvas_kp)
213
+ draw_limb = ImageDraw.Draw(canvas_limb)
214
+ r = int(r * w)
215
+ leftUpPoint = (pixel[1] - r, pixel[0] - r)
216
+ rightDownPoint = (pixel[1] + r, pixel[0] + r)
217
+ twoPointList = [leftUpPoint, rightDownPoint]
218
+ draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255))
219
+ draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
220
+ return canvas_kp, canvas_limb, state
221
+
222
+
223
+ def get_limbs(kp_support,
224
+ state,
225
+ evt: gr.SelectData,
226
+ r=0.02, width=0.02):
227
+ curr_pixel = (evt.index[1], evt.index[0])
228
+ pixels_in_queue = set()
229
+ pixels_in_queue.add((evt.index[1], evt.index[0]))
230
+ canvas_kp = kp_support
231
+ w, h = canvas_kp.size
232
+ r = int(r * w)
233
+ width = int(width * w)
234
+ while len(pixels_in_queue) > 0 and curr_pixel != state['prev_clicked']:
235
+ pixel = pixels_in_queue.pop()
236
+ state['prev_clicked'] = pixel
237
+ closest_point = min(state['kp_src'],
238
+ key=lambda p: (p[0] - pixel[0]) ** 2 +
239
+ (p[1] - pixel[1]) ** 2)
240
+ closest_point_index = state['kp_src'].index(closest_point)
241
+ draw_limb = ImageDraw.Draw(canvas_kp)
242
+ if state['color_idx'] < len(COLORS):
243
+ c = COLORS[state['color_idx']]
244
+ else:
245
+ c = random.choices(range(256), k=3)
246
+ leftUpPoint = (closest_point[1] - r, closest_point[0] - r)
247
+ rightDownPoint = (closest_point[1] + r, closest_point[0] + r)
248
+ twoPointList = [leftUpPoint, rightDownPoint]
249
+ draw_limb.ellipse(twoPointList, fill=tuple(c))
250
+ if state['count'] == 0:
251
+ state['prev_pt'] = closest_point[1], closest_point[0]
252
+ state['prev_pt_idx'] = closest_point_index
253
+ state['count'] = state['count'] + 1
254
+ else:
255
+ if state['prev_pt_idx'] != closest_point_index:
256
+ # Create Line and add Limb
257
+ draw_limb.line(
258
+ [state['prev_pt'], (closest_point[1], closest_point[0])],
259
+ fill=tuple(c),
260
+ width=width)
261
+ state['skeleton'].append(
262
+ (state['prev_pt_idx'], closest_point_index))
263
+ state['color_idx'] = state['color_idx'] + 1
264
+ else:
265
+ draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
266
+ state['count'] = 0
267
+ return canvas_kp, state
268
+
269
+
270
+ def set_query(support_img, state, example=False):
271
+ if not example:
272
+ state['skeleton'].clear()
273
+ state['kp_src'].clear()
274
+ state['original_support_image'] = np.array(support_img)[:, :, ::-1].copy()
275
+ width, height = support_img.size
276
+ support_img = support_img.resize((width // 4, width // 4),
277
+ Image.Resampling.LANCZOS)
278
+ return support_img, support_img, state
279
+
280
+
281
  with gr.Blocks() as demo:
282
  state = gr.State({
283
  'kp_src': [],
 
298
  this crucial structural information, our method enhances the accuracy of
299
  keypoint localization, marking a significant departure from conventional
300
  CAPE techniques that treat keypoints as isolated entities.
301
+ ### [Paper](https://arxiv.org/abs/2311.17891) | [Official Repo](
302
+ https://github.com/orhir/PoseAnything)
303
  ## Instructions
304
  1. Upload an image of the object you want to pose on the **left** image.
305
  2. Click on the **left** image to mark keypoints.
 
324
  eval_btn = gr.Button(value="Evaluate")
325
  with gr.Row():
326
  output_img = gr.Plot(label="Output Image", height=400, width=400)
327
+ with gr.Row():
328
+ gr.Markdown("## Examples")
329
+ with gr.Row():
330
+ gr.Examples(
331
+ examples=[
332
+ ['examples/dog2.png',
333
+ 'examples/dog2.png',
334
+ 'examples/dog1.png',
335
+ {'kp_src': [(50, 58), (51, 78), (66, 57), (118, 79),
336
+ (154, 79), (217, 74), (218, 103), (156, 104),
337
+ (152, 151), (215, 162), (213, 191),
338
+ (152, 174), (108, 171)],
339
+ 'skeleton': [(0, 1), (1, 2), (0, 2), (3, 4), (4, 5),
340
+ (3, 7), (7, 6), (3, 12), (12, 8), (8, 9),
341
+ (12, 11), (11, 10)], 'count': 0,
342
+ 'color_idx': 0, 'prev_pt': (174, 152),
343
+ 'prev_pt_idx': 11, 'prev_clicked': (207, 186),
344
+ 'original_support_image': None,
345
+ }
346
+ ],
347
+ ['examples/sofa1.jpg',
348
+ 'examples/sofa1.jpg',
349
+ 'examples/sofa2.jpg',
350
+ {
351
+ 'kp_src': [(82, 28), (65, 30), (52, 26), (65, 50),
352
+ (84, 52), (53, 54), (43, 52), (45, 71),
353
+ (81, 69), (77, 39), (57, 43), (58, 64),
354
+ (46, 42), (49, 65)],
355
+ 'skeleton': [(0, 1), (3, 1), (3, 4), (10, 9), (11, 8),
356
+ (1, 10), (10, 11), (11, 3), (1, 2), (7, 6),
357
+ (5, 13), (5, 3), (13, 11), (12, 10), (12, 2),
358
+ (6, 10), (7, 11)], 'count': 0,
359
+ 'color_idx': 23, 'prev_pt': (71, 45), 'prev_pt_idx': 7,
360
+ 'prev_clicked': (56, 63),
361
+ 'original_support_image': None,
362
+ }],
363
+ ['examples/person1.jpeg',
364
+ 'examples/person1.jpeg',
365
+ 'examples/person2.jpeg',
366
+ {
367
+ 'kp_src': [(121, 95), (122, 160), (154, 130), (184, 106),
368
+ (181, 153)],
369
+ 'skeleton': [(0, 1), (1, 2), (0, 2), (2, 3), (2, 4),
370
+ (4, 3)], 'count': 0, 'color_idx': 6,
371
+ 'prev_pt': (153, 181), 'prev_pt_idx': 4,
372
+ 'prev_clicked': (181, 108),
373
+ 'original_support_image': None,
374
+ }]
375
+ ],
376
+ inputs=[support_img, posed_support, query_img, state],
377
+ outputs=[support_img, posed_support, query_img, state],
378
+ fn=update_examples,
379
+ run_on_click=True,
380
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
  support_img.select(get_select_coords,
383
  [support_img, posed_support, state],
384
  [support_img, posed_support, state])
385
+ support_img.upload(set_query,
386
  inputs=[support_img, state],
387
  outputs=[support_img, posed_support, state])
388
  posed_support.select(get_limbs,