File size: 4,631 Bytes
c58ca4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from dataclasses import dataclass
from fractions import Fraction
from pathlib import Path
from typing import Optional

import av
import numpy as np
import torch
from av import AudioFrame


@dataclass
class VideoInfo:
    duration_sec: float
    fps: Fraction
    clip_frames: torch.Tensor
    sync_frames: torch.Tensor
    all_frames: Optional[list[np.ndarray]]

    @property
    def height(self):
        return self.all_frames[0].shape[0]

    @property
    def width(self):
        return self.all_frames[0].shape[1]


def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
                need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
    output_frames = [[] for _ in list_of_fps]
    next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
    time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
    all_frames = []

    # container = av.open(video_path)
    with av.open(video_path) as container:
        stream = container.streams.video[0]
        fps = stream.guessed_rate
        stream.thread_type = 'AUTO'
        for packet in container.demux(stream):
            for frame in packet.decode():
                frame_time = frame.time
                if frame_time < start_sec:
                    continue
                if frame_time > end_sec:
                    break

                frame_np = None
                if need_all_frames:
                    frame_np = frame.to_ndarray(format='rgb24')
                    all_frames.append(frame_np)

                for i, _ in enumerate(list_of_fps):
                    this_time = frame_time
                    while this_time >= next_frame_time_for_each_fps[i]:
                        if frame_np is None:
                            frame_np = frame.to_ndarray(format='rgb24')

                        output_frames[i].append(frame_np)
                        next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]

    output_frames = [np.stack(frames) for frames in output_frames]
    return output_frames, all_frames, fps


def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
                        sampling_rate: int):
    container = av.open(output_path, 'w')
    output_video_stream = container.add_stream('h264', video_info.fps)
    output_video_stream.codec_context.bit_rate = 10 * 1e6  # 10 Mbps
    output_video_stream.width = video_info.width
    output_video_stream.height = video_info.height
    output_video_stream.pix_fmt = 'yuv420p'

    output_audio_stream = container.add_stream('aac', sampling_rate)

    # encode video
    for image in video_info.all_frames:
        image = av.VideoFrame.from_ndarray(image)
        packet = output_video_stream.encode(image)
        container.mux(packet)

    for packet in output_video_stream.encode():
        container.mux(packet)

    # convert float tensor audio to numpy array
    audio_np = audio.numpy().astype(np.float32)
    audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
    audio_frame.sample_rate = sampling_rate

    for packet in output_audio_stream.encode(audio_frame):
        container.mux(packet)

    for packet in output_audio_stream.encode():
        container.mux(packet)

    container.close()


def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
    """
    NOTE: I don't think we can get the exact video duration right without re-encoding
    so we are not using this but keeping it here for reference
    """
    video = av.open(video_path)
    output = av.open(output_path, 'w')
    input_video_stream = video.streams.video[0]
    output_video_stream = output.add_stream(template=input_video_stream)
    output_audio_stream = output.add_stream('aac', sampling_rate)

    duration_sec = audio.shape[-1] / sampling_rate

    for packet in video.demux(input_video_stream):
        # We need to skip the "flushing" packets that `demux` generates.
        if packet.dts is None:
            continue
        # We need to assign the packet to the new stream.
        packet.stream = output_video_stream
        output.mux(packet)

    # convert float tensor audio to numpy array
    audio_np = audio.numpy().astype(np.float32)
    audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
    audio_frame.sample_rate = sampling_rate

    for packet in output_audio_stream.encode(audio_frame):
        output.mux(packet)

    for packet in output_audio_stream.encode():
        output.mux(packet)

    video.close()
    output.close()

    output.close()