ybbwcwaps
AI Video
3cc4a06
raw
history blame
3.29 kB
import cv2
import numpy as np
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.io import read_video
from torch.utils.data import Dataset
from random import random, choice, shuffle
from io import BytesIO
from PIL import Image
from PIL import ImageFile
from scipy.ndimage.filters import gaussian_filter
import pickle
import os
MEAN = {
"imagenet":[0.485, 0.456, 0.406],
"clip":[0.48145466, 0.4578275, 0.40821073]
}
STD = {
"imagenet":[0.229, 0.224, 0.225],
"clip":[0.26862954, 0.26130258, 0.27577711]
}
def recursively_read(rootdir, must_contain, exts=["mp4", "avi"]):
out = []
for r, d, f in os.walk(rootdir):
for file in f:
if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)):
out.append(os.path.join(r, file))
return out
def get_list(path, must_contain=''):
image_list = recursively_read(path, must_contain)
return image_list
def uniform_capture_video_frames(path, num_frames=16):
capture = cv2.VideoCapture(path)
total_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
frame_interval = total_frames // num_frames
frames = []
frame_count = 0
while len(frames) < num_frames:
ret, frame = capture.read()
if not ret:
break
if frame_count % frame_interval == 0:
# 将OpenCV的BGR图像转换为RGB图像
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 将numpy数组转换为PIL图像
pil_image = Image.fromarray(frame_rgb)
frames.append(pil_image)
frame_count += 1
capture.release()
return frames
class RealFakeDataset(Dataset):
def __init__(self, opt, clip_model=None, transform=None, num_frames=16):
self.opt = opt
assert opt.data_label in ["train", "val"]
self.data_label = opt.data_label
self.num_frames = num_frames
real_list = get_list( os.path.join(opt.real_list_path) )
fake_list = get_list( os.path.join(opt.fake_list_path) )
# setting the labels for the dataset
self.labels_dict = {}
for i in real_list:
self.labels_dict[i] = 0
for i in fake_list:
self.labels_dict[i] = 1
self.total_list = real_list + fake_list
shuffle(self.total_list)
self.transform = transform
def __len__(self):
return len(self.total_list)
def __getitem__(self, idx):
img_path = self.total_list[idx]
label = self.labels_dict[img_path]
# img = Image.open(img_path).convert("RGB")
# video_frames = uniform_capture_video_frames(img_path, num_frames=32)
# self.clip_model.to(torch.device('cuda:{}'.format(self.opt.gpu_ids[0])) if self.opt.gpu_ids else torch.device('cpu'))
frames, _, _ = read_video(str(img_path), pts_unit='sec')
frames = frames[:self.num_frames]
frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W)
if self.transform is not None:
video_frames = torch.cat([self.transform(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames])
return video_frames, label