File size: 4,442 Bytes
b3f324b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import math

import decord
from torch.nn import functional as F
import torch


IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

class DecordInit(object):
    """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""

    def __init__(self, num_threads=1):
        self.num_threads = num_threads
        self.ctx = decord.cpu(0)

    def __call__(self, filename):
        """Perform the Decord initialization.
        Args:
            results (dict): The resulting dict to be modified and passed
                to the next transform in pipeline.
        """
        reader = decord.VideoReader(filename,
                                    ctx=self.ctx,
                                    num_threads=self.num_threads)
        return reader

    def __repr__(self):
        repr_str = (f'{self.__class__.__name__}('
                    f'sr={self.sr},'
                    f'num_threads={self.num_threads})')
        return repr_str

def pad_to_multiple(number, ds_stride):
    remainder = number % ds_stride
    if remainder == 0:
        return number
    else:
        padding = ds_stride - remainder
        return number + padding

class Collate:
    def __init__(self, args):
        self.max_image_size = args.max_image_size
        self.ae_stride = args.ae_stride
        self.ae_stride_t = args.ae_stride_t
        self.patch_size = args.patch_size
        self.patch_size_t = args.patch_size_t
        self.num_frames = args.num_frames

    def __call__(self, batch):
        unzip = tuple(zip(*batch))
        if len(unzip) == 2:
            batch_tubes, labels = unzip
            labels = torch.as_tensor(labels).to(torch.long)
        elif len(unzip) == 3:
            batch_tubes, input_ids, cond_mask = unzip
            input_ids = torch.stack(input_ids).squeeze(1)
            cond_mask = torch.stack(cond_mask).squeeze(1)
        else:
            raise NotImplementedError
        ds_stride = self.ae_stride * self.patch_size
        t_ds_stride = self.ae_stride_t * self.patch_size_t

        # pad to max multiple of ds_stride
        batch_input_size = [i.shape for i in batch_tubes]
        max_t, max_h, max_w = self.num_frames, \
                              self.max_image_size, \
                              self.max_image_size
        pad_max_t, pad_max_h, pad_max_w = pad_to_multiple(max_t, t_ds_stride), \
                                          pad_to_multiple(max_h, ds_stride), \
                                          pad_to_multiple(max_w, ds_stride)
        each_pad_t_h_w = [[pad_max_t - i.shape[1],
                           pad_max_h - i.shape[2],
                           pad_max_w - i.shape[3]] for i in batch_tubes]
        pad_batch_tubes = [F.pad(im,
                                 (0, pad_w,
                                  0, pad_h,
                                  0, pad_t), value=0) for (pad_t, pad_h, pad_w), im in zip(each_pad_t_h_w, batch_tubes)]
        pad_batch_tubes = torch.stack(pad_batch_tubes, dim=0)

        # make attention_mask
        max_tube_size = [pad_max_t, pad_max_h, pad_max_w]
        max_latent_size = [max_tube_size[0] // self.ae_stride_t,
                           max_tube_size[1] // self.ae_stride,
                           max_tube_size[2] // self.ae_stride]
        max_patchify_latent_size = [max_latent_size[0] // self.patch_size_t,
                                    max_latent_size[1] // self.patch_size,
                                    max_latent_size[2] // self.patch_size]
        valid_patchify_latent_size = [[int(math.ceil(i[1] / t_ds_stride)),
                                       int(math.ceil(i[2] / ds_stride)),
                                       int(math.ceil(i[3] / ds_stride))] for i in batch_input_size]
        attention_mask = [F.pad(torch.ones(i),
                                (0, max_patchify_latent_size[2] - i[2],
                                 0, max_patchify_latent_size[1] - i[1],
                                 0, max_patchify_latent_size[0] - i[0]), value=0) for i in valid_patchify_latent_size]
        attention_mask = torch.stack(attention_mask)

        if len(unzip) == 2:
            return pad_batch_tubes, labels, attention_mask
        elif len(unzip) == 3:
            return pad_batch_tubes, attention_mask, input_ids, cond_mask