Spaces:
Running
on
Zero
Running
on
Zero
''' | |
This file is to test UNet and GestureNet. | |
''' | |
import os, shutil, sys | |
import urllib.request | |
import argparse | |
import imageio | |
import math | |
import cv2 | |
from PIL import Image | |
import collections | |
import numpy as np | |
import torch | |
from pathlib import Path | |
from omegaconf import OmegaConf | |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
from accelerate import Accelerator | |
from accelerate.utils import ProjectConfiguration | |
from diffusers import ( | |
AutoencoderKLTemporalDecoder, | |
DDPMScheduler, | |
) | |
from diffusers.utils import check_min_version, is_wandb_available, load_image, export_to_video | |
from huggingface_hub import hf_hub_download | |
from transformers import AutoTokenizer, PretrainedConfig | |
# Import files from the local folder | |
root_path = os.path.abspath('.') | |
sys.path.append(root_path) | |
from train_code.train_svd import import_pretrained_text_encoder | |
from data_loader.video_dataset import tokenize_captions | |
from data_loader.video_this_that_dataset import get_thisthat_sam | |
from svd.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel | |
from svd.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline | |
from svd.temporal_controlnet import ControlNetModel | |
from svd.pipeline_stable_video_diffusion_controlnet import StableVideoDiffusionControlNetPipeline | |
# Seed | |
# torch.manual_seed(42) | |
# np.random.seed(42) | |
def unet_inference(vae, unet, image_encoder, text_encoder, tokenizer, config, accelerator, weight_dtype, step, | |
parent_store_folder = None, force_close_flip = False, use_ambiguous_prompt=False): | |
# Init | |
validation_source_folder = config["validation_img_folder"] | |
# Init the pipeline | |
pipeline = StableVideoDiffusionPipeline.from_pretrained( | |
config["pretrained_model_name_or_path"], | |
vae = accelerator.unwrap_model(vae), | |
image_encoder = accelerator.unwrap_model(image_encoder), | |
unet = accelerator.unwrap_model(unet), | |
revision = None, # Set None directly now | |
torch_dtype = weight_dtype, | |
) | |
pipeline = pipeline.to(accelerator.device) | |
pipeline.set_progress_bar_config(disable=True) | |
# Process all image in the folder | |
frames_collection = [] | |
for image_name in sorted(os.listdir(validation_source_folder)): | |
if accelerator.is_main_process: | |
if parent_store_folder is None: | |
validation_store_folder = os.path.join(config["validation_store_folder"] + "_" + config["scheduler"], "step_" + str(step), image_name) | |
else: | |
validation_store_folder = os.path.join(parent_store_folder, image_name) | |
if os.path.exists(validation_store_folder): | |
shutil.rmtree(validation_store_folder) | |
os.makedirs(validation_store_folder) | |
image_path = os.path.join(validation_source_folder, image_name, 'im_0.jpg') | |
ref_image = load_image(image_path) | |
ref_image = ref_image.resize((config["width"], config["height"])) | |
# Decide the motion score in SVD (mostly what we use is fix value now) | |
if config["motion_bucket_id"] is None: | |
raise NotImplementedError("We need a fixed motion_bucket_id in the config") | |
else: | |
reflected_motion_bucket_id = config["motion_bucket_id"] | |
# print("Inference Motion Bucket ID is ", reflected_motion_bucket_id) | |
# Prepare text prompt | |
if config["use_text"]: | |
# Read the file | |
file_path = os.path.join(validation_source_folder, image_name, "lang.txt") | |
file = open(file_path, 'r') | |
prompt = file.readlines()[0] # Only read the first line | |
if use_ambiguous_prompt: | |
prompt = prompt.split(" ")[0] + " this to there" | |
print("We are creating ambiguous prompt, which is: ", prompt) | |
else: | |
prompt = "" | |
# Use the same tokenize process as the dataset preparation stage | |
tokenized_prompt = tokenize_captions(prompt, tokenizer, config, is_train=False).unsqueeze(0).to(accelerator.device) # Use unsqueeze to expand dim | |
# Store the prompt for the sanity check | |
f = open(os.path.join(validation_store_folder, "lang_cond.txt"), "a") | |
f.write(prompt) | |
f.close() | |
# Flip the image by chance (it is needed to check whether there is any object position words [left|right] in the prompt text) | |
flip = False | |
if not force_close_flip: # force_close_flip is True in testing time; else, we cannot match in the same standard | |
if random.random() < config["flip_aug_prob"]: | |
if config["use_text"]: | |
if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok) | |
flip = True | |
else: | |
flip = True | |
if flip: | |
print("Use flip in validation!") | |
ref_image = ref_image.transpose(Image.FLIP_LEFT_RIGHT) | |
# Call the model for inference | |
with torch.autocast("cuda"): | |
frames = pipeline( | |
ref_image, | |
tokenized_prompt, | |
config["use_text"], | |
text_encoder, | |
height = config["height"], | |
width = config["width"], | |
num_frames = config["video_seq_length"], | |
num_inference_steps = config["num_inference_steps"], | |
decode_chunk_size = 8, | |
motion_bucket_id = reflected_motion_bucket_id, | |
fps = 7, | |
noise_aug_strength = config["inference_noise_aug_strength"], | |
).frames[0] | |
# Store the frames | |
# breakpoint() | |
for idx, frame in enumerate(frames): | |
frame.save(os.path.join(validation_store_folder, str(idx)+".png")) | |
imageio.mimsave(os.path.join(validation_store_folder, 'combined.gif'), frames) # gif storage quality is not high, recommend to check png images | |
frames_collection.append(frames) | |
# Cleaning process | |
del pipeline | |
torch.cuda.empty_cache() | |
return frames_collection # Return resuly based on the need | |
def gesturenet_inference(vae, unet, controlnet, image_encoder, text_encoder, tokenizer, config, accelerator, weight_dtype, step, | |
parent_store_folder=None, force_close_flip=False, use_ambiguous_prompt=False): | |
# Init | |
validation_source_folder = config["validation_img_folder"] | |
# Init the pipeline | |
pipeline = StableVideoDiffusionControlNetPipeline.from_pretrained( | |
config["pretrained_model_name_or_path"], # Still based on regular SVD config | |
vae = vae, | |
image_encoder = image_encoder, | |
unet = unet, | |
revision = None, # Set None directly now | |
torch_dtype = weight_dtype, | |
) | |
pipeline = pipeline.to(accelerator.device) | |
pipeline.set_progress_bar_config(disable=True) | |
# Process all image in the folder | |
frames_collection = [] | |
for image_name in sorted(os.listdir(validation_source_folder)): | |
if accelerator.is_main_process: | |
if parent_store_folder is None: | |
validation_store_folder = os.path.join(config["validation_store_folder"] + "_" + config["scheduler"], "step_" + str(step), image_name) | |
else: | |
validation_store_folder = os.path.join(parent_store_folder, image_name) | |
if os.path.exists(validation_store_folder): | |
shutil.rmtree(validation_store_folder) | |
os.makedirs(validation_store_folder) | |
image_path = os.path.join(validation_source_folder, image_name, 'im_0.jpg') | |
ref_image = load_image(image_path) # [0, 255] Range | |
ref_image = ref_image.resize((config["width"], config["height"])) | |
# Prepare text prompt | |
if config["use_text"]: | |
# Read the file | |
file_path = os.path.join(validation_source_folder, image_name, "lang.txt") | |
file = open(file_path, 'r') | |
prompt = file.readlines()[0] # Only read the first line | |
if use_ambiguous_prompt: | |
prompt = prompt.split(" ")[0] + " this to there" | |
print("We are creating ambiguous prompt, which is: ", prompt) | |
else: | |
prompt = "" | |
# Use the same tokenize process as the dataset preparation stage | |
tokenized_prompt = tokenize_captions(prompt, tokenizer, config, is_train=False).unsqueeze(0).to(accelerator.device) # Use unsqueeze to expand dim | |
# Store the prompt for the sanity check | |
f = open(os.path.join(validation_store_folder, "lang_cond.txt"), "a") | |
f.write(prompt) | |
f.close() | |
# Flip the image by chance (it is needed to check whether there is any object position words [left|right] in the prompt text) | |
flip = False | |
if not force_close_flip: # force_close_flip is True in testing time; else, we cannot match in the same standard | |
if random.random() < config["flip_aug_prob"]: | |
if config["use_text"]: | |
if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok) | |
flip = True | |
else: | |
flip = True | |
if flip: | |
print("Use flip in validation!") | |
ref_image = ref_image.transpose(Image.FLIP_LEFT_RIGHT) | |
if config["data_loader_type"] == "thisthat": | |
condition_img, reflected_motion_bucket_id, controlnet_image_index, coordinate_values = get_thisthat_sam(config, | |
os.path.join(validation_source_folder, image_name), | |
flip = flip, | |
store_dir = validation_store_folder, | |
verbose = True) | |
else: | |
raise NotImplementedError("We don't support such data loader type") | |
# Call the pipeline | |
with torch.autocast("cuda"): | |
frames = pipeline( | |
image = ref_image, | |
condition_img = condition_img, # numpy [0,1] range | |
controlnet = accelerator.unwrap_model(controlnet), | |
prompt = tokenized_prompt, | |
use_text = config["use_text"], | |
text_encoder = text_encoder, | |
height = config["height"], | |
width = config["width"], | |
num_frames = config["video_seq_length"], | |
decode_chunk_size = 8, | |
motion_bucket_id = reflected_motion_bucket_id, | |
# controlnet_image_index = controlnet_image_index, | |
# coordinate_values = coordinate_values, | |
num_inference_steps = config["num_inference_steps"], | |
max_guidance_scale = config["inference_max_guidance_scale"], | |
fps = 7, | |
use_instructpix2pix = config["use_instructpix2pix"], | |
noise_aug_strength = config["inference_noise_aug_strength"], | |
controlnet_conditioning_scale = config["outer_conditioning_scale"], | |
inner_conditioning_scale = config["inner_conditioning_scale"], | |
guess_mode = config["inference_guess_mode"], # False in inference | |
image_guidance_scale = config["image_guidance_scale"], | |
).frames[0] | |
for idx, frame in enumerate(frames): | |
frame.save(os.path.join(validation_store_folder, str(idx)+".png")) | |
imageio.mimsave(os.path.join(validation_store_folder, 'combined.gif'), frames, duration=0.05) | |
frames_collection.append(frames) | |
# Cleaning process | |
del pipeline | |
torch.cuda.empty_cache() | |
return frames_collection # Return resuly based on the need | |
def execute_inference(huggingface_pretrained_path, model_type, validation_path, parent_store_folder, use_ambiguous_prompt): | |
# Check path | |
if os.path.exists(parent_store_folder): | |
shutil.rmtree(parent_store_folder) | |
os.makedirs(parent_store_folder) | |
# Read the yaml setting files (Very important for loading hyperparamters needed) | |
if not os.path.exists(huggingface_pretrained_path): | |
yaml_download_path = hf_hub_download(repo_id=huggingface_pretrained_path, subfolder="unet", filename="train_image2video.yaml") | |
if model_type == "GestureNet": | |
yaml_download_path = hf_hub_download(repo_id=huggingface_pretrained_path, subfolder="gesturenet", filename="train_image2video_gesturenet.yaml") | |
else: # If the path is a local path we can concatenate it here | |
yaml_download_path = os.path.join(huggingface_pretrained_path, "unet", "train_image2video.yaml") | |
if model_type == "GestureNet": | |
yaml_download_path = os.path.join(huggingface_pretrained_path, "gesturenet", "train_image2video_gesturenet.yaml") | |
# Load the config | |
assert(os.path.exists(yaml_download_path)) | |
base_config = OmegaConf.load(yaml_download_path) | |
# Other Settings | |
base_config["validation_img_folder"] = validation_path | |
################################################ Prepare vae, unet, image_encoder Same as before ################################################################# | |
accelerator = Accelerator( | |
gradient_accumulation_steps = base_config["gradient_accumulation_steps"], | |
mixed_precision = base_config["mixed_precision"], | |
log_with = base_config["report_to"], | |
project_config = ProjectConfiguration(project_dir=base_config["output_dir"], logging_dir=Path(base_config["output_dir"], base_config["logging_name"])), | |
) | |
feature_extractor = CLIPImageProcessor.from_pretrained( | |
base_config["pretrained_model_name_or_path"], subfolder="feature_extractor", revision=None | |
) # This instance has now weight, they are just seeting file | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
base_config["pretrained_model_name_or_path"], subfolder="image_encoder", revision=None, variant="fp16" | |
) | |
vae = AutoencoderKLTemporalDecoder.from_pretrained( | |
base_config["pretrained_model_name_or_path"], subfolder="vae", revision=None, variant="fp16" | |
) | |
unet = UNetSpatioTemporalConditionModel.from_pretrained( | |
huggingface_pretrained_path, | |
subfolder = "unet", | |
low_cpu_mem_usage = True, | |
# variant = "fp16", | |
) | |
print("device we have is ", accelerator.device) | |
# For text .............................................. | |
tokenizer = AutoTokenizer.from_pretrained( | |
base_config["pretrained_tokenizer_name_or_path"], | |
subfolder = "tokenizer", | |
revision = None, | |
use_fast = False, | |
) | |
# Clip Text Encoder | |
text_encoder_cls = import_pretrained_text_encoder(base_config["pretrained_tokenizer_name_or_path"], revision=None) | |
text_encoder = text_encoder_cls.from_pretrained(base_config["pretrained_tokenizer_name_or_path"], subfolder = "text_encoder", revision = None, variant = None) | |
weight_dtype = torch.float32 | |
if accelerator.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif accelerator.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
# Move vae + image_encoder to gpu and cast to weight_dtype | |
vae.requires_grad_(False) | |
image_encoder.requires_grad_(False) | |
unet.requires_grad_(False) # Will switch back at the end | |
text_encoder.requires_grad_(False) | |
# Move to accelerator | |
vae.to(accelerator.device, dtype=weight_dtype) | |
image_encoder.to(accelerator.device, dtype=weight_dtype) | |
text_encoder.to(accelerator.device, dtype=weight_dtype) | |
# For GestureNet | |
if model_type == "GestureNet": | |
unet.to(accelerator.device, dtype=weight_dtype) # There is no need to cast unet in unet training, only needed in controlnet one | |
# Handle the Controlnet first from UNet | |
gesturenet = ControlNetModel.from_pretrained( | |
huggingface_pretrained_path, | |
subfolder = "gesturenet", | |
low_cpu_mem_usage = True, | |
variant = None, | |
) | |
gesturenet.requires_grad_(False) | |
gesturenet.to(accelerator.device) | |
############################################################################################################################################################## | |
############################################################### Execution ##################################################################################### | |
# Prepare the iterative calling | |
if model_type == "UNet": | |
generated_frames = unet_inference( | |
vae, unet, image_encoder, text_encoder, tokenizer, | |
base_config, accelerator, weight_dtype, step="", | |
parent_store_folder=parent_store_folder, force_close_flip = True, | |
use_ambiguous_prompt = use_ambiguous_prompt, | |
) | |
elif model_type == "GestureNet": | |
generated_frames = gesturenet_inference( | |
vae, unet, gesturenet, image_encoder, text_encoder, tokenizer, | |
base_config, accelerator, weight_dtype, step="", | |
parent_store_folder=parent_store_folder, force_close_flip = True, | |
use_ambiguous_prompt = use_ambiguous_prompt, | |
) | |
else: | |
raise NotImplementedError("model_type is no the predefined choices we provide!") | |
################################################################################################################################################################ | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") | |
parser.add_argument( | |
"--model_type", | |
type=str, | |
default="GestureNet", | |
help="\"UNet\" for VL (vision language) / \"GestureNet\" for VGL (vision gesture language)", | |
) | |
parser.add_argument( | |
"--huggingface_pretrained_path", | |
type=str, | |
default="HikariDawn/This-and-That-1.1", | |
help="Path to the unet folder path.", | |
) | |
parser.add_argument( | |
"--validation_path", | |
type=str, | |
default="__assets__/Bridge_example/", | |
help="Sample dataset path, default to the Bridge example.", | |
) | |
parser.add_argument( | |
"--parent_store_folder", | |
type=str, | |
default="generated_results/", | |
help="Path to the store result folder.", | |
) | |
parser.add_argument( | |
"--use_ambiguous_prompt", | |
type=str, | |
default=False, | |
help="Whether we will use action verb + \"this to there\" ambgiguous prompt combo.", | |
) | |
args = parser.parse_args() | |
# File Setting | |
model_type = args.model_type | |
huggingface_pretrained_path = args.huggingface_pretrained_path | |
# validation_path Needs to have subforder for each instance. | |
# Each instance requries "im_0.jpg" for the first image; data.txt for the gesture position; lang.txt for the language | |
validation_path = args.validation_path | |
parent_store_folder = args.parent_store_folder | |
use_ambiguous_prompt = args.use_ambiguous_prompt | |
# Execution | |
execute_inference(huggingface_pretrained_path, model_type, validation_path, parent_store_folder, use_ambiguous_prompt) | |
print("All finished!!!") | |