IbrahimHasani commited on
Commit
59ec2c1
1 Parent(s): f64af32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -3,20 +3,24 @@ import torch
3
  import numpy as np
4
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
5
  from torchvision import transforms
6
- from PIL import Image
7
  import cv2
8
  import torch.nn.functional as F
9
  import tempfile
10
  import os
11
  from SuperGluePretrainedNetwork.models.matching import Matching
12
  from SuperGluePretrainedNetwork.models.utils import read_image
 
 
 
 
 
13
 
14
  # Load models
15
  mixin = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
16
  processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
17
  model = mixin.to(device)
18
 
19
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  matching = Matching({
21
  'superpoint': {'nms_radius': 4, 'keypoint_threshold': 0.005, 'max_keypoints': 1024},
22
  'superglue': {'weights': 'outdoor', 'sinkhorn_iterations': 20, 'match_threshold': 0.2}
 
3
  import numpy as np
4
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
5
  from torchvision import transforms
6
+ from PIL import Image, ImageDraw
7
  import cv2
8
  import torch.nn.functional as F
9
  import tempfile
10
  import os
11
  from SuperGluePretrainedNetwork.models.matching import Matching
12
  from SuperGluePretrainedNetwork.models.utils import read_image
13
+ import matplotlib.pyplot as plt
14
+ import matplotlib.cm as cm
15
+
16
+ # Set device
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
19
  # Load models
20
  mixin = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
21
  processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
22
  model = mixin.to(device)
23
 
 
24
  matching = Matching({
25
  'superpoint': {'nms_radius': 4, 'keypoint_threshold': 0.005, 'max_keypoints': 1024},
26
  'superglue': {'weights': 'outdoor', 'sinkhorn_iterations': 20, 'match_threshold': 0.2}