Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,876 Bytes
780389c 8b2cbe6 1e8e71b 6a95f1f 61732db ba2960a 77f2f5c 8b2cbe6 6a95f1f 4467a7b ccc35d4 77f2f5c 619c27a 507cefe 4467a7b 619c27a 1e8e71b 6a95f1f 4f7cb34 1e61aa0 0bc812a e60361f 4f7cb34 e60361f 0bc812a 6a95f1f 0bc812a b7278d2 6a95f1f e60361f 4467a7b 6a95f1f 61732db 6a95f1f 61732db 4467a7b 6a95f1f 4467a7b 6a95f1f 619c27a 6a95f1f 3cf5d84 61732db 6a95f1f ba2960a 6a95f1f 619c27a 4467a7b 6a95f1f 9740995 67e08d4 8b2cbe6 ba2960a 8b2cbe6 9740995 8b2cbe6 6a95f1f 8b2cbe6 9740995 619c27a 67e08d4 619c27a ba2960a 3dae50c ba2960a 619c27a ba2960a 619c27a ba2960a 619c27a 9740995 ba2960a 9740995 1e61aa0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import spaces
import gradio as gr
import cv2
from PIL import Image
import torch
import time
import numpy as np
from gradio_webrtc import WebRTC
import os
from twilio.rest import Client
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
from draw_boxes import draw_bounding_boxes
image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd").to("cuda")
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
print("RTC_CONFIGURATION", rtc_configuration)
SUBSAMPLE = 2
@spaces.GPU
def stream_object_detection(video, conf_threshold):
cap = cv2.VideoCapture(video)
fps = int(cap.get(cv2.CAP_PROP_FPS))
iterating = True
#desired_fps = fps // SUBSAMPLE
batch = []
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) // 2
#n_frames = 0
while iterating:
iterating, frame = cap.read()
frame = cv2.resize( frame, (0,0), fx=0.5, fy=0.5)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
#if n_frames % SUBSAMPLE == 0:
batch.append(frame)
if len(batch) == fps:
inputs = image_processor(images=batch, return_tensors="pt").to("cuda")
print(f"starting batch of size {len(batch)}")
start = time.time()
with torch.no_grad():
outputs = model(**inputs)
end = time.time()
print("time taken for inference", end - start)
start = time.time()
boxes = image_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([(height, width)] * len(batch)),
threshold=conf_threshold)
for _, (array, box) in enumerate(zip(batch, boxes)):
pil_image = draw_bounding_boxes(Image.fromarray(array), box, model, conf_threshold)
frame = np.array(pil_image)
# Convert RGB to BGR
frame = frame[:, :, ::-1].copy()
yield frame
batch = []
end = time.time()
print("time taken for processing boxes", end - start)
with gr.Blocks() as app:
gr.HTML(
"""
<h1 style='text-align: center'>
Video Object Detection with RT-DETR (Powered by WebRTC ⚡️)
</h1>
""")
gr.HTML(
"""
<h3 style='text-align: center'>
<a href='https://arxiv.org/abs/2304.08069' target='_blank'>arXiv</a> | <a href='https://huggingface.co/PekingU/rtdetr_r101vd_coco_o365' target='_blank'>github</a>
</h3>
""")
with gr.Row():
with gr.Column():
video = gr.Video(label="Video Source")
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
)
with gr.Column():
output = WebRTC(label="WebRTC Stream",
rtc_configuration=rtc_configuration,
mode="receive",
modality="video")
detect = gr.Button("Detect", variant="primary")
output.stream(
fn=stream_object_detection,
inputs=[video, conf_threshold],
outputs=[output],
trigger=detect.click
)
gr.Examples(examples=["video_example.mp4"],
inputs=[video])
if __name__ == '__main__':
app.launch() |