import os import cv2 import torch import spaces import imageio import numpy as np import gradio as gr torch.jit.script = lambda f: f import argparse from utils.batch_inference import ( BSRInferenceLoop, BIDInferenceLoop ) # import subprocess # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) device = 'cuda' if torch.cuda.is_available() else 'cpu' def get_example(task): case = { "dn": [ ['examples/bus.mp4',], ['examples/koala.mp4',], ['examples/flamingo.mp4',], ['examples/rhino.mp4',], ['examples/elephant.mp4',], ['examples/sheep.mp4',], ['examples/dog-agility.mp4',], # ['examples/dog-gooses.mp4',], ], "sr": [ ['examples/bus_sr.mp4',], ['examples/koala_sr.mp4',], ['examples/flamingo_sr.mp4',], ['examples/rhino_sr.mp4',], ['examples/elephant_sr.mp4',], ['examples/sheep_sr.mp4',], ['examples/dog-agility_sr.mp4',], # ['examples/dog-gooses_sr.mp4',], ] } return case[task] def update_prompt(input_video): video_name = input_video.split('/')[-1] return set_default_prompt(video_name) # Map videos to corresponding images video_to_image = { 'bus.mp4': ['examples_frames/bus'], 'koala.mp4': ['examples_frames/koala'], 'dog-gooses.mp4': ['examples_frames/dog-gooses'], 'flamingo.mp4': ['examples_frames/flamingo'], 'rhino.mp4': ['examples_frames/rhino'], 'elephant.mp4': ['examples_frames/elephant'], 'sheep.mp4': ['examples_frames/sheep'], 'dog-agility.mp4': ['examples_frames/dog-agility'], 'bus_sr.mp4': ['examples_frames/bus_sr'], 'koala_sr.mp4': ['examples_frames/koala_sr'], 'dog-gooses_sr.mp4': ['examples_frames/dog_gooses_sr'], 'flamingo_sr.mp4': ['examples_frames/flamingo_sr'], 'rhino_sr.mp4': ['examples_frames/rhino_sr'], 'elephant_sr.mp4': ['examples_frames/elephant_sr'], 'sheep_sr.mp4': ['examples_frames/sheep_sr'], 'dog-agility_sr.mp4': ['examples_frames/dog-agility_sr'], } def images_to_video(image_list, output_path, fps=10): # Convert PIL Images to numpy arrays frames = [np.array(img).astype(np.uint8) for img in image_list] frames = frames[:20] # Create video writer writer = imageio.get_writer(output_path, fps=fps, codec='libx264') for frame in frames: writer.append_data(frame) writer.close() @spaces.GPU(duration=120) def DiffBIR_restore(input_video, prompt, sr_ratio, n_frames, n_steps, guidance_scale, seed, n_prompt, task): video_name = input_video.split('/')[-1] if video_name in video_to_image: frames_path = video_to_image[video_name][0] else: return None print(f"[INFO] input_video: {input_video}") print(f"[INFO] Frames path: {frames_path}") args = argparse.Namespace() # args.task = True, choices=["sr", "dn", "fr", "fr_bg"] args.task = task args.upscale = sr_ratio ### sampling parameters args.steps = n_steps args.better_start = True args.tiled = False args.tile_size = 512 args.tile_stride = 256 args.pos_prompt = prompt args.neg_prompt = n_prompt args.cfg_scale = guidance_scale ### input parameters args.input = frames_path args.n_samples = 1 args.batch_size = 10 args.final_size = (480, 854) args.config = "configs/inference/my_cldm.yaml" ### guidance parameters args.guidance = False args.g_loss = "w_mse" args.g_scale = 0.0 args.g_start = 1001 args.g_stop = -1 args.g_space = "latent" args.g_repeat = 1 ### output parameters args.output = " " ### common parameters args.seed = seed args.device = "cuda" args.n_frames = n_frames ### latent control parameters args.warp_period = [0, 0.1] args.merge_period = [0, 0] args.ToMe_period = [0, 1] args.merge_ratio = [0.6, 0] if args.task == "sr": restored_vid_path = BSRInferenceLoop(args).run() elif args.task == "dn": restored_vid_path = BIDInferenceLoop(args).run() torch.cuda.empty_cache() return restored_vid_path ######## # demo # ######## intro = """

DiffIR2VR - Zero-Shot Video Restoration

[Project page] [arXiv]
Note that this page is a limited demo of DiffIR2VR. For more configurations, please visit our GitHub page. The code will be released soon!
""" with gr.Blocks(css="style.css") as demo: gr.HTML(intro) with gr.Tab(label="Super-resolution with DiffBIR"): with gr.Row(): input_video = gr.Video(label="Input Video") output_video = gr.Video(label="Restored Video", interactive=False) with gr.Row(): run_button = gr.Button("Restore your video !", visible=True) with gr.Accordion('Advanced options', open=False): prompt = gr.Textbox( label="Prompt", max_lines=1, placeholder="describe your video content" # value="bear, Van Gogh Style" ) sr_ratio = gr.Slider(label='SR ratio', minimum=1, maximum=16, value=4, step=1) n_frames = gr.Slider(label='Frames', minimum=1, maximum=60, value=10, step=1) n_steps = gr.Slider(label='Steps', minimum=1, maximum=100, value=10, step=1) guidance_scale = gr.Slider(label='Guidance Scale', minimum=0.1, maximum=30.0, value=4.0, step=0.1) seed = gr.Slider(label='Seed', minimum=-1, maximum=1000, step=1, randomize=True) n_prompt = gr.Textbox( label='Negative Prompt', value="low quality, blurry, low-resolution, noisy, unsharp, weird textures" ) task = gr.Textbox(value="sr", visible=False) # input_video.change( # fn = update_prompt, # inputs = [input_video], # outputs = [prompt], # queue = False) run_button.click(fn = DiffBIR_restore, inputs = [input_video, prompt, sr_ratio, n_frames, n_steps, guidance_scale, seed, n_prompt, task ], outputs = [output_video] ) gr.Examples( examples=get_example("sr"), label='Examples', inputs=[input_video], outputs=[output_video], examples_per_page=7 ) with gr.Tab(label="Denoise with DiffBIR"): with gr.Row(): input_video = gr.Video(label="Input Video") output_video = gr.Video(label="Restored Video", interactive=False) with gr.Row(): run_button = gr.Button("Restore your video !", visible=True) with gr.Accordion('Advanced options', open=False): prompt = gr.Textbox( label="Prompt", max_lines=1, placeholder="describe your video content" # value="bear, Van Gogh Style" ) n_frames = gr.Slider(label='Frames', minimum=1, maximum=60, value=10, step=1) n_steps = gr.Slider(label='Steps', minimum=1, maximum=100, value=10, step=1) guidance_scale = gr.Slider(label='Guidance Scale', minimum=0.1, maximum=30.0, value=4.0, step=0.1) seed = gr.Slider(label='Seed', minimum=-1, maximum=1000, step=1, randomize=True) n_prompt = gr.Textbox( label='Negative Prompt', value="low quality, blurry, low-resolution, noisy, unsharp, weird textures" ) task = gr.Textbox(value="dn", visible=False) sr_ratio = gr.Number(value=1, visible=False) # input_video.change( # fn = update_prompt, # inputs = [input_video], # outputs = [prompt], # queue = False) run_button.click(fn = DiffBIR_restore, inputs = [input_video, prompt, sr_ratio, n_frames, n_steps, guidance_scale, seed, n_prompt, task ], outputs = [output_video] ) gr.Examples( examples=get_example("dn"), label='Examples', inputs=[input_video], outputs=[output_video], examples_per_page=7 ) demo.queue() demo.launch()