|
import argparse |
|
import warnings |
|
from pathlib import Path |
|
|
|
import torch |
|
from diffusers import ControlNetModel, DPMSolverMultistepScheduler, StableDiffusionControlNetImg2ImgPipeline |
|
from torch import Tensor |
|
from torchvision.io.video import read_video, write_video |
|
from torchvision.models.optical_flow import Raft_Large_Weights, raft_large |
|
from torchvision.transforms.functional import resize |
|
from torchvision.utils import flow_to_image |
|
from tqdm import trange |
|
|
|
raft_transform = Raft_Large_Weights.DEFAULT.transforms() |
|
|
|
|
|
@torch.inference_mode() |
|
def stylize_video( |
|
input_video: Tensor, |
|
prompt: str, |
|
strength: float = 0.7, |
|
num_steps: int = 20, |
|
guidance_scale: float = 7.5, |
|
controlnet_scale: float = 1.0, |
|
batch_size: int = 4, |
|
height: int = 512, |
|
width: int = 512, |
|
device: str = "cuda", |
|
) -> Tensor: |
|
""" |
|
Stylize a video with temporal coherence (less flickering!) using HuggingFace's Stable Diffusion ControlNet pipeline. |
|
|
|
Args: |
|
input_video (Tensor): Input video tensor of shape (T, C, H, W) and range [0, 1]. |
|
prompt (str): Text prompt to condition the diffusion process. |
|
strength (float, optional): How heavily stylization affects the image. |
|
num_steps (int, optional): Number of diffusion steps (tradeoff between quality and speed). |
|
guidance_scale (float, optional): Scale of the text guidance loss (how closely to adhere to text prompt). |
|
controlnet_scale (float, optional): Scale of the ControlNet conditioning (strength of temporal coherence). |
|
batch_size (int, optional): Number of frames to diffuse at once (faster but more memory intensive). |
|
height (int, optional): Height of the output video. |
|
width (int, optional): Width of the output video. |
|
device (str, optional): Device to run stylization process on. |
|
|
|
Returns: |
|
Tensor: Output video tensor of shape (T, C, H, W) and range [0, 1]. |
|
""" |
|
|
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
|
|
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
controlnet=ControlNetModel.from_pretrained("wav/TemporalNet2", torch_dtype=torch.float16), |
|
safety_checker=None, |
|
torch_dtype=torch.float16, |
|
).to(device) |
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
|
pipe.enable_xformers_memory_efficient_attention() |
|
pipe._progress_bar_config = dict(disable=True) |
|
|
|
raft = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=True).eval().to(device) |
|
|
|
output_video = [] |
|
for i in trange(1, len(input_video), batch_size, desc="Diffusing...", unit="frame", unit_scale=batch_size): |
|
prev = resize(input_video[i - 1 : i - 1 + batch_size], (height, width), antialias=True).to(device) |
|
curr = resize(input_video[i : i + batch_size], (height, width), antialias=True).to(device) |
|
prev = prev[: curr.shape[0]] |
|
|
|
flow_img = flow_to_image(raft.forward(*raft_transform(prev, curr))[-1]).div(255) |
|
control_img = torch.cat((prev, flow_img), dim=1) |
|
|
|
output, _ = pipe( |
|
prompt=[prompt] * curr.shape[0], |
|
image=curr, |
|
control_image=control_img, |
|
height=height, |
|
width=width, |
|
strength=strength, |
|
num_inference_steps=num_steps, |
|
guidance_scale=guidance_scale, |
|
controlnet_conditioning_scale=controlnet_scale, |
|
output_type="pt", |
|
return_dict=False, |
|
) |
|
|
|
output_video.append(output.permute(0, 2, 3, 1).cpu()) |
|
|
|
return torch.cat(output_video) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(usage=stylize_video.__doc__) |
|
parser.add_argument("-i", "--in-file", type=str, required=True) |
|
parser.add_argument("-p", "--prompt", type=str, required=True) |
|
parser.add_argument("-o", "--out-file", type=str, default=None) |
|
parser.add_argument("-s", "--strength", type=float, default=0.7) |
|
parser.add_argument("-S", "--num-steps", type=int, default=20) |
|
parser.add_argument("-g", "--guidance-scale", type=float, default=7.5) |
|
parser.add_argument("-c", "--controlnet-scale", type=float, default=1.0) |
|
parser.add_argument("-b", "--batch_size", type=int, default=4) |
|
parser.add_argument("-H", "--height", type=int, default=512) |
|
parser.add_argument("-W", "--width", type=int, default=512) |
|
parser.add_argument("-d", "--device", type=str, default="cuda") |
|
args = parser.parse_args() |
|
|
|
input_video, _, info = read_video(args.in_file, pts_unit="sec", output_format="TCHW") |
|
input_video = input_video.div(255) |
|
|
|
output_video = stylize_video( |
|
input_video=input_video, |
|
prompt=args.prompt, |
|
strength=args.strength, |
|
num_steps=args.num_steps, |
|
guidance_scale=args.guidance_scale, |
|
controlnet_scale=args.controlnet_scale, |
|
height=args.height, |
|
width=args.width, |
|
device=args.device, |
|
batch_size=args.batch_size, |
|
) |
|
|
|
out_file = f"{Path(args.in_file).stem} | {args.prompt}.mp4" if args.out_file is None else args.out_file |
|
write_video(out_file, output_video.mul(255), fps=info["video_fps"]) |
|
|