This-and-That / test_code /inference.py
HikariDawn777's picture
update
c75a3f7
raw
history blame
21 kB
'''
This file is to test UNet and GestureNet.
'''
import os, shutil, sys
import urllib.request
import argparse
import imageio
import math
import cv2
from PIL import Image
import collections
import numpy as np
import torch
from pathlib import Path
from omegaconf import OmegaConf
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from diffusers import (
AutoencoderKLTemporalDecoder,
DDPMScheduler,
)
from diffusers.utils import check_min_version, is_wandb_available, load_image, export_to_video
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, PretrainedConfig
# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from train_code.train_svd import import_pretrained_text_encoder
from data_loader.video_dataset import tokenize_captions
from data_loader.video_this_that_dataset import get_thisthat_sam
from svd.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from svd.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline
from svd.temporal_controlnet import ControlNetModel
from svd.pipeline_stable_video_diffusion_controlnet import StableVideoDiffusionControlNetPipeline
# Seed
# torch.manual_seed(42)
# np.random.seed(42)
def unet_inference(vae, unet, image_encoder, text_encoder, tokenizer, config, accelerator, weight_dtype, step,
parent_store_folder = None, force_close_flip = False, use_ambiguous_prompt=False):
# Init
validation_source_folder = config["validation_img_folder"]
# Init the pipeline
pipeline = StableVideoDiffusionPipeline.from_pretrained(
config["pretrained_model_name_or_path"],
vae = accelerator.unwrap_model(vae),
image_encoder = accelerator.unwrap_model(image_encoder),
unet = accelerator.unwrap_model(unet),
revision = None, # Set None directly now
torch_dtype = weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# Process all image in the folder
frames_collection = []
for image_name in sorted(os.listdir(validation_source_folder)):
if accelerator.is_main_process:
if parent_store_folder is None:
validation_store_folder = os.path.join(config["validation_store_folder"] + "_" + config["scheduler"], "step_" + str(step), image_name)
else:
validation_store_folder = os.path.join(parent_store_folder, image_name)
if os.path.exists(validation_store_folder):
shutil.rmtree(validation_store_folder)
os.makedirs(validation_store_folder)
image_path = os.path.join(validation_source_folder, image_name, 'im_0.jpg')
ref_image = load_image(image_path)
ref_image = ref_image.resize((config["width"], config["height"]))
# Decide the motion score in SVD (mostly what we use is fix value now)
if config["motion_bucket_id"] is None:
raise NotImplementedError("We need a fixed motion_bucket_id in the config")
else:
reflected_motion_bucket_id = config["motion_bucket_id"]
# print("Inference Motion Bucket ID is ", reflected_motion_bucket_id)
# Prepare text prompt
if config["use_text"]:
# Read the file
file_path = os.path.join(validation_source_folder, image_name, "lang.txt")
file = open(file_path, 'r')
prompt = file.readlines()[0] # Only read the first line
if use_ambiguous_prompt:
prompt = prompt.split(" ")[0] + " this to there"
print("We are creating ambiguous prompt, which is: ", prompt)
else:
prompt = ""
# Use the same tokenize process as the dataset preparation stage
tokenized_prompt = tokenize_captions(prompt, tokenizer, config, is_train=False).unsqueeze(0).to(accelerator.device) # Use unsqueeze to expand dim
# Store the prompt for the sanity check
f = open(os.path.join(validation_store_folder, "lang_cond.txt"), "a")
f.write(prompt)
f.close()
# Flip the image by chance (it is needed to check whether there is any object position words [left|right] in the prompt text)
flip = False
if not force_close_flip: # force_close_flip is True in testing time; else, we cannot match in the same standard
if random.random() < config["flip_aug_prob"]:
if config["use_text"]:
if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok)
flip = True
else:
flip = True
if flip:
print("Use flip in validation!")
ref_image = ref_image.transpose(Image.FLIP_LEFT_RIGHT)
# Call the model for inference
with torch.autocast("cuda"):
frames = pipeline(
ref_image,
tokenized_prompt,
config["use_text"],
text_encoder,
height = config["height"],
width = config["width"],
num_frames = config["video_seq_length"],
num_inference_steps = config["num_inference_steps"],
decode_chunk_size = 8,
motion_bucket_id = reflected_motion_bucket_id,
fps = 7,
noise_aug_strength = config["inference_noise_aug_strength"],
).frames[0]
# Store the frames
# breakpoint()
for idx, frame in enumerate(frames):
frame.save(os.path.join(validation_store_folder, str(idx)+".png"))
imageio.mimsave(os.path.join(validation_store_folder, 'combined.gif'), frames) # gif storage quality is not high, recommend to check png images
frames_collection.append(frames)
# Cleaning process
del pipeline
torch.cuda.empty_cache()
return frames_collection # Return resuly based on the need
def gesturenet_inference(vae, unet, controlnet, image_encoder, text_encoder, tokenizer, config, accelerator, weight_dtype, step,
parent_store_folder=None, force_close_flip=False, use_ambiguous_prompt=False):
# Init
validation_source_folder = config["validation_img_folder"]
# Init the pipeline
pipeline = StableVideoDiffusionControlNetPipeline.from_pretrained(
config["pretrained_model_name_or_path"], # Still based on regular SVD config
vae = vae,
image_encoder = image_encoder,
unet = unet,
revision = None, # Set None directly now
torch_dtype = weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# Process all image in the folder
frames_collection = []
for image_name in sorted(os.listdir(validation_source_folder)):
if accelerator.is_main_process:
if parent_store_folder is None:
validation_store_folder = os.path.join(config["validation_store_folder"] + "_" + config["scheduler"], "step_" + str(step), image_name)
else:
validation_store_folder = os.path.join(parent_store_folder, image_name)
if os.path.exists(validation_store_folder):
shutil.rmtree(validation_store_folder)
os.makedirs(validation_store_folder)
image_path = os.path.join(validation_source_folder, image_name, 'im_0.jpg')
ref_image = load_image(image_path) # [0, 255] Range
ref_image = ref_image.resize((config["width"], config["height"]))
# Prepare text prompt
if config["use_text"]:
# Read the file
file_path = os.path.join(validation_source_folder, image_name, "lang.txt")
file = open(file_path, 'r')
prompt = file.readlines()[0] # Only read the first line
if use_ambiguous_prompt:
prompt = prompt.split(" ")[0] + " this to there"
print("We are creating ambiguous prompt, which is: ", prompt)
else:
prompt = ""
# Use the same tokenize process as the dataset preparation stage
tokenized_prompt = tokenize_captions(prompt, tokenizer, config, is_train=False).unsqueeze(0).to(accelerator.device) # Use unsqueeze to expand dim
# Store the prompt for the sanity check
f = open(os.path.join(validation_store_folder, "lang_cond.txt"), "a")
f.write(prompt)
f.close()
# Flip the image by chance (it is needed to check whether there is any object position words [left|right] in the prompt text)
flip = False
if not force_close_flip: # force_close_flip is True in testing time; else, we cannot match in the same standard
if random.random() < config["flip_aug_prob"]:
if config["use_text"]:
if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok)
flip = True
else:
flip = True
if flip:
print("Use flip in validation!")
ref_image = ref_image.transpose(Image.FLIP_LEFT_RIGHT)
if config["data_loader_type"] == "thisthat":
condition_img, reflected_motion_bucket_id, controlnet_image_index, coordinate_values = get_thisthat_sam(config,
os.path.join(validation_source_folder, image_name),
flip = flip,
store_dir = validation_store_folder,
verbose = True)
else:
raise NotImplementedError("We don't support such data loader type")
# Call the pipeline
with torch.autocast("cuda"):
frames = pipeline(
image = ref_image,
condition_img = condition_img, # numpy [0,1] range
controlnet = accelerator.unwrap_model(controlnet),
prompt = tokenized_prompt,
use_text = config["use_text"],
text_encoder = text_encoder,
height = config["height"],
width = config["width"],
num_frames = config["video_seq_length"],
decode_chunk_size = 8,
motion_bucket_id = reflected_motion_bucket_id,
# controlnet_image_index = controlnet_image_index,
# coordinate_values = coordinate_values,
num_inference_steps = config["num_inference_steps"],
max_guidance_scale = config["inference_max_guidance_scale"],
fps = 7,
use_instructpix2pix = config["use_instructpix2pix"],
noise_aug_strength = config["inference_noise_aug_strength"],
controlnet_conditioning_scale = config["outer_conditioning_scale"],
inner_conditioning_scale = config["inner_conditioning_scale"],
guess_mode = config["inference_guess_mode"], # False in inference
image_guidance_scale = config["image_guidance_scale"],
).frames[0]
for idx, frame in enumerate(frames):
frame.save(os.path.join(validation_store_folder, str(idx)+".png"))
imageio.mimsave(os.path.join(validation_store_folder, 'combined.gif'), frames, duration=0.05)
frames_collection.append(frames)
# Cleaning process
del pipeline
torch.cuda.empty_cache()
return frames_collection # Return resuly based on the need
def execute_inference(huggingface_pretrained_path, model_type, validation_path, parent_store_folder, use_ambiguous_prompt):
# Check path
if os.path.exists(parent_store_folder):
shutil.rmtree(parent_store_folder)
os.makedirs(parent_store_folder)
# Read the yaml setting files (Very important for loading hyperparamters needed)
if not os.path.exists(huggingface_pretrained_path):
yaml_download_path = hf_hub_download(repo_id=huggingface_pretrained_path, subfolder="unet", filename="train_image2video.yaml")
if model_type == "GestureNet":
yaml_download_path = hf_hub_download(repo_id=huggingface_pretrained_path, subfolder="gesturenet", filename="train_image2video_gesturenet.yaml")
else: # If the path is a local path we can concatenate it here
yaml_download_path = os.path.join(huggingface_pretrained_path, "unet", "train_image2video.yaml")
if model_type == "GestureNet":
yaml_download_path = os.path.join(huggingface_pretrained_path, "gesturenet", "train_image2video_gesturenet.yaml")
# Load the config
assert(os.path.exists(yaml_download_path))
base_config = OmegaConf.load(yaml_download_path)
# Other Settings
base_config["validation_img_folder"] = validation_path
################################################ Prepare vae, unet, image_encoder Same as before #################################################################
accelerator = Accelerator(
gradient_accumulation_steps = base_config["gradient_accumulation_steps"],
mixed_precision = base_config["mixed_precision"],
log_with = base_config["report_to"],
project_config = ProjectConfiguration(project_dir=base_config["output_dir"], logging_dir=Path(base_config["output_dir"], base_config["logging_name"])),
)
feature_extractor = CLIPImageProcessor.from_pretrained(
base_config["pretrained_model_name_or_path"], subfolder="feature_extractor", revision=None
) # This instance has now weight, they are just seeting file
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
base_config["pretrained_model_name_or_path"], subfolder="image_encoder", revision=None, variant="fp16"
)
vae = AutoencoderKLTemporalDecoder.from_pretrained(
base_config["pretrained_model_name_or_path"], subfolder="vae", revision=None, variant="fp16"
)
unet = UNetSpatioTemporalConditionModel.from_pretrained(
huggingface_pretrained_path,
subfolder = "unet",
low_cpu_mem_usage = True,
# variant = "fp16",
)
print("device we have is ", accelerator.device)
# For text ..............................................
tokenizer = AutoTokenizer.from_pretrained(
base_config["pretrained_tokenizer_name_or_path"],
subfolder = "tokenizer",
revision = None,
use_fast = False,
)
# Clip Text Encoder
text_encoder_cls = import_pretrained_text_encoder(base_config["pretrained_tokenizer_name_or_path"], revision=None)
text_encoder = text_encoder_cls.from_pretrained(base_config["pretrained_tokenizer_name_or_path"], subfolder = "text_encoder", revision = None, variant = None)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move vae + image_encoder to gpu and cast to weight_dtype
vae.requires_grad_(False)
image_encoder.requires_grad_(False)
unet.requires_grad_(False) # Will switch back at the end
text_encoder.requires_grad_(False)
# Move to accelerator
vae.to(accelerator.device, dtype=weight_dtype)
image_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# For GestureNet
if model_type == "GestureNet":
unet.to(accelerator.device, dtype=weight_dtype) # There is no need to cast unet in unet training, only needed in controlnet one
# Handle the Controlnet first from UNet
gesturenet = ControlNetModel.from_pretrained(
huggingface_pretrained_path,
subfolder = "gesturenet",
low_cpu_mem_usage = True,
variant = None,
)
gesturenet.requires_grad_(False)
gesturenet.to(accelerator.device)
##############################################################################################################################################################
############################################################### Execution #####################################################################################
# Prepare the iterative calling
if model_type == "UNet":
generated_frames = unet_inference(
vae, unet, image_encoder, text_encoder, tokenizer,
base_config, accelerator, weight_dtype, step="",
parent_store_folder=parent_store_folder, force_close_flip = True,
use_ambiguous_prompt = use_ambiguous_prompt,
)
elif model_type == "GestureNet":
generated_frames = gesturenet_inference(
vae, unet, gesturenet, image_encoder, text_encoder, tokenizer,
base_config, accelerator, weight_dtype, step="",
parent_store_folder=parent_store_folder, force_close_flip = True,
use_ambiguous_prompt = use_ambiguous_prompt,
)
else:
raise NotImplementedError("model_type is no the predefined choices we provide!")
################################################################################################################################################################
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
parser.add_argument(
"--model_type",
type=str,
default="GestureNet",
help="\"UNet\" for VL (vision language) / \"GestureNet\" for VGL (vision gesture language)",
)
parser.add_argument(
"--huggingface_pretrained_path",
type=str,
default="HikariDawn/This-and-That-1.1",
help="Path to the unet folder path.",
)
parser.add_argument(
"--validation_path",
type=str,
default="__assets__/Bridge_example/",
help="Sample dataset path, default to the Bridge example.",
)
parser.add_argument(
"--parent_store_folder",
type=str,
default="generated_results/",
help="Path to the store result folder.",
)
parser.add_argument(
"--use_ambiguous_prompt",
type=str,
default=False,
help="Whether we will use action verb + \"this to there\" ambgiguous prompt combo.",
)
args = parser.parse_args()
# File Setting
model_type = args.model_type
huggingface_pretrained_path = args.huggingface_pretrained_path
# validation_path Needs to have subforder for each instance.
# Each instance requries "im_0.jpg" for the first image; data.txt for the gesture position; lang.txt for the language
validation_path = args.validation_path
parent_store_folder = args.parent_store_folder
use_ambiguous_prompt = args.use_ambiguous_prompt
# Execution
execute_inference(huggingface_pretrained_path, model_type, validation_path, parent_store_folder, use_ambiguous_prompt)
print("All finished!!!")