import os import av import cv2 import numpy as np import torch import gradio as gr from transformers import AutoProcessor, TvpForVideoGrounding def pyav_decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps): ''' Convert the video from its original fps to the target_fps and decode the video with PyAV decoder. Args: container (container): pyav container. sampling_rate (int): frame sampling rate (interval between two sampled frames). num_frames (int): number of frames to sample. clip_idx (int): if clip_idx is -1, perform random temporal sampling. If clip_idx is larger than -1, uniformly split the video to num_clips clips, and select the clip_idx-th video clip. num_clips (int): overall number of clips to uniformly sample from the given video. target_fps (int): the input video may have different fps, convert it to the target video fps before frame sampling. Returns: frames (tensor): decoded frames from the video. Return None if the no video stream was found. fps (float): the number of frames per second of the video. ''' video = container.streams.video[0] fps = float(video.average_rate) clip_size = sampling_rate * num_frames / target_fps * fps delta = max(num_frames - clip_size, 0) start_idx = delta * clip_idx / num_clips end_idx = start_idx + clip_size - 1 timebase = video.duration / num_frames video_start_pts = int(start_idx * timebase) video_end_pts = int(end_idx * timebase) seek_offset = max(video_start_pts - 1024, 0) container.seek(seek_offset, any_frame=False, backward=True, stream=video) frames = {} for frame in container.decode(video=0): if frame.pts < video_start_pts: continue frames[frame.pts] = frame if frame.pts > video_end_pts: break frames = [frames[pts] for pts in sorted(frames)] return frames, fps def decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps): ''' Decode the video and perform temporal sampling. Args: container (container): pyav container. sampling_rate (int): frame sampling rate (interval between two sampled frames). num_frames (int): number of frames to sample. clip_idx (int): if clip_idx is -1, perform random temporal sampling. If clip_idx is larger than -1, uniformly split the video to num_clips clips, and select the clip_idx-th video clip. num_clips (int): overall number of clips to uniformly sample from the given video. target_fps (int): the input video may have different fps, convert it to the target video fps before frame sampling. Returns: frames (tensor): decoded frames from the video. ''' assert clip_idx >= -2, "Not a valied clip_idx {}".format(clip_idx) frames, fps = pyav_decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps) clip_size = sampling_rate * num_frames / target_fps * fps index = np.linspace(0, clip_size - 1, num_frames) index = np.clip(index, 0, len(frames) - 1).astype(np.int64) frames = np.array([frames[idx].to_rgb().to_ndarray() for idx in index]) frames = frames.transpose(0, 3, 1, 2) return frames def get_video_duration(filename): cap = cv2.VideoCapture(_extract_video_filepath(filename)) if cap.isOpened(): rate = cap.get(5) frame_num = cap.get(7) duration = frame_num/rate return duration return -1 def _extract_video_filepath(video_filename): if isinstance(video_filename, dict): return video_filename['video']['path'] return video_filename def predict_durations(model_checkpoint, text, video_filename, device="cpu"): print(f"Loading model: {model_checkpoint}") model = TvpForVideoGrounding.from_pretrained(model_checkpoint) processor = AutoProcessor.from_pretrained(model_checkpoint) print(f"Loading video: {video_filename}") filepath = video_filename['video']['path'] if isinstance(video_filename, dict) else video_filename raw_sampled_frames = decode( container=av.open(_extract_video_filepath(video_filename), metadata_errors="ignore"), # container=av.open(video_filename['path'], metadata_errors="ignore"), sampling_rate=1, num_frames=model.config.num_frames, clip_idx=0, num_clips=1, target_fps=3, ) print("Processing video and text") model_inputs = processor( text=[text], videos=list(raw_sampled_frames), return_tensors="pt", max_text_length=100 ).to(device) # model_inputs["pixel_values"] = model_inputs["pixel_values"].to(model.dtype) print("Running inference") output = model(**model_inputs) duration = get_video_duration(video_filename) start, end = processor.post_process_video_grounding(output.logits, duration) return f"start: {start}s, end: {end}s" HF_TOKEN = os.environ.get("HF_TOKEN", None) DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') MODELS = ["Intel/tvp-base", "Intel/tvp-base-ANet"] EXAMPLES = [ ["Intel/tvp-base", "a person is sitting on a bed.", "./examples/bed.mp4", ], ["Intel/tvp-base", "a person eats some food.", "./examples/food.mp4", ], ["Intel/tvp-base", "a person reads a book.", "./examples/book.mp4", ], ] title = "Video Grounding with TVP" DESCRIPTION = """# Video Grounding with TVP""" with gr.Blocks(title=title) as demo: gr.Markdown(DESCRIPTION) gr.Markdown( """ Video Grounding is the task of localizing a moment in a video that best matches a natural language description. For example, given the video of a person sitting on a bed, the model should be able to predict the start and end time of the video that best matches the description "a person is sitting on a bed". Enter a description of an event in the video and select a video to see the predicted start and end time. """ ) with gr.Row(): model_checkpoint = gr.Dropdown(MODELS, label="Model", value=MODELS[0], type="value") with gr.Row(equal_height=True): with gr.Column(scale=0.5): video_in = gr.Video(label="Video File", elem_id="video_in") with gr.Column(): text_in = gr.Textbox(label="Text", placeholder="Description of event in the video", interactive=True) text_out = gr.Textbox(label="Prediction", placeholder="Predicted start and end time") time_button = gr.Button("Get start and end time") time_button.click(predict_durations, inputs=[model_checkpoint, text_in, video_in], outputs=[text_out]) examples = gr.Examples(examples=EXAMPLES, fn=predict_durations, inputs=[model_checkpoint, text_in, video_in], outputs=[text_out], cache_examples=True, preprocess=False) if __name__ == "__main__": demo.launch(debug=True)