IbrahimHasani commited on
Commit
7a8c937
1 Parent(s): cb4cd83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -531
app.py CHANGED
@@ -1,412 +1,98 @@
1
- import cv2
2
- import requests
3
-
4
- from PIL import Image
5
- import PIL
6
- from PIL import ImageDraw
7
-
8
- from matplotlib import pyplot as plt
9
- import matplotlib
10
- from matplotlib import rcParams
11
-
12
- import os
13
- import tempfile
14
- from io import BytesIO
15
- from pathlib import Path
16
- import argparse
17
- import random
18
- import numpy as np
19
  import torch
20
- import matplotlib.cm as cm
21
- import pandas as pd
22
-
23
-
24
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
25
- from transformers.image_utils import ImageFeatureExtractionMixin
26
-
27
-
 
 
 
28
  from SuperGluePretrainedNetwork.models.matching import Matching
29
- from SuperGluePretrainedNetwork.models.utils import (compute_pose_error, compute_epipolar_error,
30
- estimate_pose,
31
- error_colormap, AverageTimer, pose_auc, read_image,
32
- rotate_intrinsics, rotate_pose_inplane,
33
- scale_intrinsics)
34
-
35
- torch.set_grad_enabled(False)
36
-
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
 
 
 
 
 
 
 
39
 
40
- mixin = ImageFeatureExtractionMixin()
41
- model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
42
- processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
43
 
 
 
44
 
45
- # Use GPU if available
46
- if torch.cuda.is_available():
47
- device = torch.device("cuda")
48
- else:
49
- device = torch.device("cpu")
50
 
 
 
 
 
51
 
52
- import requests
53
- from PIL import Image, ImageDraw
54
- from io import BytesIO
55
- import matplotlib.pyplot as plt
56
- import numpy as np
57
- import torch
58
- import cv2
59
- import tempfile
60
 
61
- def detect_and_crop2(target_image_path,
62
- query_image_path,
63
- model,
64
- processor,
65
- mixin,
66
- device,
67
- threshold=0.5,
68
- nms_threshold=0.3,
69
- visualize=True):
70
-
71
- # Open target image
72
- image = Image.open(target_image_path).convert('RGB')
73
- image_size = model.config.vision_config.image_size + 5
74
- image = mixin.resize(image, image_size)
75
- target_sizes = torch.Tensor([image.size[::-1]])
76
-
77
- # Open query image
78
- query_image = Image.open(query_image_path).convert('RGB')
79
- image_size = model.config.vision_config.image_size + 5
80
- query_image = mixin.resize(query_image, image_size)
81
-
82
- # Process input and query image
83
- inputs = processor(images=image, query_images=query_image, return_tensors="pt").to(device)
84
-
85
- # Get predictions
86
  with torch.no_grad():
87
  outputs = model.image_guided_detection(**inputs)
88
-
89
- # Convert predictions to CPU
90
- img = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
91
  outputs.logits = outputs.logits.cpu()
92
- outputs.target_pred_boxes = outputs.target_pred_boxes.cpu()
93
-
94
- # Post process the predictions
95
  results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes)
96
  boxes, scores = results[0]["boxes"], results[0]["scores"]
97
 
98
- # If no boxes, return an empty list
99
- if len(boxes) == 0 and visualize:
100
- print(f"No boxes detected for image: {target_image_path}")
101
- fig, ax = plt.subplots(figsize=(6, 6))
102
- ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
103
- ax.set_title("Original Image")
104
- ax.axis("off")
105
- plt.show()
106
- return []
107
-
108
- # Filter boxes
109
- img_with_all_boxes = img.copy()
110
  filtered_boxes = []
111
- filtered_scores = []
112
- img_width, img_height = img.shape[1], img.shape[0]
113
- for box, score in zip(boxes, scores):
114
  x1, y1, x2, y2 = [int(i) for i in box.tolist()]
115
- if x1 < 0 or y1 < 0 or x2 < 0 or y2 < 0:
116
- continue
117
- if (x2 - x1) / img_width >= 0.94 and (y2 - y1) / img_height >= 0.94:
118
- continue
119
- filtered_boxes.append([x1, y1, x2, y2])
120
- filtered_scores.append(score)
121
-
122
- # Draw boxes on original image
123
- draw = ImageDraw.Draw(image)
124
- for box in filtered_boxes:
125
- draw.rectangle(box, outline="red",width=3)
126
-
127
- cropped_images = []
128
- for box in filtered_boxes:
129
- x1, y1, x2, y2 = box
130
  cropped_img = img[y1:y2, x1:x2]
131
  if cropped_img.size != 0:
132
- cropped_images.append(cropped_img)
133
-
134
- if visualize:
135
- # Visualization
136
- if not filtered_boxes:
137
- fig, ax = plt.subplots(figsize=(6, 6))
138
- ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
139
- ax.set_title("Original Image")
140
- ax.axis("off")
141
- plt.show()
142
- else:
143
- fig, axs = plt.subplots(1, len(cropped_images) + 2, figsize=(15, 5))
144
- axs[0].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
145
- axs[0].set_title("Original Image")
146
- axs[0].axis("off")
147
-
148
- for i, (box, score) in enumerate(zip(filtered_boxes, filtered_scores)):
149
- x1, y1, x2, y2 = box
150
- cropped_img = img[y1:y2, x1:x2]
151
- font = cv2.FONT_HERSHEY_SIMPLEX
152
- text = f"{score:.2f}"
153
- cv2.putText(cropped_img, text, (5, cropped_img.shape[0]-10), font, 0.5, (255,0,0), 1, cv2.LINE_AA)
154
- axs[i+2].imshow(cv2.cvtColor(cropped_img, cv2.COLOR_BGR2RGB))
155
- axs[i+2].set_title("Score: " + text)
156
- axs[i+2].axis("off")
157
- plt.tight_layout()
158
- plt.show()
159
-
160
- return cropped_images, image # return original image with boxes drawn
161
 
162
- def save_array_to_temp_image(arr):
163
- # Convert the array to an image
164
- img = Image.fromarray(arr)
165
-
166
- # Create a temporary file for the image
167
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir=tempfile.gettempdir())
168
- temp_file_name = temp_file.name
169
- temp_file.close() # We close it because we're not writing to it directly, PIL will handle the writing
170
-
171
- # Save the image to the temp file
172
- img.save(temp_file_name)
173
 
174
- return temp_file_name
175
-
176
- '''
177
- def process_resize(w: int, h: int, resize_dims: list) -> tuple:
178
- if len(resize_dims) == 1 and resize_dims[0] > -1:
179
- scale = resize_dims[0] / max(h, w)
180
- w_new, h_new = int(round(w * scale)), int(round(h * scale))
181
- return w_new, h_new
182
- return w, h
183
- '''
184
-
185
- def plot_image_pair(imgs, dpi=100, size=6, pad=.5):
186
- n = len(imgs)
187
- assert n == 2, 'number of images must be two'
188
- figsize = (size*n, size*3/4) if size is not None else None
189
- _, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
190
- for i in range(n):
191
- ax[i].imshow(imgs[i], cmap=plt.get_cmap('gray'), vmin=0, vmax=255)
192
- ax[i].get_yaxis().set_ticks([])
193
- ax[i].get_xaxis().set_ticks([])
194
- for spine in ax[i].spines.values(): # remove frame
195
- spine.set_visible(False)
196
- plt.tight_layout(pad=pad)
197
-
198
- def plot_keypoints(kpts0, kpts1, color='w', ps=2):
199
- ax = plt.gcf().axes
200
- ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
201
- ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
202
-
203
- def plot_matches(kpts0, kpts1, color, lw=1.5, ps=4):
204
- fig = plt.gcf()
205
- ax = fig.axes
206
- fig.canvas.draw()
207
-
208
- transFigure = fig.transFigure.inverted()
209
- fkpts0 = transFigure.transform(ax[0].transData.transform(kpts0))
210
- fkpts1 = transFigure.transform(ax[1].transData.transform(kpts1))
211
-
212
- fig.lines = [matplotlib.lines.Line2D(
213
- (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), zorder=1,
214
- transform=fig.transFigure, c=color[i], linewidth=lw)
215
- for i in range(len(kpts0))]
216
- ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
217
- ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
218
-
219
- def unified_matching_plot2(image0, image1, kpts0, kpts1, mkpts0, mkpts1,
220
- color, text, path=None, show_keypoints=False,
221
- fast_viz=False, opencv_display=False,
222
- opencv_title='matches', small_text=[]):
223
-
224
- # Set the background color for the plot
225
- plt.figure(facecolor='#eeeeee')
226
- plot_image_pair([image0, image1])
227
-
228
- # Elegant points and lines for matches
229
- if show_keypoints:
230
- plot_keypoints(kpts0, kpts1, color='k', ps=4)
231
- plot_keypoints(kpts0, kpts1, color='w', ps=2)
232
- plot_matches(mkpts0, mkpts1, color, lw=1)
233
-
234
- fig = plt.gcf()
235
-
236
- # Add text
237
- fig.text(
238
- 0.01, 0.01, '\n'.join(small_text), transform=fig.axes[0].transAxes,
239
- fontsize=10, va='bottom', ha='left', color='#333333', fontweight='bold',
240
- bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle="round,pad=0.3"))
241
-
242
- fig.text(
243
- 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
244
- fontsize=15, va='top', ha='left', color='#333333', fontweight='bold',
245
- bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle="round,pad=0.3"))
246
-
247
- # Optional: remove axis for a cleaner look
248
- plt.axis('off')
249
-
250
- # Convert the figure to an OpenCV image
251
- buf = BytesIO()
252
- plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
253
- buf.seek(0)
254
- img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
255
- buf.close()
256
- img = cv2.imdecode(img_arr, 1)
257
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
258
-
259
- # Close the figure to free memory
260
- plt.close(fig)
261
-
262
- return img
263
-
264
- def create_image_pyramid2(image_path, longest_side, scales=[0.25, 0.5, 1.0]):
265
- original_image = cv2.imread(image_path)
266
- oh, ow, _ = original_image.shape
267
-
268
- # Determine the scaling factor based on the longest side
269
- if oh > ow:
270
- output_height = longest_side
271
- output_width = int((ow / oh) * longest_side)
272
- else:
273
- output_width = longest_side
274
- output_height = int((oh / ow) * longest_side)
275
- output_size = (output_width, output_height)
276
-
277
- pyramid = []
278
-
279
- for scale in scales:
280
- # Resize based on the scale factor
281
- resized = cv2.resize(original_image, None, fx=scale, fy=scale)
282
- rh, rw, _ = resized.shape
283
-
284
- if scale < 1.0: # downsampling
285
- # Calculate the amount of padding required
286
- dy_top = max((output_size[1] - rh) // 2, 0)
287
- dy_bottom = output_size[1] - rh - dy_top
288
- dx_left = max((output_size[0] - rw) // 2, 0)
289
- dx_right = output_size[0] - rw - dx_left
290
-
291
- # Create padded image
292
- padded = cv2.copyMakeBorder(resized, dy_top, dy_bottom, dx_left, dx_right, cv2.BORDER_CONSTANT, value=[255, 255, 255])
293
- pyramid.append(padded)
294
- elif scale > 1.0: # upsampling
295
- # We need to crop the image to fit the desired output size
296
- dy = (rh - output_size[1]) // 2
297
- dx = (rw - output_size[0]) // 2
298
- cropped = resized[dy:dy+output_size[1], dx:dx+output_size[0]]
299
- pyramid.append(cropped)
300
- else: # scale == 1.0
301
- pyramid.append(resized)
302
-
303
- return pyramid
304
-
305
- # Example usage
306
- # pyramid = create_image_pyramid('path_to_image.jpg', 800)
307
- def image_matching(query_img, target_img, image_dims=[640*2], scale_factors=[0.33,0.66,1], visualize=True, k_thresh=None, m_thresh=None, write=False):
308
 
309
- image1, inp1, scales1 = read_image(target_img, device, [640*2], 0, True)
310
- query_pyramid = create_image_pyramid2(query_img, image_dims[0], scale_factors)
311
-
312
- all_valid = []
313
- all_inliers = []
314
- all_return_imgs = []
315
- max_matches_img = None
316
- max_matches = -1
317
-
318
- for idx, query_level in enumerate(query_pyramid):
319
- temp_file_path = "temp_level_{}.png".format(idx)
320
- cv2.imwrite(temp_file_path, query_level)
321
-
322
- image0, inp0, scales0 = read_image(temp_file_path, device, [640*2], 0, True)
323
-
324
- if image0 is None or image1 is None:
325
- print('Problem reading image pair: {} {}'.format(query_img, target_img))
326
- else:
327
- # Matching
328
- pred = matching({'image0': inp0, 'image1': inp1})
329
- pred = {k: v[0] for k, v in pred.items()}
330
- kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
331
- matches, conf = pred['matches0'], pred['matching_scores0']
332
-
333
- valid = matches > -1
334
- mkpts0 = kpts0[valid]
335
- mkpts1 = kpts1[matches[valid]]
336
- mconf = conf[valid]
337
- #color = cm.jet(mconf)[:len(mkpts0)] # Ensure consistent size
338
- color = cm.jet(mconf.detach().numpy())[:len(mkpts0)]
339
-
340
- all_valid.append(np.sum( valid.tolist() ))
341
-
342
- # Convert torch tensors to numpy arrays.
343
- mkpts0_np = mkpts0.cpu().numpy()
344
- mkpts1_np = mkpts1.cpu().numpy()
345
-
346
- try:
347
- # Use RANSAC to find the homography matrix.
348
- H, inliers = cv2.findHomography(mkpts0_np, mkpts1_np, cv2.RANSAC, 5.0)
349
- except:
350
- H = 0
351
- inliers = 0
352
- print ("Not enough points for homography")
353
- # Convert inliers from shape (N, 1) to shape (N,) and count them.
354
- num_inliers = np.sum(inliers)
355
-
356
- all_inliers.append(num_inliers)
357
-
358
- # Visualization
359
- text = [
360
- 'Engagify Image Matching',
361
- 'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)),
362
- 'Scaling Factor: {}'.format( scale_factors[idx]),
363
- 'Matches: {}'.format(len(mkpts0)),
364
- 'Inliers: {}'.format(num_inliers),
365
- ]
366
-
367
-
368
- k_thresh = matching.superpoint.config['keypoint_threshold']
369
- m_thresh = matching.superglue.config['match_threshold']
370
-
371
- small_text = [
372
- 'Keypoint Threshold: {:.4f}'.format(k_thresh),
373
- 'Match Threshold: {:.2f}'.format(m_thresh),
374
- ]
375
-
376
- visualized_img = None # To store the visualized image
377
-
378
- if visualize:
379
- ret_img = unified_matching_plot2(
380
- image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, text, 'Test_Level_{}'.format(idx), True, False, True, 'Matches_Level_{}'.format(idx), small_text)
381
- all_return_imgs.append(ret_img)
382
- # Storing image with most matches
383
- #if len(mkpts0) > max_matches:
384
- # max_matches = len(mkpts0)
385
- # max_matches_img = 'Matches_Level_{}'.format(idx)
386
-
387
- avg_valid = np.sum(all_valid) / len(scale_factors)
388
- avg_inliers = np.sum(all_inliers) / len(scale_factors)
389
-
390
- # Convert the image with the most matches to base64 encoded format
391
- # with open(max_matches_img, "rb") as image_file:
392
- # encoded_string = base64.b64encode(image_file.read()).decode()
393
-
394
- return {'valid':all_valid, 'inliers':all_inliers, 'visualized_image':all_return_imgs} #, encoded_string
395
-
396
- # Usage:
397
- #results = image_matching('Samples/Poster/poster_event_small_22.jpg', 'Samples/Images/16.jpeg', visualize=True)
398
- #print (results)
399
-
400
- def image_matching_no_pyramid(query_img, target_img, visualize=True, write=False):
401
-
402
  image1, inp1, scales1 = read_image(target_img, device, [640*2], 0, True)
403
  image0, inp0, scales0 = read_image(query_img, device, [640*2], 0, True)
404
-
405
  if image0 is None or image1 is None:
406
- print('Problem reading image pair: {} {}'.format(query_img, target_img))
407
  return None
408
-
409
- # Matching
410
  pred = matching({'image0': inp0, 'image1': inp1})
411
  pred = {k: v[0] for k, v in pred.items()}
412
  kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
@@ -416,194 +102,88 @@ def image_matching_no_pyramid(query_img, target_img, visualize=True, write=False
416
  mkpts0 = kpts0[valid]
417
  mkpts1 = kpts1[matches[valid]]
418
  mconf = conf[valid]
419
- #color = cm.jet(mconf)[:len(mkpts0)] # Ensure consistent size
420
- color = cm.jet(mconf.detach().numpy())[:len(mkpts0)]
421
-
422
  valid_count = np.sum(valid.tolist())
423
 
424
- # Convert torch tensors to numpy arrays.
425
  mkpts0_np = mkpts0.cpu().numpy()
426
  mkpts1_np = mkpts1.cpu().numpy()
427
 
428
- try:
429
- # Use RANSAC to find the homography matrix.
430
  H, inliers = cv2.findHomography(mkpts0_np, mkpts1_np, cv2.RANSAC, 5.0)
431
  except:
432
- H = 0
433
  inliers = 0
434
- print("Not enough points for homography")
435
-
436
- # Convert inliers from shape (N, 1) to shape (N,) and count them.
437
- num_inliers = np.sum(inliers)
438
-
439
- # Visualization
440
- text = [
441
- 'Engagify Image Matching',
442
- 'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)),
443
- 'Matches: {}'.format(len(mkpts0)),
444
- 'Inliers: {}'.format(num_inliers),
445
- ]
446
 
447
- k_thresh = matching.superpoint.config['keypoint_threshold']
448
- m_thresh = matching.superglue.config['match_threshold']
449
-
450
- small_text = [
451
- 'Keypoint Threshold: {:.4f}'.format(k_thresh),
452
- 'Match Threshold: {:.2f}'.format(m_thresh),
453
- ]
454
 
455
- visualized_img = None # To store the visualized image
456
-
457
  if visualize:
458
  visualized_img = unified_matching_plot2(
459
- image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, text, 'Test_Match', True, False, True, 'Matches', small_text)
460
-
 
 
461
  return {
462
- 'valid': [valid_count],
463
  'inliers': [num_inliers],
464
  'visualized_image': [visualized_img]
465
  }
466
 
467
- # Usage:
468
- #results = image_matching_no_pyramid('Samples/Poster/poster_event_small_22.jpg', 'Samples/Images/16.jpeg', visualize=True)
469
-
470
- # Load the SuperPoint and SuperGlue models.
471
- device = 'cuda' if torch.cuda.is_available() and not opt.force_cpu else 'cpu'
472
- print('Running inference on device \"{}\"'.format(device))
473
- config = {
474
- 'superpoint': {
475
- 'nms_radius': 4,
476
- 'keypoint_threshold': 0.005,
477
- 'max_keypoints': 1024
478
- },
479
- 'superglue': {
480
- 'weights': 'outdoor',
481
- 'sinkhorn_iterations': 20,
482
- 'match_threshold': 0.2,
483
- }
484
- }
485
- matching = Matching(config).eval().to(device)
486
-
487
- from PIL import Image
488
-
489
- def stitch_images(images):
490
- """Stitches a list of images vertically."""
491
- if not images:
492
- # Return a placeholder image if the images list is empty
493
- return Image.new('RGB', (100, 100), color='gray')
494
-
495
- max_width = max([img.width for img in images])
496
- total_height = sum(img.height for img in images)
497
-
498
- composite = Image.new('RGB', (max_width, total_height))
499
-
500
- y_offset = 0
501
- for img in images:
502
- composite.paste(img, (0, y_offset))
503
- y_offset += img.height
504
-
505
- return composite
506
-
507
- def check_object_in_image3(query_image, target_image, threshold=50, scale_factor=[0.33,0.66,1]):
508
- decision_on = []
509
- # Convert cv2 images to PIL images and add them to a list
510
  images_to_return = []
 
511
 
512
- cropped_images, bbox_image = detect_and_crop2(target_image_path=target_image,
513
- query_image_path=query_image,
514
- model=model,
515
- processor=processor,
516
- mixin=mixin,
517
- device=device,
518
- visualize=False)
519
-
520
  temp_files = [save_array_to_temp_image(i) for i in cropped_images]
521
  crop_results = [image_matching_no_pyramid(query_image, i, visualize=True) for i in temp_files]
522
 
523
  cropped_visuals = []
524
  cropped_inliers = []
525
  for result in crop_results:
526
- # Add visualized images to the temporary list
527
  for img in result['visualized_image']:
528
  cropped_visuals.append(Image.fromarray(img))
529
  for inliers_ in result['inliers']:
530
  cropped_inliers.append(inliers_)
531
- # Stitch the cropped visuals into one image
532
  images_to_return.append(stitch_images(cropped_visuals))
533
-
534
- pyramid_results = image_matching(query_image, target_image, visualize=True, scale_factors=scale_factor)
535
-
536
- pyramid_visuals = [Image.fromarray(img) for img in pyramid_results['visualized_image']]
537
- # Stitch the pyramid visuals into one image
538
- images_to_return.append(stitch_images(pyramid_visuals))
539
-
540
- # Check inliers and determine if the object is present
541
- print (cropped_inliers)
542
- is_present = any(value > threshold for value in cropped_inliers)
543
- if is_present == True:
544
- decision_on.append('Object Detection')
545
- is_present = any(value > threshold for value in pyramid_results["inliers"])
546
- if is_present == True:
547
- decision_on.append('Pyramid Max Point')
548
- if is_present == False:
549
- decision_on.append("Neither, It Failed All Tests")
550
-
551
- # Return results as a dictionary
552
  return {
553
  'is_present': is_present,
554
- 'images': images_to_return,
555
- 'scale factors': scale_factor,
556
- 'object detection inliers': cropped_inliers,
557
- 'pyramid_inliers' : pyramid_results["inliers"],
558
- 'bbox_image':bbox_image,
559
- 'decision_on':decision_on,
560
-
561
  }
562
 
563
- # Example call:
564
- #result = check_object_in_image3('Samples/Poster/poster_event_small.jpg', 'Samples/Images/True_Image_3423234.jpeg', 50)
565
- # Accessing the results:
566
- #print(result['is_present']) # prints True/False
567
- #print(result['images']) # is a list of 2 stitched images.
568
 
 
 
569
 
570
- import gradio as gr
571
- import cv2
572
- from PIL import Image
573
-
574
- def gradio_interface(query_image_path, target_image_path, threshold):
575
- result = check_object_in_image3(query_image_path, target_image_path, threshold)
576
- # Depending on how many images are in the list, you can return them like this:
577
- return result['bbox_image'], result['images'][0], result['object detection inliers'], result['scale factors'], result['pyramid_inliers'], result['images'][1], str(result['is_present']), result['decision_on']
578
-
579
-
580
- # Define the Gradio interface
581
- interface = gr.Interface(
582
- fn=gradio_interface, # function to be called on button press
583
  inputs=[
584
- gr.components.Image(label="Query Image (Drop the Image you want to detect here)", type="filepath"),
585
- gr.components.Image(label="Target Image (Drop the Image youd like to search here)", type="filepath"),
586
- gr.components.Slider(minimum=0, maximum=200, value=50, step=5, label="Enter the Inlier Threshold"),
587
- ],
 
588
  outputs=[
589
- gr.components.Image(label='Filtered Regions of Interest (Candidates)'),
590
- gr.components.Image(label="Cropped Visuals from Image Guided Object Detection "),
591
- gr.components.Text(label='Inliers detected for Image Guided Object Detection '),
592
- gr.components.Text(label='Scale Factors Used for Pyramid (Results below, In Order)'),
593
- gr.components.Text(label='Inliers detected for Pyramid Search (In Order)'),
594
- gr.components.Image(label="Pyramid Visuals"),
595
- gr.components.Textbox(label="Object Present?"),
596
- gr.components.Textbox(label="Decision Taken Based on?"),
597
  ],
598
- theme=gr.themes.Monochrome(),
599
- title="'Image Specific Image Recognition + Matching Tool",
600
- description="[Author: Ibrahim Hasani] \n "
601
- " This tool leverages Transformer, Deep Learning, and Traditional Computer Vision techniques to determine if a specified object "
602
- "(given by the query image) is present within a target image. \n"
603
- "1. Image-Guided Object Detection where we detect potential regions of interest. (Owl-Vit-Google). \n"
604
- "2. Pyramid Search that looks at various scales of the target image. Results provide "
605
- "visual representations of the matching process and a final verdict on the object's presence.\n"
606
- "3. SuperPoint (MagicLeap) + SuperGlue + Homography to extract inliers, which are thresholded for decision making."
 
607
  )
608
 
609
- interface.launch()
 
 
1
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
+ import numpy as np
 
 
 
4
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import cv2
8
+ import torch.nn.functional as F
9
+ import tempfile
10
+ import os
11
  from SuperGluePretrainedNetwork.models.matching import Matching
12
+ from SuperGluePretrainedNetwork.models.utils import read_image
 
 
 
 
 
 
 
13
 
14
+ # Load models
15
+ mixin = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
16
+ processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
17
+ model = mixin.to(device)
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ matching = Matching({
21
+ 'superpoint': {'nms_radius': 4, 'keypoint_threshold': 0.005, 'max_keypoints': 1024},
22
+ 'superglue': {'weights': 'outdoor', 'sinkhorn_iterations': 20, 'match_threshold': 0.2}
23
+ }).eval().to(device)
24
+
25
+ # Utility functions
26
+ def preprocess_image(image):
27
+ transform = transforms.Compose([
28
+ transforms.Resize((224, 224)),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
31
+ ])
32
+ return transform(image).unsqueeze(0)
33
 
34
+ def save_array_to_temp_image(arr):
35
+ rgb_arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
36
+ img = Image.fromarray(rgb_arr)
37
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
38
+ temp_file_name = temp_file.name
39
+ temp_file.close()
40
+ img.save(temp_file_name)
41
+ return temp_file_name
42
 
43
+ def stitch_images(images):
44
+ if not images:
45
+ return Image.new('RGB', (100, 100), color='gray')
46
 
47
+ max_width = max([img.width for img in images])
48
+ total_height = sum(img.height for img in images)
49
 
50
+ composite = Image.new('RGB', (max_width, total_height))
 
 
 
 
51
 
52
+ y_offset = 0
53
+ for img in images:
54
+ composite.paste(img, (0, y_offset))
55
+ y_offset += img.height
56
 
57
+ return composite
 
 
 
 
 
 
 
58
 
59
+ # Main functions
60
+ def detect_and_crop(target_image, query_image, threshold=0.5, nms_threshold=0.3):
61
+ target_sizes = torch.Tensor([target_image.size[::-1]])
62
+ inputs = processor(images=target_image, query_images=query_image, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  with torch.no_grad():
64
  outputs = model.image_guided_detection(**inputs)
65
+
66
+ img = cv2.cvtColor(np.array(target_image), cv2.COLOR_BGR2RGB)
 
67
  outputs.logits = outputs.logits.cpu()
68
+ outputs.target_pred_boxes = outputs.target_pred_boxes.cpu()
69
+
 
70
  results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes)
71
  boxes, scores = results[0]["boxes"], results[0]["scores"]
72
 
73
+ if len(boxes) == 0:
74
+ return [], None
75
+
 
 
 
 
 
 
 
 
 
76
  filtered_boxes = []
77
+ for box in boxes:
 
 
78
  x1, y1, x2, y2 = [int(i) for i in box.tolist()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  cropped_img = img[y1:y2, x1:x2]
80
  if cropped_img.size != 0:
81
+ filtered_boxes.append(cropped_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ draw = ImageDraw.Draw(target_image)
84
+ for box in boxes:
85
+ draw.rectangle(box.tolist(), outline="red", width=3)
 
 
 
 
 
 
 
 
86
 
87
+ return filtered_boxes, target_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ def image_matching_no_pyramid(query_img, target_img, visualize=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  image1, inp1, scales1 = read_image(target_img, device, [640*2], 0, True)
91
  image0, inp0, scales0 = read_image(query_img, device, [640*2], 0, True)
92
+
93
  if image0 is None or image1 is None:
 
94
  return None
95
+
 
96
  pred = matching({'image0': inp0, 'image1': inp1})
97
  pred = {k: v[0] for k, v in pred.items()}
98
  kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
 
102
  mkpts0 = kpts0[valid]
103
  mkpts1 = kpts1[matches[valid]]
104
  mconf = conf[valid]
105
+ color = cm.jet(mconf.cpu())[:len(mkpts0)]
106
+
 
107
  valid_count = np.sum(valid.tolist())
108
 
 
109
  mkpts0_np = mkpts0.cpu().numpy()
110
  mkpts1_np = mkpts1.cpu().numpy()
111
 
112
+ try:
 
113
  H, inliers = cv2.findHomography(mkpts0_np, mkpts1_np, cv2.RANSAC, 5.0)
114
  except:
 
115
  inliers = 0
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ num_inliers = np.sum(inliers)
 
 
 
 
 
 
118
 
 
 
119
  if visualize:
120
  visualized_img = unified_matching_plot2(
121
+ image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, ['Matches'], True, False, True, 'Matches', [])
122
+ else:
123
+ visualized_img = None
124
+
125
  return {
126
+ 'valid': [valid_count],
127
  'inliers': [num_inliers],
128
  'visualized_image': [visualized_img]
129
  }
130
 
131
+ def check_object_in_image(query_image, target_image, threshold=50, scale_factor=[0.33, 0.66, 1]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  images_to_return = []
133
+ cropped_images, bbox_image = detect_and_crop(target_image, query_image)
134
 
 
 
 
 
 
 
 
 
135
  temp_files = [save_array_to_temp_image(i) for i in cropped_images]
136
  crop_results = [image_matching_no_pyramid(query_image, i, visualize=True) for i in temp_files]
137
 
138
  cropped_visuals = []
139
  cropped_inliers = []
140
  for result in crop_results:
 
141
  for img in result['visualized_image']:
142
  cropped_visuals.append(Image.fromarray(img))
143
  for inliers_ in result['inliers']:
144
  cropped_inliers.append(inliers_)
145
+
146
  images_to_return.append(stitch_images(cropped_visuals))
147
+
148
+ is_present = any(value >= threshold for value in cropped_inliers)
149
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  return {
151
  'is_present': is_present,
152
+ 'images': images_to_return,
153
+ 'object detection inliers': [int(i) for i in cropped_inliers],
154
+ 'bbox_image': bbox_image,
 
 
 
 
155
  }
156
 
157
+ def interface(poster_source, media_source, threshold, scale_factor):
158
+ result1 = check_object_in_image(poster_source, media_source, threshold, scale_factor)
159
+ if result1['is_present']:
160
+ return result1
 
161
 
162
+ result2 = check_object_in_image(poster_source, media_source, threshold, scale_factor)
163
+ return result2 if result2['is_present'] else result1
164
 
165
+ iface = gr.Interface(
166
+ fn=interface,
 
 
 
 
 
 
 
 
 
 
 
167
  inputs=[
168
+ gr.Image(type="pil", label="Upload a Query Image (Poster)"),
169
+ gr.Image(type="pil", label="Upload a Target Image (Media)"),
170
+ gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Threshold"),
171
+ gr.CheckboxGroup(choices=[0.33, 0.66, 1.0], value=[0.33, 0.66, 1.0], label="Scale Factors")
172
+ ],
173
  outputs=[
174
+ gr.JSON(label="Result")
 
 
 
 
 
 
 
175
  ],
176
+ title="Object Detection in Image",
177
+ description="""
178
+ **Instructions:**
179
+
180
+ 1. **Upload a Query Image (Poster)**: Select an image file that contains the object you want to detect.
181
+ 2. **Upload a Target Image (Media)**: Select an image file where you want to detect the object.
182
+ 3. **Set Threshold**: Adjust the slider to set the threshold for object detection.
183
+ 4. **Set Scale Factors**: Select the scale factors for image pyramid.
184
+ 5. **View Results**: The result will show whether the object is present in the image along with additional details.
185
+ """
186
  )
187
 
188
+ if __name__ == "__main__":
189
+ iface.launch()