Don't resize output image

#2
by ceyda - opened
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -19,19 +19,22 @@ processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
19
  def query_image(img, text_queries, score_threshold):
20
  text_queries = text_queries
21
  text_queries = text_queries.split(",")
22
- inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
 
 
 
 
23
 
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
 
27
 
28
- target_sizes = torch.Tensor([[768, 768]])
29
  outputs.logits = outputs.logits.cpu()
30
  outputs.pred_boxes = outputs.pred_boxes.cpu()
31
  results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
32
  boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
33
 
34
- img = cv2.resize(img, (768, 768), interpolation = cv2.INTER_AREA)
35
  font = cv2.FONT_HERSHEY_SIMPLEX
36
 
37
  for box, score, label in zip(boxes, scores, labels):
@@ -61,7 +64,7 @@ can also use the score threshold slider to set a threshold to filter out low pro
61
  """
62
  demo = gr.Interface(
63
  query_image,
64
- inputs=[gr.Image(shape=(768, 768)), "text", gr.Slider(0, 1, value=0.1)],
65
  outputs="image",
66
  title="Zero-Shot Object Detection with OWL-ViT",
67
  description=description,
 
19
  def query_image(img, text_queries, score_threshold):
20
  text_queries = text_queries
21
  text_queries = text_queries.split(",")
22
+
23
+ target_sizes = torch.Tensor([img.shape[:2]])
24
+ img_input = cv2.resize(img, (768, 768), interpolation = cv2.INTER_AREA)
25
+
26
+ inputs = processor(text=text_queries, images=img_input, return_tensors="pt").to(device)
27
 
28
  with torch.no_grad():
29
  outputs = model(**inputs)
30
 
31
 
32
+
33
  outputs.logits = outputs.logits.cpu()
34
  outputs.pred_boxes = outputs.pred_boxes.cpu()
35
  results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
36
  boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
37
 
 
38
  font = cv2.FONT_HERSHEY_SIMPLEX
39
 
40
  for box, score, label in zip(boxes, scores, labels):
 
64
  """
65
  demo = gr.Interface(
66
  query_image,
67
+ inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1)],
68
  outputs="image",
69
  title="Zero-Shot Object Detection with OWL-ViT",
70
  description=description,