import argparse import warnings from pathlib import Path import torch from diffusers import ControlNetModel, DPMSolverMultistepScheduler, StableDiffusionControlNetImg2ImgPipeline from torch import Tensor from torchvision.io.video import read_video, write_video from torchvision.models.optical_flow import Raft_Large_Weights, raft_large from torchvision.transforms.functional import resize from torchvision.utils import flow_to_image from tqdm import trange raft_transform = Raft_Large_Weights.DEFAULT.transforms() @torch.inference_mode() def stylize_video( input_video: Tensor, prompt: str, strength: float = 0.7, num_steps: int = 20, guidance_scale: float = 7.5, controlnet_scale: float = 1.0, batch_size: int = 4, height: int = 512, width: int = 512, device: str = "cuda", ) -> Tensor: """ Stylize a video with temporal coherence (less flickering!) using HuggingFace's Stable Diffusion ControlNet pipeline. Args: input_video (Tensor): Input video tensor of shape (T, C, H, W) and range [0, 1]. prompt (str): Text prompt to condition the diffusion process. strength (float, optional): How heavily stylization affects the image. num_steps (int, optional): Number of diffusion steps (tradeoff between quality and speed). guidance_scale (float, optional): Scale of the text guidance loss (how closely to adhere to text prompt). controlnet_scale (float, optional): Scale of the ControlNet conditioning (strength of temporal coherence). batch_size (int, optional): Number of frames to diffuse at once (faster but more memory intensive). height (int, optional): Height of the output video. width (int, optional): Width of the output video. device (str, optional): Device to run stylization process on. Returns: Tensor: Output video tensor of shape (T, C, H, W) and range [0, 1]. """ with warnings.catch_warnings(): warnings.simplefilter("ignore") # silence annoying TypedStorage warnings pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=ControlNetModel.from_pretrained("wav/TemporalNet2", torch_dtype=torch.float16), safety_checker=None, torch_dtype=torch.float16, ).to(device) pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_xformers_memory_efficient_attention() pipe._progress_bar_config = dict(disable=True) raft = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=True).eval().to(device) output_video = [] for i in trange(1, len(input_video), batch_size, desc="Diffusing...", unit="frame", unit_scale=batch_size): prev = resize(input_video[i - 1 : i - 1 + batch_size], (height, width), antialias=True).to(device) curr = resize(input_video[i : i + batch_size], (height, width), antialias=True).to(device) prev = prev[: curr.shape[0]] # make sure prev and curr have the same batch size (for the last batch) flow_img = flow_to_image(raft.forward(*raft_transform(prev, curr))[-1]).div(255) control_img = torch.cat((prev, flow_img), dim=1) output, _ = pipe( prompt=[prompt] * curr.shape[0], image=curr, control_image=control_img, height=height, width=width, strength=strength, num_inference_steps=num_steps, guidance_scale=guidance_scale, controlnet_conditioning_scale=controlnet_scale, output_type="pt", return_dict=False, ) output_video.append(output.permute(0, 2, 3, 1).cpu()) return torch.cat(output_video) if __name__ == "__main__": parser = argparse.ArgumentParser(usage=stylize_video.__doc__) parser.add_argument("-i", "--in-file", type=str, required=True) parser.add_argument("-p", "--prompt", type=str, required=True) parser.add_argument("-o", "--out-file", type=str, default=None) parser.add_argument("-s", "--strength", type=float, default=0.7) parser.add_argument("-S", "--num-steps", type=int, default=20) parser.add_argument("-g", "--guidance-scale", type=float, default=7.5) parser.add_argument("-c", "--controlnet-scale", type=float, default=1.0) parser.add_argument("-b", "--batch_size", type=int, default=4) parser.add_argument("-H", "--height", type=int, default=512) parser.add_argument("-W", "--width", type=int, default=512) parser.add_argument("-d", "--device", type=str, default="cuda") args = parser.parse_args() input_video, _, info = read_video(args.in_file, pts_unit="sec", output_format="TCHW") input_video = input_video.div(255) output_video = stylize_video( input_video=input_video, prompt=args.prompt, strength=args.strength, num_steps=args.num_steps, guidance_scale=args.guidance_scale, controlnet_scale=args.controlnet_scale, height=args.height, width=args.width, device=args.device, batch_size=args.batch_size, ) out_file = f"{Path(args.in_file).stem} | {args.prompt}.mp4" if args.out_file is None else args.out_file write_video(out_file, output_video.mul(255), fps=info["video_fps"])