Spaces:
Sleeping
Sleeping
File size: 3,286 Bytes
3cc4a06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import cv2
import numpy as np
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.io import read_video
from torch.utils.data import Dataset
from random import random, choice, shuffle
from io import BytesIO
from PIL import Image
from PIL import ImageFile
from scipy.ndimage.filters import gaussian_filter
import pickle
import os
MEAN = {
"imagenet":[0.485, 0.456, 0.406],
"clip":[0.48145466, 0.4578275, 0.40821073]
}
STD = {
"imagenet":[0.229, 0.224, 0.225],
"clip":[0.26862954, 0.26130258, 0.27577711]
}
def recursively_read(rootdir, must_contain, exts=["mp4", "avi"]):
out = []
for r, d, f in os.walk(rootdir):
for file in f:
if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)):
out.append(os.path.join(r, file))
return out
def get_list(path, must_contain=''):
image_list = recursively_read(path, must_contain)
return image_list
def uniform_capture_video_frames(path, num_frames=16):
capture = cv2.VideoCapture(path)
total_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
frame_interval = total_frames // num_frames
frames = []
frame_count = 0
while len(frames) < num_frames:
ret, frame = capture.read()
if not ret:
break
if frame_count % frame_interval == 0:
# 将OpenCV的BGR图像转换为RGB图像
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 将numpy数组转换为PIL图像
pil_image = Image.fromarray(frame_rgb)
frames.append(pil_image)
frame_count += 1
capture.release()
return frames
class RealFakeDataset(Dataset):
def __init__(self, opt, clip_model=None, transform=None, num_frames=16):
self.opt = opt
assert opt.data_label in ["train", "val"]
self.data_label = opt.data_label
self.num_frames = num_frames
real_list = get_list( os.path.join(opt.real_list_path) )
fake_list = get_list( os.path.join(opt.fake_list_path) )
# setting the labels for the dataset
self.labels_dict = {}
for i in real_list:
self.labels_dict[i] = 0
for i in fake_list:
self.labels_dict[i] = 1
self.total_list = real_list + fake_list
shuffle(self.total_list)
self.transform = transform
def __len__(self):
return len(self.total_list)
def __getitem__(self, idx):
img_path = self.total_list[idx]
label = self.labels_dict[img_path]
# img = Image.open(img_path).convert("RGB")
# video_frames = uniform_capture_video_frames(img_path, num_frames=32)
# self.clip_model.to(torch.device('cuda:{}'.format(self.opt.gpu_ids[0])) if self.opt.gpu_ids else torch.device('cpu'))
frames, _, _ = read_video(str(img_path), pts_unit='sec')
frames = frames[:self.num_frames]
frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W)
if self.transform is not None:
video_frames = torch.cat([self.transform(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames])
return video_frames, label |