fffiloni commited on
Commit
d30159c
·
verified ·
1 Parent(s): 14fc79e

Create app_zero.py

Browse files
Files changed (1) hide show
  1. app_zero.py +254 -0
app_zero.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import types
3
+ torch.cuda.get_device_capability = lambda *args, **kwargs: (8, 6)
4
+ torch.cuda.get_device_properties = lambda *args, **kwargs: types.SimpleNamespace(name='NVIDIA A10G', major=8, minor=6, total_memory=23836033024, multi_processor_count=80)
5
+
6
+ import huggingface_hub
7
+ huggingface_hub.snapshot_download(
8
+ repo_id='camenduru/PASD',
9
+ allow_patterns=[
10
+ 'pasd/**',
11
+ 'pasd_light/**',
12
+ 'pasd_light_rrdb/**',
13
+ 'pasd_rrdb/**',
14
+ ],
15
+ local_dir='PASD/runs',
16
+ local_dir_use_symlinks=False,
17
+ )
18
+ huggingface_hub.hf_hub_download(
19
+ repo_id='camenduru/PASD',
20
+ filename='majicmixRealistic_v6.safetensors',
21
+ local_dir='PASD/checkpoints/personalized_models',
22
+ local_dir_use_symlinks=False,
23
+ )
24
+ huggingface_hub.hf_hub_download(
25
+ repo_id='akhaliq/RetinaFace-R50',
26
+ filename='RetinaFace-R50.pth',
27
+ local_dir='PASD/annotator/ckpts',
28
+ local_dir_use_symlinks=False,
29
+ )
30
+
31
+ import sys; sys.path.append('./PASD')
32
+ import spaces
33
+ import os
34
+ import datetime
35
+ import einops
36
+ import gradio as gr
37
+ from gradio_imageslider import ImageSlider
38
+ import numpy as np
39
+ import torch
40
+ import random
41
+ from PIL import Image
42
+ from pathlib import Path
43
+ from torchvision import transforms
44
+ import torch.nn.functional as F
45
+ from torchvision.models import resnet50, ResNet50_Weights
46
+
47
+ from pytorch_lightning import seed_everything
48
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
49
+ from diffusers import AutoencoderKL, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler
50
+
51
+ from pipelines.pipeline_pasd import StableDiffusionControlNetPipeline
52
+ from myutils.misc import load_dreambooth_lora, rand_name
53
+ from myutils.wavelet_color_fix import wavelet_color_fix
54
+ from annotator.retinaface import RetinaFaceDetection
55
+
56
+ use_pasd_light = False
57
+ face_detector = RetinaFaceDetection()
58
+
59
+ if use_pasd_light:
60
+ from models.pasd_light.unet_2d_condition import UNet2DConditionModel
61
+ from models.pasd_light.controlnet import ControlNetModel
62
+ else:
63
+ from models.pasd.unet_2d_condition import UNet2DConditionModel
64
+ from models.pasd.controlnet import ControlNetModel
65
+
66
+ pretrained_model_path = "runwayml/stable-diffusion-v1-5"
67
+ ckpt_path = "PASD/runs/pasd/checkpoint-100000"
68
+ #dreambooth_lora_path = "checkpoints/personalized_models/toonyou_beta3.safetensors"
69
+ dreambooth_lora_path = "PASD/checkpoints/personalized_models/majicmixRealistic_v6.safetensors"
70
+ #dreambooth_lora_path = "checkpoints/personalized_models/Realistic_Vision_V5.1.safetensors"
71
+ weight_dtype = torch.float16
72
+ device = "cuda"
73
+
74
+ scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
75
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
76
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
77
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
78
+ feature_extractor = CLIPImageProcessor.from_pretrained(pretrained_model_path, subfolder="feature_extractor")
79
+ unet = UNet2DConditionModel.from_pretrained(ckpt_path, subfolder="unet")
80
+ controlnet = ControlNetModel.from_pretrained(ckpt_path, subfolder="controlnet")
81
+ vae.requires_grad_(False)
82
+ text_encoder.requires_grad_(False)
83
+ unet.requires_grad_(False)
84
+ controlnet.requires_grad_(False)
85
+
86
+ unet, vae, text_encoder = load_dreambooth_lora(unet, vae, text_encoder, dreambooth_lora_path)
87
+
88
+ text_encoder.to(device, dtype=weight_dtype)
89
+ vae.to(device, dtype=weight_dtype)
90
+ unet.to(device, dtype=weight_dtype)
91
+ controlnet.to(device, dtype=weight_dtype)
92
+
93
+ validation_pipeline = StableDiffusionControlNetPipeline(
94
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor,
95
+ unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False,
96
+ )
97
+ #validation_pipeline.enable_vae_tiling()
98
+ validation_pipeline._init_tiled_vae(decoder_tile_size=224)
99
+
100
+ weights = ResNet50_Weights.DEFAULT
101
+ preprocess = weights.transforms()
102
+ resnet = resnet50(weights=weights)
103
+ resnet.eval()
104
+
105
+ def resize_image(image_path, target_height):
106
+ # Open the image file
107
+ with Image.open(image_path) as img:
108
+ # Calculate the ratio to resize the image to the target height
109
+ ratio = target_height / float(img.size[1])
110
+ # Calculate the new width based on the aspect ratio
111
+ new_width = int(float(img.size[0]) * ratio)
112
+ # Resize the image
113
+ resized_img = img.resize((new_width, target_height), Image.LANCZOS)
114
+ # Save the resized image
115
+ #resized_img.save(output_path)
116
+ return resized_img
117
+
118
+ @spaces.GPU(enable_queue=True)
119
+ def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed):
120
+
121
+ #tempo fix for seed equals-1
122
+ if seed == -1:
123
+ seed = 0
124
+
125
+ input_image = resize_image(input_image, 512)
126
+ process_size = 768
127
+ resize_preproc = transforms.Compose([
128
+ transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR),
129
+ ])
130
+
131
+ # Get the current timestamp
132
+ timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
133
+
134
+ with torch.no_grad():
135
+ seed_everything(seed)
136
+ generator = torch.Generator(device=device)
137
+
138
+ input_image = input_image.convert('RGB')
139
+ batch = preprocess(input_image).unsqueeze(0)
140
+ prediction = resnet(batch).squeeze(0).softmax(0)
141
+ class_id = prediction.argmax().item()
142
+ score = prediction[class_id].item()
143
+ category_name = weights.meta["categories"][class_id]
144
+ if score >= 0.1:
145
+ prompt += f"{category_name}" if prompt=='' else f", {category_name}"
146
+
147
+ prompt = a_prompt if prompt=='' else f"{prompt}, {a_prompt}"
148
+
149
+ ori_width, ori_height = input_image.size
150
+ resize_flag = False
151
+
152
+ rscale = upscale
153
+ input_image = input_image.resize((input_image.size[0]*rscale, input_image.size[1]*rscale))
154
+
155
+ #if min(validation_image.size) < process_size:
156
+ # validation_image = resize_preproc(validation_image)
157
+
158
+ input_image = input_image.resize((input_image.size[0]//8*8, input_image.size[1]//8*8))
159
+ width, height = input_image.size
160
+ resize_flag = True #
161
+
162
+ try:
163
+ image = validation_pipeline(
164
+ None, prompt, input_image, num_inference_steps=denoise_steps, generator=generator, height=height, width=width, guidance_scale=cfg,
165
+ negative_prompt=n_prompt, conditioning_scale=alpha, eta=0.0,
166
+ ).images[0]
167
+
168
+ if True: #alpha<1.0:
169
+ image = wavelet_color_fix(image, input_image)
170
+
171
+ if resize_flag:
172
+ image = image.resize((ori_width*rscale, ori_height*rscale))
173
+ except Exception as e:
174
+ print(e)
175
+ image = Image.new(mode="RGB", size=(512, 512))
176
+
177
+ # Convert and save the image as JPEG
178
+ image.save(f'result_{timestamp}.jpg', 'JPEG')
179
+
180
+ # Convert and save the image as JPEG
181
+ input_image.save(f'input_{timestamp}.jpg', 'JPEG')
182
+
183
+ return (f"input_{timestamp}.jpg", f"result_{timestamp}.jpg"), f"result_{timestamp}.jpg"
184
+
185
+ title = "Pixel-Aware Stable Diffusion for Real-ISR"
186
+ description = "Gradio Demo for PASD Real-ISR. To use it, simply upload your image, or click one of the examples to load them."
187
+ article = "<a href='https://github.com/yangxy/PASD' target='_blank'>Github Repo Pytorch</a>"
188
+ #examples=[['samples/27d38eeb2dbbe7c9.png'],['samples/629e4da70703193b.png']]
189
+
190
+ css = """
191
+ #col-container{
192
+ margin: 0 auto;
193
+ max-width: 720px;
194
+ }
195
+ #project-links{
196
+ margin: 0 0 12px !important;
197
+ column-gap: 8px;
198
+ display: flex;
199
+ justify-content: center;
200
+ flex-wrap: nowrap;
201
+ flex-direction: row;
202
+ align-items: center;
203
+ }
204
+ """
205
+
206
+ with gr.Blocks(css=css) as demo:
207
+ with gr.Column(elem_id="col-container"):
208
+ gr.HTML(f"""
209
+ <h2 style="text-align: center;">
210
+ PASD Magnify
211
+ </h2>
212
+ <p style="text-align: center;">
213
+ Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization
214
+ </p>
215
+ <p id="project-links" align="center">
216
+ <a href='https://github.com/yangxy/PASD'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://huggingface.co/papers/2308.14469'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
217
+ </p>
218
+ <p style="margin:12px auto;display: flex;justify-content: center;">
219
+ <a href="https://huggingface.co/spaces/fffiloni/PASD?duplicate=true"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space"></a>
220
+ </p>
221
+
222
+ """)
223
+ with gr.Row():
224
+ with gr.Column():
225
+ input_image = gr.Image(type="filepath", sources=["upload"], value="PASD/samples/frog.png")
226
+ prompt_in = gr.Textbox(label="Prompt", value="Frog")
227
+ with gr.Accordion(label="Advanced settings", open=False):
228
+ added_prompt = gr.Textbox(label="Added Prompt", value='clean, high-resolution, 8k, best quality, masterpiece')
229
+ neg_prompt = gr.Textbox(label="Negative Prompt",value='dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
230
+ denoise_steps = gr.Slider(label="Denoise Steps", minimum=10, maximum=50, value=20, step=1)
231
+ upsample_scale = gr.Slider(label="Upsample Scale", minimum=1, maximum=4, value=2, step=1)
232
+ condition_scale = gr.Slider(label="Conditioning Scale", minimum=0.5, maximum=1.5, value=1.1, step=0.1)
233
+ classifier_free_guidance = gr.Slider(label="Classier-free Guidance", minimum=0.1, maximum=10.0, value=7.5, step=0.1)
234
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
235
+ submit_btn = gr.Button("Submit")
236
+ with gr.Column():
237
+ b_a_slider = ImageSlider(label="B/A result", position=0.5)
238
+ file_output = gr.File(label="Downloadable image result")
239
+
240
+ submit_btn.click(
241
+ fn = inference,
242
+ inputs = [
243
+ input_image, prompt_in,
244
+ added_prompt, neg_prompt,
245
+ denoise_steps,
246
+ upsample_scale, condition_scale,
247
+ classifier_free_guidance, seed
248
+ ],
249
+ outputs = [
250
+ b_a_slider,
251
+ file_output
252
+ ]
253
+ )
254
+ demo.queue(max_size=20).launch()