|
import os |
|
import einops |
|
import gradio as gr |
|
from gradio_imageslider import ImageSlider |
|
import numpy as np |
|
import torch |
|
import random |
|
from PIL import Image |
|
from pathlib import Path |
|
from torchvision import transforms |
|
import torch.nn.functional as F |
|
from torchvision.models import resnet50, ResNet50_Weights |
|
|
|
from pytorch_lightning import seed_everything |
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor |
|
from diffusers import AutoencoderKL, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler |
|
|
|
from pipelines.pipeline_pasd import StableDiffusionControlNetPipeline |
|
from myutils.misc import load_dreambooth_lora, rand_name |
|
from myutils.wavelet_color_fix import wavelet_color_fix |
|
from annotator.retinaface import RetinaFaceDetection |
|
|
|
use_pasd_light = False |
|
face_detector = RetinaFaceDetection() |
|
|
|
if use_pasd_light: |
|
from models.pasd_light.unet_2d_condition import UNet2DConditionModel |
|
from models.pasd_light.controlnet import ControlNetModel |
|
else: |
|
from models.pasd.unet_2d_condition import UNet2DConditionModel |
|
from models.pasd.controlnet import ControlNetModel |
|
|
|
pretrained_model_path = "checkpoints/stable-diffusion-v1-5" |
|
ckpt_path = "runs/pasd/checkpoint-100000" |
|
|
|
dreambooth_lora_path = "checkpoints/personalized_models/majicmixRealistic_v6.safetensors" |
|
|
|
weight_dtype = torch.float16 |
|
device = "cuda" |
|
|
|
scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") |
|
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") |
|
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") |
|
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") |
|
feature_extractor = CLIPImageProcessor.from_pretrained(f"{pretrained_model_path}/feature_extractor") |
|
unet = UNet2DConditionModel.from_pretrained(ckpt_path, subfolder="unet") |
|
controlnet = ControlNetModel.from_pretrained(ckpt_path, subfolder="controlnet") |
|
vae.requires_grad_(False) |
|
text_encoder.requires_grad_(False) |
|
unet.requires_grad_(False) |
|
controlnet.requires_grad_(False) |
|
|
|
unet, vae, text_encoder = load_dreambooth_lora(unet, vae, text_encoder, dreambooth_lora_path) |
|
|
|
text_encoder.to(device, dtype=weight_dtype) |
|
vae.to(device, dtype=weight_dtype) |
|
unet.to(device, dtype=weight_dtype) |
|
controlnet.to(device, dtype=weight_dtype) |
|
|
|
validation_pipeline = StableDiffusionControlNetPipeline( |
|
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor, |
|
unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False, |
|
) |
|
|
|
validation_pipeline._init_tiled_vae(decoder_tile_size=224) |
|
|
|
weights = ResNet50_Weights.DEFAULT |
|
preprocess = weights.transforms() |
|
resnet = resnet50(weights=weights) |
|
resnet.eval() |
|
|
|
def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed): |
|
process_size = 768 |
|
resize_preproc = transforms.Compose([ |
|
transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR), |
|
]) |
|
|
|
with torch.no_grad(): |
|
seed_everything(seed) |
|
generator = torch.Generator(device=device) |
|
|
|
input_image = input_image.convert('RGB') |
|
batch = preprocess(input_image).unsqueeze(0) |
|
prediction = resnet(batch).squeeze(0).softmax(0) |
|
class_id = prediction.argmax().item() |
|
score = prediction[class_id].item() |
|
category_name = weights.meta["categories"][class_id] |
|
if score >= 0.1: |
|
prompt += f"{category_name}" if prompt=='' else f", {category_name}" |
|
|
|
prompt = a_prompt if prompt=='' else f"{prompt}, {a_prompt}" |
|
|
|
ori_width, ori_height = input_image.size |
|
resize_flag = False |
|
|
|
rscale = upscale |
|
input_image = input_image.resize((input_image.size[0]*rscale, input_image.size[1]*rscale)) |
|
|
|
|
|
|
|
|
|
input_image = input_image.resize((input_image.size[0]//8*8, input_image.size[1]//8*8)) |
|
width, height = input_image.size |
|
resize_flag = True |
|
|
|
try: |
|
image = validation_pipeline( |
|
None, prompt, input_image, num_inference_steps=denoise_steps, generator=generator, height=height, width=width, guidance_scale=cfg, |
|
negative_prompt=n_prompt, conditioning_scale=alpha, eta=0.0, |
|
).images[0] |
|
|
|
if True: |
|
image = wavelet_color_fix(image, input_image) |
|
|
|
if resize_flag: |
|
image = image.resize((ori_width*rscale, ori_height*rscale)) |
|
except Exception as e: |
|
print(e) |
|
image = Image.new(mode="RGB", size=(512, 512)) |
|
|
|
|
|
image.save('result.jpg', 'JPEG') |
|
|
|
|
|
input_image.save('input.jpg', 'JPEG') |
|
|
|
return ("input.jpg", "result.jpg"), "result.jpg" |
|
|
|
title = "Pixel-Aware Stable Diffusion for Real-ISR" |
|
description = "Gradio Demo for PASD Real-ISR. To use it, simply upload your image, or click one of the examples to load them." |
|
article = "<a href='https://github.com/yangxy/PASD' target='_blank'>Github Repo Pytorch</a>" |
|
|
|
|
|
css = """ |
|
#col-container{ |
|
margin: 0 auto; |
|
max-width: 720px; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
with gr.Column(elem_id="col-container"): |
|
gr.HTML(f""" |
|
<h2 style="text-align: center;"> |
|
PASD Gradio demo |
|
</h2> |
|
<p style="text-align: center;"> |
|
Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization |
|
</p> |
|
|
|
""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(type="pil", sources=["upload"], value="samples/frog.png") |
|
prompt_in = gr.Textbox(label="Prompt", value="Frog") |
|
with gr.Accordion(label="Advanced settings", open=False): |
|
added_prompt = gr.Textbox(label="Added Prompt", value='clean, high-resolution, 8k, best quality, masterpiece') |
|
neg_prompt = gr.Textbox(label="Negative Prompt",value='dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') |
|
denoise_steps = gr.Slider(label="Denoise Steps", minimum=10, maximum=50, value=20, step=1) |
|
upsample_scale = gr.Slider(label="Upsample Scale", minimum=1, maximum=4, value=2, step=1) |
|
condition_scale = gr.Slider(label="Conditioning Scale", minimum=0.5, maximum=1.5, value=1.1, step=0.1) |
|
classifier_free_guidance = gr.Slider(label="Classier-free Guidance", minimum=0.1, maximum=10.0, value=7.5, step=0.1) |
|
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) |
|
submit_btn = gr.Button("Submit") |
|
with gr.Column(): |
|
b_a_slider = ImageSlider(label="B/A result", position=0.5) |
|
file_output = gr.File(label="Downloadable image result") |
|
|
|
submit_btn.click( |
|
fn = inference, |
|
inputs = [ |
|
input_image, prompt_in, |
|
added_prompt, neg_prompt, |
|
denoise_steps, |
|
upsample_scale, condition_scale, |
|
classifier_free_guidance, seed |
|
], |
|
outputs = [ |
|
b_a_slider, |
|
file_output |
|
] |
|
) |
|
demo.queue().launch() |