|
from enum import Enum |
|
import gc |
|
import numpy as np |
|
|
|
import torch |
|
import decord |
|
from diffusers import StableDiffusionInstructPix2PixPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UNet2DConditionModel |
|
from diffusers.schedulers import EulerAncestralDiscreteScheduler, DDIMScheduler |
|
from text_to_video.text_to_video_pipeline import TextToVideoPipeline |
|
|
|
import utils |
|
import gradio_utils |
|
|
|
decord.bridge.set_bridge('torch') |
|
|
|
|
|
class ModelType(Enum): |
|
Pix2Pix_Video = 1, |
|
Text2Video = 2, |
|
ControlNetCanny = 3, |
|
ControlNetCannyDB = 4, |
|
ControlNetPose = 5, |
|
|
|
|
|
class Model: |
|
def __init__(self, device, dtype, **kwargs): |
|
self.device = device |
|
self.dtype = dtype |
|
self.generator = torch.Generator(device=device) |
|
self.pipe_dict = { |
|
ModelType.Pix2Pix_Video: StableDiffusionInstructPix2PixPipeline, |
|
ModelType.Text2Video: TextToVideoPipeline, |
|
ModelType.ControlNetCanny: StableDiffusionControlNetPipeline, |
|
ModelType.ControlNetCannyDB: StableDiffusionControlNetPipeline, |
|
ModelType.ControlNetPose: StableDiffusionControlNetPipeline, |
|
} |
|
self.controlnet_attn_proc = utils.CrossFrameAttnProcessor(unet_chunk_size=2) |
|
self.pix2pix_attn_proc = utils.CrossFrameAttnProcessor(unet_chunk_size=3) |
|
self.text2video_attn_proc = utils.CrossFrameAttnProcessor(unet_chunk_size=2) |
|
|
|
self.pipe = None |
|
self.model_type = None |
|
|
|
self.states = {} |
|
|
|
def set_model(self, model_type: ModelType, model_id: str, **kwargs): |
|
if self.pipe is not None: |
|
del self.pipe |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
safety_checker = kwargs.pop('safety_checker', None) |
|
self.pipe = self.pipe_dict[model_type].from_pretrained(model_id, safety_checker=safety_checker, **kwargs).to(self.device).to(self.dtype) |
|
self.model_type = model_type |
|
|
|
def inference_chunk(self, frame_ids, **kwargs): |
|
if self.pipe is None: |
|
return |
|
image = kwargs.pop('image') |
|
prompt = np.array(kwargs.pop('prompt')) |
|
negative_prompt = np.array(kwargs.pop('negative_prompt', '')) |
|
latents = None |
|
if 'latents' in kwargs: |
|
latents = kwargs.pop('latents')[frame_ids] |
|
return self.pipe(image=image[frame_ids], |
|
prompt=prompt[frame_ids].tolist(), |
|
negative_prompt=negative_prompt[frame_ids].tolist(), |
|
latents=latents, |
|
generator=self.generator, |
|
**kwargs) |
|
|
|
def inference(self, split_to_chunks=False, chunk_size=8, **kwargs): |
|
if self.pipe is None: |
|
return |
|
seed = kwargs.pop('seed', 0) |
|
kwargs.pop('generator', '') |
|
|
|
if split_to_chunks: |
|
assert 'image' in kwargs |
|
assert 'prompt' in kwargs |
|
image = kwargs.pop('image') |
|
prompt = kwargs.pop('prompt') |
|
negative_prompt = kwargs.pop('negative_prompt', '') |
|
f = image.shape[0] |
|
chunk_ids = np.arange(0, f, chunk_size - 1) |
|
result = [] |
|
for i in range(len(chunk_ids)): |
|
ch_start = chunk_ids[i] |
|
ch_end = f if i == len(chunk_ids) - 1 else chunk_ids[i + 1] |
|
frame_ids = [0] + list(range(ch_start, ch_end)) |
|
self.generator.manual_seed(seed) |
|
print(f'Processing chunk {i + 1} / {len(chunk_ids)}') |
|
result.append(self.inference_chunk(frame_ids=frame_ids, |
|
image=image, |
|
prompt=[prompt] * f, |
|
negative_prompt=[negative_prompt] * f, |
|
**kwargs).images[1:]) |
|
result = np.concatenate(result) |
|
return result |
|
else: |
|
return self.pipe(generator=self.generator, **kwargs).videos[0] |
|
|
|
def process_controlnet_canny(self, |
|
video_path, |
|
prompt, |
|
num_inference_steps=20, |
|
controlnet_conditioning_scale=1.0, |
|
guidance_scale=9.0, |
|
seed=42, |
|
eta=0.0, |
|
low_threshold=100, |
|
high_threshold=200, |
|
resolution=512): |
|
video_path = gradio_utils.edge_path_to_video_path(video_path) |
|
if self.model_type != ModelType.ControlNetCanny: |
|
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") |
|
self.set_model(ModelType.ControlNetCanny, model_id="runwayml/stable-diffusion-v1-5", controlnet=controlnet) |
|
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe.unet.set_attn_processor(processor=self.controlnet_attn_proc) |
|
self.pipe.controlnet.set_attn_processor(processor=self.controlnet_attn_proc) |
|
|
|
|
|
added_prompt = 'best quality, extremely detailed' |
|
negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' |
|
|
|
video, fps = utils.prepare_video(video_path, resolution, self.device, self.dtype, False, start_t=0, end_t=15) |
|
control = utils.pre_process_canny(video, low_threshold, high_threshold).to(self.device).to(self.dtype) |
|
f, _, h, w = video.shape |
|
self.generator.manual_seed(seed) |
|
latents = torch.randn((1, 4, h//8, w//8), dtype=self.dtype, device=self.device, generator=self.generator) |
|
latents = latents.repeat(f, 1, 1, 1) |
|
result = self.inference(image=control, |
|
prompt=prompt + ', ' + added_prompt, |
|
height=h, |
|
width=w, |
|
negative_prompt=negative_prompts, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale, |
|
eta=eta, |
|
latents=latents, |
|
seed=seed, |
|
output_type='numpy', |
|
split_to_chunks=True, |
|
chunk_size=8, |
|
) |
|
return utils.create_video(result, fps) |
|
|
|
def process_controlnet_pose(self, |
|
video_path, |
|
prompt, |
|
num_inference_steps=20, |
|
controlnet_conditioning_scale=1.0, |
|
guidance_scale=9.0, |
|
seed=42, |
|
eta=0.0, |
|
resolution=512): |
|
video_path = gradio_utils.motion_to_video_path(video_path) |
|
if self.model_type != ModelType.ControlNetPose: |
|
controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-openpose") |
|
self.set_model(ModelType.ControlNetPose, model_id="runwayml/stable-diffusion-v1-5", controlnet=controlnet) |
|
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe.unet.set_attn_processor(processor=self.controlnet_attn_proc) |
|
self.pipe.controlnet.set_attn_processor(processor=self.controlnet_attn_proc) |
|
|
|
added_prompt = 'best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth' |
|
negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic' |
|
|
|
video, fps = utils.prepare_video(video_path, resolution, self.device, self.dtype, False, output_fps=4) |
|
control = utils.pre_process_pose(video, apply_pose_detect=False).to(self.device).to(self.dtype) |
|
f, _, h, w = video.shape |
|
self.generator.manual_seed(seed) |
|
latents = torch.randn((1, 4, h//8, w//8), dtype=self.dtype, device=self.device, generator=self.generator) |
|
latents = latents.repeat(f, 1, 1, 1) |
|
result = self.inference(image=control, |
|
prompt=prompt + ', ' + added_prompt, |
|
height=h, |
|
width=w, |
|
negative_prompt=negative_prompts, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale, |
|
eta=eta, |
|
latents=latents, |
|
seed=seed, |
|
output_type='numpy', |
|
split_to_chunks=True, |
|
chunk_size=8, |
|
) |
|
return utils.create_gif(result, fps) |
|
|
|
|
|
def process_controlnet_canny_db(self, |
|
db_path, |
|
video_path, |
|
prompt, |
|
num_inference_steps=20, |
|
controlnet_conditioning_scale=1.0, |
|
guidance_scale=9.0, |
|
seed=42, |
|
eta=0.0, |
|
low_threshold=100, |
|
high_threshold=200, |
|
resolution=512): |
|
db_path = gradio_utils.get_model_from_db_selection(db_path) |
|
video_path = gradio_utils.get_video_from_canny_selection(video_path) |
|
|
|
if 'db_path' not in self.states or db_path != self.states['db_path']: |
|
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") |
|
self.set_model(ModelType.ControlNetCannyDB, model_id=db_path, controlnet=controlnet) |
|
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe.unet.set_attn_processor(processor=self.controlnet_attn_proc) |
|
self.pipe.controlnet.set_attn_processor(processor=self.controlnet_attn_proc) |
|
self.states['db_path'] = db_path |
|
|
|
added_prompt = 'best quality, extremely detailed' |
|
negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' |
|
|
|
video, fps = utils.prepare_video(video_path, resolution, self.device, self.dtype, False) |
|
control = utils.pre_process_canny(video, low_threshold, high_threshold).to(self.device).to(self.dtype) |
|
f, _, h, w = video.shape |
|
self.generator.manual_seed(seed) |
|
latents = torch.randn((1, 4, h//8, w//8), dtype=self.dtype, device=self.device, generator=self.generator) |
|
latents = latents.repeat(f, 1, 1, 1) |
|
result = self.inference(image=control, |
|
prompt=prompt + ', ' + added_prompt, |
|
height=h, |
|
width=w, |
|
negative_prompt=negative_prompts, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale, |
|
eta=eta, |
|
latents=latents, |
|
seed=seed, |
|
output_type='numpy', |
|
split_to_chunks=True, |
|
chunk_size=8, |
|
) |
|
return utils.create_gif(result, fps) |
|
|
|
def process_pix2pix(self, video, prompt, resolution=512, seed=0, start_t=0, end_t=-1, out_fps=-1): |
|
end_t = start_t+15 |
|
if self.model_type != ModelType.Pix2Pix_Video: |
|
self.set_model(ModelType.Pix2Pix_Video, model_id="timbrooks/instruct-pix2pix") |
|
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe.unet.set_attn_processor(processor=self.pix2pix_attn_proc) |
|
video, fps = utils.prepare_video(video, resolution, self.device, self.dtype, True, start_t, end_t, out_fps) |
|
self.generator.manual_seed(seed) |
|
result = self.inference(image=video, |
|
prompt=prompt, |
|
seed=seed, |
|
output_type='numpy', |
|
num_inference_steps=50, |
|
image_guidance_scale=1.5, |
|
split_to_chunks=True, |
|
chunk_size=8, |
|
) |
|
return utils.create_video(result, fps) |
|
|
|
def process_text2video(self, prompt, motion_field_strength_x=12,motion_field_strength_y=12, n_prompt="", resolution=512, seed=24, num_frames=8, fps=2, t0=881, t1=941, |
|
use_cf_attn=True, use_motion_field=True, |
|
smooth_bg=False, smooth_bg_strength=0.4 ): |
|
|
|
if self.model_type != ModelType.Text2Video: |
|
unet = UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="unet") |
|
self.set_model(ModelType.Text2Video, model_id="runwayml/stable-diffusion-v1-5", unet=unet) |
|
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
|
if use_cf_attn: |
|
self.pipe.unet.set_attn_processor(processor=self.text2video_attn_proc) |
|
self.generator.manual_seed(seed) |
|
|
|
|
|
added_prompt = "high quality, HD, 8K, trending on artstation, high focus, dramatic lighting" |
|
negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic' |
|
|
|
prompt = prompt.rstrip() |
|
if len(prompt) > 0 and (prompt[-1] == "," or prompt[-1] == "."): |
|
prompt = prompt.rstrip()[:-1] |
|
prompt = prompt.rstrip() |
|
prompt = prompt + ", "+added_prompt |
|
if len(n_prompt)>0: |
|
negative_prompt = [n_prompt] |
|
else: |
|
negative_prompt = None |
|
|
|
result = self.inference(prompt=[prompt], |
|
video_length=num_frames, |
|
height=resolution, |
|
width=resolution, |
|
num_inference_steps=50, |
|
guidance_scale=7.5, |
|
guidance_stop_step=1.0, |
|
t0=t0, |
|
t1=t1, |
|
motion_field_strength_x=motion_field_strength_x, |
|
motion_field_strength_y=motion_field_strength_y, |
|
use_motion_field=use_motion_field, |
|
smooth_bg=smooth_bg, |
|
smooth_bg_strength=smooth_bg_strength, |
|
seed=seed, |
|
output_type='numpy', |
|
negative_prompt = negative_prompt, |
|
) |
|
return utils.create_video(result, fps) |