File size: 2,067 Bytes
b3f324b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
from glob import glob

import numpy as np
import torch
import torchvision
from PIL import Image
from torch.utils.data import Dataset

from opensora.utils.dataset_utils import DecordInit, is_image_file


class ExtractVideo2Feature(Dataset):
    def __init__(self, args, transform):
        self.data_path = args.data_path
        self.transform = transform
        self.v_decoder = DecordInit()
        self.samples = list(glob(f'{self.data_path}'))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        video_path = self.samples[idx]
        video = self.decord_read(video_path)
        video = self.transform(video)  # T C H W -> T C H W
        return video, video_path

    def tv_read(self, path):
        vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
        total_frames = len(vframes)
        frame_indice = list(range(total_frames))
        video = vframes[frame_indice]
        return video

    def decord_read(self, path):
        decord_vr = self.v_decoder(path)
        total_frames = len(decord_vr)
        frame_indice = list(range(total_frames))
        video_data = decord_vr.get_batch(frame_indice).asnumpy()
        video_data = torch.from_numpy(video_data)
        video_data = video_data.permute(0, 3, 1, 2)  # (T, H, W, C) -> (T C H W)
        return video_data



class ExtractImage2Feature(Dataset):
    def __init__(self, args, transform):
        self.data_path = args.data_path
        self.transform = transform
        self.data_all = list(glob(f'{self.data_path}'))

    def __len__(self):
        return len(self.data_all)

    def __getitem__(self, index):
        path = self.data_all[index]
        video_frame = torch.as_tensor(np.array(Image.open(path), dtype=np.uint8, copy=True)).unsqueeze(0)
        video_frame = video_frame.permute(0, 3, 1, 2)
        video_frame = self.transform(video_frame)  # T C H W
        # video_frame = video_frame.transpose(0, 1)  # T C H W -> C T H W

        return video_frame, path