freddyaboulton's picture
Update app.py
e60361f verified
raw
history blame
3.92 kB
import spaces
import gradio as gr
import cv2
from PIL import Image
import torch
import time
import numpy as np
from gradio_webrtc import WebRTC
import os
from twilio.rest import Client
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
from draw_boxes import draw_bounding_boxes
image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd").to("cuda")
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
print("RTC_CONFIGURATION", rtc_configuration)
SUBSAMPLE = 2
@spaces.GPU
def stream_object_detection(video, conf_threshold):
cap = cv2.VideoCapture(video)
#fps = int(cap.get(cv2.CAP_PROP_FPS))
iterating = True
#desired_fps = fps // SUBSAMPLE
#batch = []
#width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2
#height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) // 2
#n_frames = 0
while iterating:
iterating, frame = cap.read()
frame = cv2.resize( frame, (0,0), fx=0.5, fy=0.5)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
#if n_frames % SUBSAMPLE == 0:
batch.append(frame)
if len(batch) == fps:
inputs = image_processor(images=batch, return_tensors="pt").to("cuda")
print(f"starting batch of size {len(batch)}")
start = time.time()
with torch.no_grad():
outputs = model(**inputs)
end = time.time()
print("time taken for inference", end - start)
start = time.time()
boxes = image_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([(height, width)] * len(batch)),
threshold=conf_threshold)
for i, (array, box) in enumerate(zip(batch, boxes)):
pil_image = draw_bounding_boxes(Image.fromarray(array), box, model, conf_threshold)
frame = np.array(pil_image)
# Convert RGB to BGR
frame = frame[:, :, ::-1].copy()
yield frame
batch = []
end = time.time()
print("time taken for processing boxes", end - start)
n_frames += 1
with gr.Blocks() as app:
gr.HTML(
"""
<h1 style='text-align: center'>
Video Object Detection with RT-DETR (Powered by WebRTC ⚡️)
</h1>
""")
gr.HTML(
"""
<h3 style='text-align: center'>
<a href='https://arxiv.org/abs/2304.08069' target='_blank'>arXiv</a> | <a href='https://huggingface.co/PekingU/rtdetr_r101vd_coco_o365' target='_blank'>github</a>
</h3>
""")
with gr.Row():
with gr.Column():
video = gr.Video(label="Video Source")
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
)
with gr.Column():
output = WebRTC(label="WebRTC Stream",
rtc_configuration=rtc_configuration,
mode="receive",
modality="video")
detect = gr.Button("Detect", variant="primary")
output.stream(
fn=stream_object_detection,
inputs=[video, conf_threshold],
outputs=[output],
trigger=detect.click
)
gr.Examples(examples=["video_example.mp4"],
inputs=[video])
if __name__ == '__main__':
app.launch()