Spaces:
Runtime error
Runtime error
File size: 4,655 Bytes
a001281 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import os, io, csv, math, random
import numpy as np
from einops import rearrange
from decord import VideoReader
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from PIA.utils.util import zero_rank_print, detect_edges
import cv2
def get_score(video_data,
cond_frame_idx,
weight=[1.0, 1.0, 1.0, 1.0],
use_edge=True):
"""
Similar to get_score under utils/util.py/detect_edges
"""
"""
the shape of video_data is f c h w, np.ndarray
"""
h, w = video_data.shape[1], video_data.shape[2]
cond_frame = video_data[cond_frame_idx]
cond_hsv_list = list(
cv2.split(
cv2.cvtColor(cond_frame.astype(np.float32), cv2.COLOR_RGB2HSV)))
if use_edge:
cond_frame_lum = cond_hsv_list[-1]
cond_frame_edge = detect_edges(cond_frame_lum.astype(np.uint8))
cond_hsv_list.append(cond_frame_edge)
score_sum = []
for frame_idx in range(video_data.shape[0]):
frame = video_data[frame_idx]
hsv_list = list(
cv2.split(cv2.cvtColor(frame.astype(np.float32),
cv2.COLOR_RGB2HSV)))
if use_edge:
frame_img_lum = hsv_list[-1]
frame_img_edge = detect_edges(lum=frame_img_lum.astype(np.uint8))
hsv_list.append(frame_img_edge)
hsv_diff = [
np.abs(hsv_list[c] - cond_hsv_list[c]) for c in range(len(weight))
]
hsv_mse = [np.sum(hsv_diff[c]) * weight[c] for c in range(len(weight))]
score_sum.append(sum(hsv_mse) / (h * w) / (sum(weight)))
return score_sum
class WebVid10M(Dataset):
def __init__(
self,
csv_path, video_folder,
sample_size=256, sample_stride=4, sample_n_frames=16,
is_image=False,
):
zero_rank_print(f"loading annotations from {csv_path} ...")
with open(csv_path, 'r') as csvfile:
self.dataset = list(csv.DictReader(csvfile))
self.length = len(self.dataset)
zero_rank_print(f"data scale: {self.length}")
self.video_folder = video_folder
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.is_image = is_image
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
self.pixel_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(sample_size[0]),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
def get_batch(self, idx):
video_dict = self.dataset[idx]
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
video_reader = VideoReader(video_dir)
video_length = len(video_reader)
total_frames = len(video_reader)
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
start_idx = random.randint(0, video_length - clip_length)
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
frame_indice = [random.randint(0, total_frames - 1)]
pixel_values_np = video_reader.get_batch(frame_indice).asnumpy()
cond_frames = random.randint(0, self.sample_n_frames - 1)
# f h w c -> f c h w
pixel_values = torch.from_numpy(pixel_values_np).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
del video_reader
if self.is_image:
pixel_values = pixel_values[0]
return pixel_values, name, cond_frames, videoid
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
video, name, cond_frames, videoid = self.get_batch(idx)
break
except Exception as e:
# zero_rank_print(e)
idx = random.randint(0, self.length-1)
video = self.pixel_transforms(video)
video_ = video.clone().permute(0, 2, 3, 1).numpy() / 2 + 0.5
video_ = video_ * 255
#video_ = video_.astype(np.uint8)
score = get_score(video_, cond_frame_idx=cond_frames)
del video_
sample = dict(pixel_values=video, text=name, score=score, cond_frames=cond_frames, vid=videoid)
return sample
|