TileUpscalerV2 / app.py
gokaygokay's picture
Update app.py
749f0a2 verified
raw
history blame
19 kB
import spaces
import os
import requests
import time
import io
import torch
from PIL import Image
import cv2
import numpy as np
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, DDIMScheduler
from diffusers.models import AutoencoderKL
from diffusers.models.attention_processor import AttnProcessor2_0
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from RealESRGAN import RealESRGAN
import gradio as gr
import subprocess
from tqdm import tqdm
import shutil
import uuid
import json
import threading
# Constants
USE_TORCH_COMPILE = False
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
# Ensure CUDA is available
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. This script requires a CUDA-capable GPU.")
device = torch.device("cuda")
print(f"Using device: {device}")
# Replace the global abort_status with an Event
abort_event = threading.Event()
css = """
.gradio-container {
max-width: 100% !important;
padding: 20px !important;
}
#component-0 {
height: auto !important;
overflow: visible !important;
}
"""
def abort_job():
if abort_event.is_set():
return "Job is already being aborted."
abort_event.set()
return "Aborting job... Processing will stop after the current frame."
def check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
return True
except (subprocess.CalledProcessError, FileNotFoundError):
return False
def download_file(url, folder_path, filename):
if not os.path.exists(folder_path):
os.makedirs(folder_path)
file_path = os.path.join(folder_path, filename)
if os.path.isfile(file_path):
print(f"File already exists: {file_path}")
else:
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(file_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=1024):
file.write(chunk)
print(f"File successfully downloaded and saved: {file_path}")
else:
print(f"Error downloading the file. Status code: {response.status_code}")
def download_models():
models = {
"MODEL": ("https://huggingface.co/dantea1118/juggernaut_reborn/resolve/main/juggernaut_reborn.safetensors?download=true", "models/models/Stable-diffusion", "juggernaut_reborn.safetensors"),
"UPSCALER_X2": ("https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth?download=true", "models/upscalers/", "RealESRGAN_x2.pth"),
"UPSCALER_X4": ("https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth?download=true", "models/upscalers/", "RealESRGAN_x4.pth"),
"NEGATIVE_1": ("https://huggingface.co/philz1337x/embeddings/resolve/main/verybadimagenegative_v1.3.pt?download=true", "models/embeddings", "verybadimagenegative_v1.3.pt"),
"NEGATIVE_2": ("https://huggingface.co/datasets/AddictiveFuture/sd-negative-embeddings/resolve/main/JuggernautNegative-neg.pt?download=true", "models/embeddings", "JuggernautNegative-neg.pt"),
"LORA_1": ("https://huggingface.co/philz1337x/loras/resolve/main/SDXLrender_v2.0.safetensors?download=true", "models/Lora", "SDXLrender_v2.0.safetensors"),
"LORA_2": ("https://huggingface.co/philz1337x/loras/resolve/main/more_details.safetensors?download=true", "models/Lora", "more_details.safetensors"),
"CONTROLNET": ("https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1e_sd15_tile.pth?download=true", "models/ControlNet", "control_v11f1e_sd15_tile.pth"),
"VAE": ("https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.safetensors?download=true", "models/VAE", "vae-ft-mse-840000-ema-pruned.safetensors"),
}
for model, (url, folder, filename) in models.items():
download_file(url, folder, filename)
download_models()
def timer_func(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print(f"{func.__name__} took {end_time - start_time:.2f} seconds")
return result
return wrapper
class ModelManager:
def __init__(self):
self.pipe = None
self.realesrgan_x2 = None
self.realesrgan_x4 = None
def load_models(self):
if self.pipe is None:
self.pipe = self.setup_pipeline()
self.pipe.to(device)
self.pipe.unet.set_attn_processor(AttnProcessor2_0())
self.pipe.vae.set_attn_processor(AttnProcessor2_0())
if USE_TORCH_COMPILE:
self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True)
if self.realesrgan_x2 is None:
self.realesrgan_x2 = RealESRGAN(device, scale=2)
self.realesrgan_x2.load_weights('models/upscalers/RealESRGAN_x2.pth', download=False)
if self.realesrgan_x4 is None:
self.realesrgan_x4 = RealESRGAN(device, scale=4)
self.realesrgan_x4.load_weights('models/upscalers/RealESRGAN_x4.pth', download=False)
def setup_pipeline(self):
controlnet = ControlNetModel.from_single_file(
"models/ControlNet/control_v11f1e_sd15_tile.pth", torch_dtype=torch.float16
)
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
model_path = "models/models/Stable-diffusion/juggernaut_reborn.safetensors"
pipe = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
model_path,
controlnet=controlnet,
torch_dtype=torch.float16,
use_safetensors=True,
safety_checker=safety_checker
)
vae = AutoencoderKL.from_single_file(
"models/VAE/vae-ft-mse-840000-ema-pruned.safetensors",
torch_dtype=torch.float16
)
pipe.vae = vae
pipe.load_textual_inversion("models/embeddings/verybadimagenegative_v1.3.pt")
pipe.load_textual_inversion("models/embeddings/JuggernautNegative-neg.pt")
pipe.load_lora_weights("models/Lora/SDXLrender_v2.0.safetensors")
pipe.fuse_lora(lora_scale=0.5)
pipe.load_lora_weights("models/Lora/more_details.safetensors")
pipe.fuse_lora(lora_scale=1.)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.3, b2=1.4)
return pipe
@timer_func
def process_image_batch(self, input_images, resolution, num_inference_steps, strength, hdr, guidance_scale):
condition_images = [self.prepare_image(img, resolution, hdr) for img in input_images]
prompt = "masterpiece, best quality, highres"
negative_prompt = "low quality, normal quality, ugly, blurry, blur, lowres, bad anatomy, bad hands, cropped, worst quality, verybadimagenegative_v1.3, JuggernautNegative-neg"
options = {
"prompt": [prompt] * len(input_images),
"negative_prompt": [negative_prompt] * len(input_images),
"image": condition_images,
"control_image": condition_images,
"width": condition_images[0].size[0],
"height": condition_images[0].size[1],
"strength": strength,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"generator": torch.Generator(device=device).manual_seed(0),
}
print("Running inference on batch...")
results = self.pipe(**options).images
print("Batch processing completed successfully")
return results
def prepare_image(self, input_image, resolution, hdr):
condition_image = self.resize_and_upscale(input_image, resolution)
condition_image = self.create_hdr_effect(condition_image, hdr)
return condition_image
@timer_func
def resize_and_upscale(self, input_image, resolution):
scale = 2 if resolution <= 2048 else 4
if isinstance(input_image, str):
input_image = Image.open(input_image).convert("RGB")
elif isinstance(input_image, io.IOBase):
input_image = Image.open(input_image).convert("RGB")
elif isinstance(input_image, Image.Image):
input_image = input_image.convert("RGB")
elif isinstance(input_image, np.ndarray):
input_image = Image.fromarray(input_image).convert("RGB")
else:
raise ValueError(f"Unsupported input type for input_image: {type(input_image)}")
W, H = input_image.size
k = float(resolution) / min(H, W)
H = int(round(H * k / 64.0)) * 64
W = int(round(W * k / 64.0)) * 64
img = input_image.resize((W, H), resample=Image.LANCZOS)
if scale == 2:
img = self.realesrgan_x2.predict(img)
else:
img = self.realesrgan_x4.predict(img)
return img
@timer_func
def create_hdr_effect(self, original_image, hdr):
if hdr == 0:
return original_image
cv_original = cv2.cvtColor(np.array(original_image), cv2.COLOR_RGB2BGR)
factors = [1.0 - 0.9 * hdr, 1.0 - 0.7 * hdr, 1.0 - 0.45 * hdr,
1.0 - 0.25 * hdr, 1.0, 1.0 + 0.2 * hdr,
1.0 + 0.4 * hdr, 1.0 + 0.6 * hdr, 1.0 + 0.8 * hdr]
images = [cv2.convertScaleAbs(cv_original, alpha=factor) for factor in factors]
merge_mertens = cv2.createMergeMertens()
hdr_image = merge_mertens.process(images)
hdr_image_8bit = np.clip(hdr_image * 255, 0, 255).astype('uint8')
hdr_result = Image.fromarray(cv2.cvtColor(hdr_image_8bit, cv2.COLOR_BGR2RGB))
return hdr_result
model_manager = ModelManager()
model_manager.load_models() # Ensure models are loaded
def extract_frames(video_path, output_folder):
os.makedirs(output_folder, exist_ok=True)
command = [
'ffmpeg',
'-i', video_path,
'-vf', 'fps=30',
f'{output_folder}/frame_%06d.png'
]
subprocess.run(command, check=True)
def frames_to_video(input_folder, output_path, fps, original_video_path):
# First, create the video from frames without audio
temp_output_path = output_path + "_temp.mp4"
video_command = [
'ffmpeg',
'-framerate', str(fps),
'-i', f'{input_folder}/frame_%06d.png',
'-c:v', 'libx264',
'-pix_fmt', 'yuv420p',
temp_output_path
]
subprocess.run(video_command, check=True)
# Then, copy the audio from the original video and add it to the new video
final_command = [
'ffmpeg',
'-i', temp_output_path,
'-i', original_video_path,
'-c:v', 'copy',
'-c:a', 'aac',
'-map', '0:v:0',
'-map', '1:a:0?',
'-shortest',
output_path
]
subprocess.run(final_command, check=True)
# Remove the temporary file
os.remove(temp_output_path)
@timer_func
def process_video(input_video, resolution, num_inference_steps, strength, hdr, guidance_scale, max_frames=None, frame_interval=1, preserve_frames=False, batch_size=8, progress=gr.Progress()):
abort_event.clear() # Clear the abort flag at the start of a new job
print("Starting video processing...")
# Create a new job folder
job_id = str(uuid.uuid4())
job_folder = os.path.join("jobs", job_id)
os.makedirs(job_folder, exist_ok=True)
# Save job config
config = {
"resolution": resolution,
"num_inference_steps": num_inference_steps,
"strength": strength,
"hdr": hdr,
"guidance_scale": guidance_scale,
"max_frames": max_frames,
"frame_interval": frame_interval,
"preserve_frames": preserve_frames,
"batch_size": batch_size
}
with open(os.path.join(job_folder, "config.json"), "w") as f:
json.dump(config, f)
# If input_video is a file object or has a 'name' attribute, use its name
if isinstance(input_video, io.IOBase) or hasattr(input_video, 'name'):
input_video = input_video.name
# Set up folders
frames_folder = os.path.join(job_folder, "video_frames")
processed_frames_folder = os.path.join(job_folder, "processed_frames")
os.makedirs(frames_folder, exist_ok=True)
os.makedirs(processed_frames_folder, exist_ok=True)
# Extract frames
progress(0.1, desc="Extracting frames...")
extract_frames(input_video, frames_folder)
# Process selected frames
frame_files = sorted(os.listdir(frames_folder))
total_frames = len(frame_files)
frames_to_process = min(max_frames, total_frames) if max_frames else total_frames
try:
progress(0.2, desc="Processing frames...")
for i in tqdm(range(0, frames_to_process, batch_size), desc="Processing batches"):
if abort_event.is_set():
print("Job aborted. Stopping processing of new frames.")
break
batch_frames = frame_files[i:min(i+batch_size, frames_to_process)]
input_images = [Image.open(os.path.join(frames_folder, frame)) for frame in batch_frames]
processed_images = model_manager.process_image_batch(input_images, resolution, num_inference_steps, strength, hdr, guidance_scale)
for frame_file, processed_image in zip(batch_frames, processed_images):
output_frame_path = os.path.join(processed_frames_folder, frame_file)
if not preserve_frames or not os.path.exists(output_frame_path):
processed_image.save(output_frame_path)
progress((0.2 + 0.7 * (i + batch_size) / frames_to_process), desc=f"Processed batch {i//batch_size + 1}/{(frames_to_process-1)//batch_size + 1}")
# Always attempt to reassemble video
progress(0.9, desc="Reassembling video...")
input_filename = os.path.splitext(os.path.basename(input_video))[0]
output_video = os.path.join(job_folder, f"{input_filename}_upscaled.mp4")
frames_to_video(processed_frames_folder, output_video, 30, input_video)
if abort_event.is_set():
progress(1.0, desc="Video processing aborted, but partial result saved")
print("Video processing aborted, but partial result saved")
else:
progress(1.0, desc="Video processing completed successfully")
print("Video processing completed successfully")
return output_video
except Exception as e:
print(f"An error occurred during processing: {str(e)}")
progress(1.0, desc=f"Error: {str(e)}")
return None
@spaces.GPU(duration=200)
def gradio_process_media(input_media, resolution, num_inference_steps, strength, hdr, guidance_scale, max_frames, frame_interval, preserve_frames, batch_size, progress=gr.Progress()):
abort_event.clear() # Clear the abort flag at the start of a new job
if input_media is None:
return None, "No input media provided."
print(f"Input media type: {type(input_media)}")
# Get the file path
if isinstance(input_media, str):
file_path = input_media
elif isinstance(input_media, io.IOBase):
file_path = input_media.name
elif hasattr(input_media, 'name'):
file_path = input_media.name
else:
raise ValueError(f"Unsupported input type: {type(input_media)}")
print(f"File path: {file_path}")
# Check if the file is a video
video_extensions = ('.mp4', '.avi', '.mov', '.mkv')
if file_path.lower().endswith(video_extensions):
print("Processing video...")
result = process_video(file_path, resolution, num_inference_steps, strength, hdr, guidance_scale, max_frames, frame_interval, preserve_frames, batch_size, progress)
if result:
return result, "Video processing completed successfully."
else:
return None, "Error occurred during video processing."
else:
print("Processing image...")
result = model_manager.process_image(file_path, resolution, num_inference_steps, strength, hdr, guidance_scale)
if result:
# Save the processed image
output_path = os.path.join("processed_images", f"processed_{os.path.basename(file_path)}")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
result.save(output_path)
return output_path, "Image processing completed successfully."
else:
return None, "Error occurred during image processing."
title = """
<h1 align="center">Simple Slow Video Upscaler</h1>
<p align="center">Currently takes too long :( Please try 1-2 seconds videos only.</p>
<p align="center">
<a href="https://twitter.com/hrishioa" target="_blank">[Hrishi]</a>
<a href="https://huggingface.co/spaces/gokaygokay/Tile-Upscaler" target="_blank">[gokaygokay/Tile-Upscaler]</a>
<a href="https://github.com/philz1337x/clarity-upscaler" target="_blank">[philz1337x]</a>
<a href="https://github.com/BatouResearch/controlnet-tile-upscale" target="_blank">[BatouResearch]</a>
<a href="https://github.com/hrishioa/SimpleSlowVideoUpscaler" target="_blank">[Hrishi GitHub]</a>
</p>
"""
# Update the Gradio interface
with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue")) as iface:
gr.HTML(title)
with gr.Row():
with gr.Column(scale=2):
input_media = gr.File(label="Input Media (Image or Video)")
resolution = gr.Slider(256, 2048, 512, step=256, label="Resolution")
num_inference_steps = gr.Slider(1, 50, 20, step=1, label="Number of Inference Steps")
strength = gr.Slider(0, 1, 0.05, step=0.01, label="Strength")
hdr = gr.Slider(0, 1, 0, step=0.1, label="HDR Effect")
guidance_scale = gr.Slider(0, 20, 5, step=0.5, label="Guidance Scale")
max_frames = gr.Number(label="Max Frames to Process (leave empty for full video)", precision=0)
frame_interval = gr.Slider(1, 30, 1, step=1, label="Frame Interval (process every nth frame)")
preserve_frames = gr.Checkbox(label="Preserve Existing Processed Frames", value=True)
batch_size = gr.Slider(1, 16, 8, step=1, label="Batch Size")
with gr.Column(scale=1):
submit_button = gr.Button("Process Media")
abort_button = gr.Button("Abort Job")
output = gr.File(label="Processed Media")
status = gr.Markdown("Ready to process media.")
submit_button.click(
gradio_process_media,
inputs=[input_media, resolution, num_inference_steps, strength, hdr, guidance_scale, max_frames, frame_interval, preserve_frames, batch_size],
outputs=[output, status]
)
abort_button.click(abort_job, inputs=[], outputs=status)
# Launch the Gradio app
iface.launch()