import cv2 import numpy as np import torch import torchvision.datasets as datasets import torchvision.transforms as transforms import torchvision.transforms.functional as TF from torchvision.io import read_video from torch.utils.data import Dataset from random import random, choice, shuffle from io import BytesIO from PIL import Image from PIL import ImageFile from scipy.ndimage.filters import gaussian_filter import pickle import os MEAN = { "imagenet":[0.485, 0.456, 0.406], "clip":[0.48145466, 0.4578275, 0.40821073] } STD = { "imagenet":[0.229, 0.224, 0.225], "clip":[0.26862954, 0.26130258, 0.27577711] } def recursively_read(rootdir, must_contain, exts=["mp4", "avi"]): out = [] for r, d, f in os.walk(rootdir): for file in f: if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): out.append(os.path.join(r, file)) return out def get_list(path, must_contain=''): image_list = recursively_read(path, must_contain) return image_list def uniform_capture_video_frames(path, num_frames=16): capture = cv2.VideoCapture(path) total_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) frame_interval = total_frames // num_frames frames = [] frame_count = 0 while len(frames) < num_frames: ret, frame = capture.read() if not ret: break if frame_count % frame_interval == 0: # 将OpenCV的BGR图像转换为RGB图像 frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 将numpy数组转换为PIL图像 pil_image = Image.fromarray(frame_rgb) frames.append(pil_image) frame_count += 1 capture.release() return frames class RealFakeDataset(Dataset): def __init__(self, opt, clip_model=None, transform=None, num_frames=16): self.opt = opt assert opt.data_label in ["train", "val"] self.data_label = opt.data_label self.num_frames = num_frames real_list = get_list( os.path.join(opt.real_list_path) ) fake_list = get_list( os.path.join(opt.fake_list_path) ) # setting the labels for the dataset self.labels_dict = {} for i in real_list: self.labels_dict[i] = 0 for i in fake_list: self.labels_dict[i] = 1 self.total_list = real_list + fake_list shuffle(self.total_list) self.transform = transform def __len__(self): return len(self.total_list) def __getitem__(self, idx): img_path = self.total_list[idx] label = self.labels_dict[img_path] # img = Image.open(img_path).convert("RGB") # video_frames = uniform_capture_video_frames(img_path, num_frames=32) # self.clip_model.to(torch.device('cuda:{}'.format(self.opt.gpu_ids[0])) if self.opt.gpu_ids else torch.device('cpu')) frames, _, _ = read_video(str(img_path), pts_unit='sec') frames = frames[:self.num_frames] frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W) if self.transform is not None: video_frames = torch.cat([self.transform(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames]) return video_frames, label