|
|
|
import os |
|
import glob |
|
from functools import partial |
|
from tqdm import tqdm, trange |
|
from multiprocessing import Pool |
|
from PIL import Image |
|
import cv2 |
|
import mlxu |
|
from natsort import natsorted |
|
import numpy as np |
|
import einops |
|
import torch |
|
|
|
from vqlm_demo.inference import MultiProcessInferenceModel |
|
from vqlm_demo.utils import ( |
|
is_video, random_square_crop, |
|
read_frames_from_dir, read_frames_from_video |
|
) |
|
|
|
|
|
FLAGS, _ = mlxu.define_flags_with_default( |
|
checkpoint='', |
|
input_files='', |
|
frame_input=False, |
|
read_file_list='', |
|
output_dir='', |
|
center_crop=1.0, |
|
n_context_frames=12, |
|
n_new_frames=4, |
|
n_candidates=8, |
|
temperature=1.0, |
|
top_p=1.0, |
|
n_workers=8, |
|
stride=8, |
|
batch_size=32, |
|
torch_devices='', |
|
shuffle=False, |
|
max_examples=0, |
|
) |
|
|
|
|
|
def save_image(args): |
|
image, filename = args |
|
base = FLAGS.input_files.split('*')[0] |
|
filename = filename[len(base):].replace('/', '_') + '.png' |
|
Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename)) |
|
|
|
|
|
class VideoDataset(torch.utils.data.Dataset): |
|
|
|
def __init__(self, videos, frame_input=False, n_frames=8, stride=1): |
|
self.videos = videos |
|
self.frame_input = frame_input |
|
self.n_frames = n_frames |
|
self.stride = stride |
|
|
|
def __getitem__(self, index): |
|
if self.frame_input: |
|
frames = read_frames_from_dir( |
|
self.videos[index], self.n_frames, self.stride, |
|
center_crop=FLAGS.center_crop, |
|
) |
|
else: |
|
frames = read_frames_from_video( |
|
self.videos[index], self.n_frames, self.stride, |
|
center_crop=FLAGS.center_crop, |
|
) |
|
if frames is None: |
|
return self[np.random.randint(0, len(self))] |
|
return frames, self.videos[index] |
|
|
|
def __len__(self): |
|
return len(self.videos) |
|
|
|
|
|
|
|
def main(_): |
|
assert FLAGS.checkpoint != '' and FLAGS.output_dir != '' |
|
assert FLAGS.read_file_list != '' or FLAGS.input_files != '' |
|
os.makedirs(FLAGS.output_dir, exist_ok=True) |
|
|
|
if FLAGS.read_file_list != '': |
|
with open(FLAGS.read_file_list, 'r') as f: |
|
videos = [x.strip() for x in f.readlines()] |
|
else: |
|
videos = glob.glob(FLAGS.input_files) |
|
|
|
if FLAGS.frame_input: |
|
videos = [x for x in videos if os.path.isdir(x)] |
|
else: |
|
videos = [x for x in videos if is_video(x)] |
|
|
|
if FLAGS.shuffle: |
|
np.random.shuffle(videos) |
|
|
|
if FLAGS.max_examples > 0: |
|
videos = videos[:FLAGS.max_examples] |
|
|
|
dataset = VideoDataset( |
|
videos, |
|
frame_input=FLAGS.frame_input, |
|
n_frames=FLAGS.n_context_frames, |
|
stride=FLAGS.stride |
|
) |
|
dataloader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=FLAGS.batch_size, |
|
shuffle=False, |
|
num_workers=FLAGS.n_workers, |
|
prefetch_factor=4, |
|
drop_last=True, |
|
) |
|
|
|
if FLAGS.torch_devices == '': |
|
torch_devices = None |
|
else: |
|
torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')] |
|
|
|
model = MultiProcessInferenceModel( |
|
checkpoint=FLAGS.checkpoint, torch_devices=torch_devices, |
|
) |
|
|
|
save_img_pool = Pool(FLAGS.n_workers) |
|
|
|
|
|
|
|
for batch, filenames in tqdm(dataloader, ncols=0): |
|
|
|
|
|
|
|
batch = batch.numpy() |
|
|
|
|
|
|
|
generated = model( |
|
batch, |
|
n_new_frames=FLAGS.n_new_frames, |
|
n_candidates=FLAGS.n_candidates, |
|
temperature=FLAGS.temperature, |
|
top_p=FLAGS.top_p, |
|
) |
|
|
|
|
|
generated = np.array(generated) |
|
|
|
|
|
|
|
|
|
output_batch = einops.repeat( |
|
batch, |
|
'b s h w c -> b n s h w c', |
|
n=FLAGS.n_candidates, |
|
) |
|
|
|
|
|
combined = einops.rearrange( |
|
np.concatenate([output_batch, generated], axis=2), |
|
'b n s h w c -> b (n h) (s w) c' |
|
) |
|
|
|
|
|
combined = (np.clip(combined, 0, 1) * 255).astype(np.uint8) |
|
save_img_pool.imap(save_image, zip(combined, filenames)) |
|
|
|
|
|
if __name__ == '__main__': |
|
mlxu.run(main) |