hydit22 / inference.py
heziiiii's picture
Upload inference.py
88c74ea verified
raw
history blame
19.4 kB
import random
import time
from pathlib import Path
import numpy as np
import torch
# For reproducibility
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True
from diffusers import schedulers
from diffusers.models import AutoencoderKL
from loguru import logger
from transformers import BertModel, BertTokenizer
from transformers.modeling_utils import logger as tf_logger
from .constants import SAMPLER_FACTORY, NEGATIVE_PROMPT, TRT_MAX_WIDTH, TRT_MAX_HEIGHT, TRT_MAX_BATCH_SIZE
from .diffusion.pipeline import StableDiffusionPipeline
from .modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG
from .modules.posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
from .modules.text_encoder import MT5Embedder
from .utils.tools import set_seeds
from peft import LoraConfig
class Resolution:
def __init__(self, width, height):
self.width = width
self.height = height
def __str__(self):
return f'{self.height}x{self.width}'
class ResolutionGroup:
def __init__(self):
self.data = [
Resolution(1024, 1024), # 1:1
Resolution(1280, 1280), # 1:1
Resolution(1024, 768), # 4:3
Resolution(1152, 864), # 4:3
Resolution(1280, 960), # 4:3
Resolution(768, 1024), # 3:4
Resolution(864, 1152), # 3:4
Resolution(960, 1280), # 3:4
Resolution(1280, 768), # 16:9
Resolution(768, 1280), # 9:16
]
self.supported_sizes = set([(r.width, r.height) for r in self.data])
def is_valid(self, width, height):
return (width, height) in self.supported_sizes
STANDARD_RATIO = np.array([
1.0, # 1:1
4.0 / 3.0, # 4:3
3.0 / 4.0, # 3:4
16.0 / 9.0, # 16:9
9.0 / 16.0, # 9:16
])
STANDARD_SHAPE = [
[(1024, 1024), (1280, 1280)], # 1:1
[(1280, 960)], # 4:3
[(960, 1280)], # 3:4
[(1280, 768)], # 16:9
[(768, 1280)], # 9:16
]
STANDARD_AREA = [
np.array([w * h for w, h in shapes])
for shapes in STANDARD_SHAPE
]
def get_standard_shape(target_width, target_height):
"""
Map image size to standard size.
"""
target_ratio = target_width / target_height
closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
return width, height
def _to_tuple(val):
if isinstance(val, (list, tuple)):
if len(val) == 1:
val = [val[0], val[0]]
elif len(val) == 2:
val = tuple(val)
else:
raise ValueError(f"Invalid value: {val}")
elif isinstance(val, (int, float)):
val = (val, val)
else:
raise ValueError(f"Invalid value: {val}")
return val
def get_pipeline(args, vae, text_encoder, tokenizer, model, device, rank,
embedder_t5, infer_mode, sampler=None):
"""
Get scheduler and pipeline for sampling. The sampler and pipeline are both
based on diffusers and make some modifications.
Returns
-------
pipeline: StableDiffusionPipeline
sampler_name: str
"""
sampler = sampler or args.sampler
# Load sampler from factory
kwargs = SAMPLER_FACTORY[sampler]['kwargs']
scheduler = SAMPLER_FACTORY[sampler]['scheduler']
# Update sampler according to the arguments
kwargs['beta_schedule'] = args.noise_schedule
kwargs['beta_start'] = args.beta_start
kwargs['beta_end'] = args.beta_end
kwargs['prediction_type'] = args.predict_type
# Build scheduler according to the sampler.
scheduler_class = getattr(schedulers, scheduler)
scheduler = scheduler_class(**kwargs)
# Set timesteps for inference steps.
scheduler.set_timesteps(args.infer_steps, device)
# Only enable progress bar for rank 0
progress_bar_config = {} if rank == 0 else {'disable': True}
pipeline = StableDiffusionPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=model,
scheduler=scheduler,
feature_extractor=None,
safety_checker=None,
requires_safety_checker=False,
progress_bar_config=progress_bar_config,
embedder_t5=embedder_t5,
infer_mode=infer_mode,
)
pipeline = pipeline.to(device)
return pipeline, sampler
class End2End(object):
def __init__(self, args, models_root_path):
self.args = args
# Check arguments
t2i_root_path = Path(models_root_path) / "t2i"
self.root = t2i_root_path
logger.info(f"Got text-to-image model root path: {t2i_root_path}")
# Set device and disable gradient
self.device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)
# Disable BertModel logging checkpoint info
tf_logger.setLevel('ERROR')
# ========================================================================
logger.info(f"Loading CLIP Text Encoder...")
text_encoder_path = self.root / "clip_text_encoder"
self.clip_text_encoder = BertModel.from_pretrained(str(text_encoder_path), False, revision=None).to(self.device)
logger.info(f"Loading CLIP Text Encoder finished")
# ========================================================================
logger.info(f"Loading CLIP Tokenizer...")
tokenizer_path = self.root / "tokenizer"
self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path))
logger.info(f"Loading CLIP Tokenizer finished")
# ========================================================================
logger.info(f"Loading T5 Text Encoder and T5 Tokenizer...")
t5_text_encoder_path = self.root / 'mt5'
embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256)
self.embedder_t5 = embedder_t5
logger.info(f"Loading t5_text_encoder and t5_tokenizer finished")
# ========================================================================
logger.info(f"Loading VAE...")
vae_path = self.root / "sdxl-vae-fp16-fix"
self.vae = AutoencoderKL.from_pretrained(str(vae_path)).to(self.device)
logger.info(f"Loading VAE finished")
# ========================================================================
# Create model structure and load the checkpoint
logger.info(f"Building HunYuan-DiT model...")
model_config = HUNYUAN_DIT_CONFIG[self.args.model]
self.patch_size = model_config['patch_size']
self.head_size = model_config['hidden_size'] // model_config['num_heads']
self.resolutions, self.freqs_cis_img = self.standard_shapes() # Used for TensorRT models
self.image_size = _to_tuple(self.args.image_size)
latent_size = (self.image_size[0] // 8, self.image_size[1] // 8)
self.infer_mode = self.args.infer_mode
if self.infer_mode in ['fa', 'torch']:
# # for trained pt
# model_path = Path("/home1/qbs/my_program1/HunyuanDiT/log_EXP/024-dit_g2_full_1024p/checkpoints/0100000.pt/mp_rank_00_model_states.pt")
# if not model_path.exists():
# raise ValueError(f"model_path not exists: {model_path}")
# # Build model structure
# self.model = HunYuanDiT(self.args,
# input_size=latent_size,
# **model_config,
# log_fn=logger.info,
# ).half().to(self.device) # Force to use fp16
# # Load model checkpoint
# logger.info(f"Loading torch model {model_path}...")
# state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
# self.model.load_state_dict(state_dict["module"])
# for ema trained pt
model_path = Path("/home1/qbs/my_program1/HunyuanDiT/log_EXP/027-dit_g2_full_1024p/checkpoints/latest.pt/mp_rank_00_model_states.pt")
if not model_path.exists():
raise ValueError(f"model_path not exists: {model_path}")
# Build model structure
self.model = HunYuanDiT(self.args,
input_size=latent_size,
**model_config,
log_fn=logger.info,
).half().to(self.device) # Force to use fp16
# Load model checkpoint
logger.info(f"Loading torch model {model_path}...")
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
self.model.load_state_dict(state_dict["ema"])
# #original
# model_dir = self.root / "model"
# model_path = model_dir / f"pytorch_model_{self.args.load_key}.pt"
# if not model_path.exists():
# raise ValueError(f"model_path not exists: {model_path}")
# # Build model structure
# self.model = HunYuanDiT(self.args,
# input_size=latent_size,
# **model_config,
# log_fn=logger.info,
# ).half().to(self.device) # Force to use fp16
# # Load model checkpoint
# logger.info(f"Loading torch model {model_path}...")
# state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
# self.model.load_state_dict(state_dict)
lora_ckpt = args.lora_ckpt
if lora_ckpt is not None and lora_ckpt != "":
logger.info(f"Loading Lora checkpoint {lora_ckpt}...")
self.model.load_adapter(lora_ckpt)
self.model.merge_and_unload()
self.model.eval()
logger.info(f"Loading torch model finished")
elif self.infer_mode == 'trt':
from .modules.trt.hcf_model import TRTModel
trt_dir = self.root / "model_trt"
engine_dir = trt_dir / "engine"
plugin_path = trt_dir / "fmha_plugins/9.2_plugin_cuda11/fMHAPlugin.so"
model_name = "model_onnx"
logger.info(f"Loading TensorRT model {engine_dir}/{model_name}...")
self.model = TRTModel(model_name=model_name,
engine_dir=str(engine_dir),
image_height=TRT_MAX_HEIGHT,
image_width=TRT_MAX_WIDTH,
text_maxlen=args.text_len,
embedding_dim=args.text_states_dim,
plugin_path=str(plugin_path),
max_batch_size=TRT_MAX_BATCH_SIZE,
)
logger.info(f"Loading TensorRT model finished")
else:
raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
# ========================================================================
# Build inference pipeline. We use a customized StableDiffusionPipeline.
logger.info(f"Loading inference pipeline...")
self.pipeline, self.sampler = self.load_sampler()
logger.info(f'Loading pipeline finished')
# ========================================================================
self.default_negative_prompt = NEGATIVE_PROMPT
logger.info("==================================================")
logger.info(f" Model is ready. ")
logger.info("==================================================")
def load_sampler(self, sampler=None):
pipeline, sampler = get_pipeline(self.args,
self.vae,
self.clip_text_encoder,
self.tokenizer,
self.model,
device=self.device,
rank=0,
embedder_t5=self.embedder_t5,
infer_mode=self.infer_mode,
sampler=sampler,
)
return pipeline, sampler
def calc_rope(self, height, width):
th = height // 8 // self.patch_size
tw = width // 8 // self.patch_size
base_size = 512 // 8 // self.patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
rope = get_2d_rotary_pos_embed(self.head_size, *sub_args)
return rope
def standard_shapes(self):
resolutions = ResolutionGroup()
freqs_cis_img = {}
for reso in resolutions.data:
freqs_cis_img[str(reso)] = self.calc_rope(reso.height, reso.width)
return resolutions, freqs_cis_img
def predict(self,
user_prompt,
height=1024,
width=1024,
seed=None,
enhanced_prompt=None,
negative_prompt=None,
infer_steps=100,
guidance_scale=6,
batch_size=1,
src_size_cond=(1024, 1024),
sampler=None,
):
# ========================================================================
# Arguments: seed
# ========================================================================
if seed is None:
seed = random.randint(0, 1_000_000)
if not isinstance(seed, int):
raise TypeError(f"`seed` must be an integer, but got {type(seed)}")
generator = set_seeds(seed, device=self.device)
# ========================================================================
# Arguments: target_width, target_height
# ========================================================================
if width <= 0 or height <= 0:
raise ValueError(f"`height` and `width` must be positive integers, got height={height}, width={width}")
logger.info(f"Input (height, width) = ({height}, {width})")
if self.infer_mode in ['fa', 'torch']:
# We must force height and width to align to 16 and to be an integer.
target_height = int((height // 16) * 16)
target_width = int((width // 16) * 16)
logger.info(f"Align to 16: (height, width) = ({target_height}, {target_width})")
elif self.infer_mode == 'trt':
target_width, target_height = get_standard_shape(width, height)
logger.info(f"Align to standard shape: (height, width) = ({target_height}, {target_width})")
else:
raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
# ========================================================================
# Arguments: prompt, new_prompt, negative_prompt
# ========================================================================
if not isinstance(user_prompt, str):
raise TypeError(f"`user_prompt` must be a string, but got {type(user_prompt)}")
user_prompt = user_prompt.strip()
prompt = user_prompt
if enhanced_prompt is not None:
if not isinstance(enhanced_prompt, str):
raise TypeError(f"`enhanced_prompt` must be a string, but got {type(enhanced_prompt)}")
enhanced_prompt = enhanced_prompt.strip()
prompt = enhanced_prompt
# negative prompt
if negative_prompt is None or negative_prompt == '':
negative_prompt = self.default_negative_prompt
if not isinstance(negative_prompt, str):
raise TypeError(f"`negative_prompt` must be a string, but got {type(negative_prompt)}")
# ========================================================================
# Arguments: style. (A fixed argument. Don't Change it.)
# ========================================================================
style = torch.as_tensor([0, 0] * batch_size, device=self.device)
# ========================================================================
# Inner arguments: image_meta_size (Please refer to SDXL.)
# ========================================================================
if isinstance(src_size_cond, int):
src_size_cond = [src_size_cond, src_size_cond]
if not isinstance(src_size_cond, (list, tuple)):
raise TypeError(f"`src_size_cond` must be a list or tuple, but got {type(src_size_cond)}")
if len(src_size_cond) != 2:
raise ValueError(f"`src_size_cond` must be a tuple of 2 integers, but got {len(src_size_cond)}")
size_cond = list(src_size_cond) + [target_width, target_height, 0, 0]
image_meta_size = torch.as_tensor([size_cond] * 2 * batch_size, device=self.device)
# ========================================================================
start_time = time.time()
logger.debug(f"""
prompt: {user_prompt}
enhanced prompt: {enhanced_prompt}
seed: {seed}
(height, width): {(target_height, target_width)}
negative_prompt: {negative_prompt}
batch_size: {batch_size}
guidance_scale: {guidance_scale}
infer_steps: {infer_steps}
image_meta_size: {size_cond}
""")
reso = f'{target_height}x{target_width}'
if reso in self.freqs_cis_img:
freqs_cis_img = self.freqs_cis_img[reso]
else:
freqs_cis_img = self.calc_rope(target_height, target_width)
if sampler is not None and sampler != self.sampler:
self.pipeline, self.sampler = self.load_sampler(sampler)
samples = self.pipeline(
height=target_height,
width=target_width,
prompt=prompt,
negative_prompt=negative_prompt,
num_images_per_prompt=batch_size,
guidance_scale=guidance_scale,
num_inference_steps=infer_steps,
image_meta_size=image_meta_size,
style=style,
return_dict=False,
generator=generator,
freqs_cis_img=freqs_cis_img,
use_fp16=self.args.use_fp16,
learn_sigma=self.args.learn_sigma,
)[0]
gen_time = time.time() - start_time
logger.debug(f"Success, time: {gen_time}")
return {
'images': samples,
'seed': seed,
}