Spaces:
Sleeping
Sleeping
import cv2 | |
import numpy as np | |
import torch | |
import torchvision.datasets as datasets | |
import torchvision.transforms as transforms | |
import torchvision.transforms.functional as TF | |
from torchvision.io import read_video | |
from torch.utils.data import Dataset | |
from random import random, choice, shuffle | |
from io import BytesIO | |
from PIL import Image | |
from PIL import ImageFile | |
from scipy.ndimage.filters import gaussian_filter | |
import pickle | |
import os | |
MEAN = { | |
"imagenet":[0.485, 0.456, 0.406], | |
"clip":[0.48145466, 0.4578275, 0.40821073] | |
} | |
STD = { | |
"imagenet":[0.229, 0.224, 0.225], | |
"clip":[0.26862954, 0.26130258, 0.27577711] | |
} | |
def recursively_read(rootdir, must_contain, exts=["mp4", "avi"]): | |
out = [] | |
for r, d, f in os.walk(rootdir): | |
for file in f: | |
if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): | |
out.append(os.path.join(r, file)) | |
return out | |
def get_list(path, must_contain=''): | |
image_list = recursively_read(path, must_contain) | |
return image_list | |
def uniform_capture_video_frames(path, num_frames=16): | |
capture = cv2.VideoCapture(path) | |
total_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frame_interval = total_frames // num_frames | |
frames = [] | |
frame_count = 0 | |
while len(frames) < num_frames: | |
ret, frame = capture.read() | |
if not ret: | |
break | |
if frame_count % frame_interval == 0: | |
# 将OpenCV的BGR图像转换为RGB图像 | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
# 将numpy数组转换为PIL图像 | |
pil_image = Image.fromarray(frame_rgb) | |
frames.append(pil_image) | |
frame_count += 1 | |
capture.release() | |
return frames | |
class RealFakeDataset(Dataset): | |
def __init__(self, opt, clip_model=None, transform=None, num_frames=16): | |
self.opt = opt | |
assert opt.data_label in ["train", "val"] | |
self.data_label = opt.data_label | |
self.num_frames = num_frames | |
real_list = get_list( os.path.join(opt.real_list_path) ) | |
fake_list = get_list( os.path.join(opt.fake_list_path) ) | |
# setting the labels for the dataset | |
self.labels_dict = {} | |
for i in real_list: | |
self.labels_dict[i] = 0 | |
for i in fake_list: | |
self.labels_dict[i] = 1 | |
self.total_list = real_list + fake_list | |
shuffle(self.total_list) | |
self.transform = transform | |
def __len__(self): | |
return len(self.total_list) | |
def __getitem__(self, idx): | |
img_path = self.total_list[idx] | |
label = self.labels_dict[img_path] | |
# img = Image.open(img_path).convert("RGB") | |
# video_frames = uniform_capture_video_frames(img_path, num_frames=32) | |
# self.clip_model.to(torch.device('cuda:{}'.format(self.opt.gpu_ids[0])) if self.opt.gpu_ids else torch.device('cpu')) | |
frames, _, _ = read_video(str(img_path), pts_unit='sec') | |
frames = frames[:self.num_frames] | |
frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W) | |
if self.transform is not None: | |
video_frames = torch.cat([self.transform(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames]) | |
return video_frames, label |