owl-tracking / app.py
merve's picture
merve HF staff
Update app.py
c7b56d5 verified
raw
history blame
3.83 kB
from transformers import Owlv2Processor, Owlv2ForObjectDetection
from typing import List
import os
import numpy as np
import supervision as sv
import uuid
import torch
from tqdm import tqdm
import gradio as gr
import torch
import numpy as np
from PIL import Image
import spaces
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
MASK_ANNOTATOR = sv.MaskAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()
def calculate_end_frame_index(source_video_path):
video_info = sv.VideoInfo.from_video_path(source_video_path)
return min(
video_info.total_frames,
video_info.fps * 2
)
def annotate_image(
input_image,
detections,
labels
) -> np.ndarray:
output_image = MASK_ANNOTATOR.annotate(input_image, detections)
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
return output_image
@spaces.GPU
def process_video(
input_video,
labels,
progress=gr.Progress(track_tqdm=True)
):
labels = labels.split(",")
video_info = sv.VideoInfo.from_video_path(input_video)
total = calculate_end_frame_index(input_video)
frame_generator = sv.get_video_frames_generator(
source_path=input_video,
end=total
)
result_file_name = f"{uuid.uuid4()}.mp4"
result_file_path = os.path.join("./outputs", result_file_name)
with sv.VideoSink(result_file_path, video_info=video_info) as sink:
for _ in tqdm(range(total), desc="Processing video.."):
frame = next(frame_generator)
# list of dict of {"box": box, "mask":mask, "score":score, "label":label}
results = query(frame, labels)
detections = sv.Detections.from_transformers(results[0])
final_labels = []
for id in results[0]["labels"]:
final_labels.append(labels[id])
frame = annotate_image(
input_image=frame,
detections=detections,
labels=final_labels,
)
sink.write_frame(frame)
return result_file_path
def query(image, texts):
inputs = processor(text=texts, images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
target_sizes = torch.Tensor([image.shape[:-1]])
results = processor.post_process_object_detection(outputs=outputs, threshold=0.3, target_sizes=target_sizes)
return results
with gr.Blocks() as demo:
gr.Markdown("## Zero-shot Object Tracking with OWLv2 🦉")
gr.Markdown("This is a demo for zero-shot object tracking using [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) model by Google.")
gr.Markdown("Simply upload a video and enter the candidate labels, or try the example below. 👇")
with gr.Tab(label="Video"):
with gr.Row():
input_video = gr.Video(
label='Input Video'
)
output_video = gr.Video(
label='Output Video'
)
with gr.Row():
candidate_labels = gr.Textbox(
label='Labels',
placeholder='Labels separated by a comma',
)
submit = gr.Button()
gr.Examples(
fn=process_video,
examples=[["./cats.mp4", "dog,cat"]],
inputs=[
input_video,
candidate_labels,
],
outputs=output_video
)
submit.click(
fn=process_video,
inputs=[input_video, candidate_labels],
outputs=output_video
)
demo.launch(debug=False, show_error=True)