JOY-Huang commited on
Commit
7eb74c8
·
1 Parent(s): 1398a0f

update size processing

Browse files
Files changed (1) hide show
  1. app.py +63 -47
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import torch
3
  import spaces
4
 
@@ -7,7 +8,6 @@ import gradio as gr
7
  from PIL import Image
8
 
9
  from diffusers import DDPMScheduler
10
- from diffusers.utils import load_image
11
  from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
12
 
13
  from module.ip_adapter.utils import load_adapter_to_pipe
@@ -16,19 +16,33 @@ from pipelines.sdxl_instantir import InstantIRPipeline
16
  from huggingface_hub import hf_hub_download
17
 
18
 
19
- def resize_img(input_image, max_side=1280, min_side=1024, size=None,
20
  pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
21
 
22
  w, h = input_image.size
23
- if size is not None:
24
- w_resize_new, h_resize_new = size
 
 
 
 
 
 
 
25
  else:
26
- # ratio = min_side / min(h, w)
27
- # w, h = round(ratio*w), round(ratio*h)
28
- ratio = max_side / max(h, w)
29
- input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
30
- w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
31
- h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
 
 
 
 
 
 
 
32
  input_image = input_image.resize([w_resize_new, h_resize_new], mode)
33
 
34
  if pad_to_max_side:
@@ -37,7 +51,7 @@ def resize_img(input_image, max_side=1280, min_side=1024, size=None,
37
  offset_y = (max_side - h_resize_new) // 2
38
  res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
39
  input_image = Image.fromarray(res)
40
- return input_image
41
 
42
 
43
  if not os.path.exists("models/adapter.pt"):
@@ -52,10 +66,7 @@ sdxl_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
52
  dinov2_repo_id = "facebook/dinov2-large"
53
  lcm_repo_id = "latent-consistency/lcm-lora-sdxl"
54
 
55
- if torch.cuda.is_available():
56
- torch_dtype = torch.float16
57
- else:
58
- torch_dtype = torch.float32
59
 
60
  # Load pretrained models.
61
  print("Initializing pipeline...")
@@ -96,7 +107,8 @@ pipe.aggregator.load_state_dict(aggregator_state_dict)
96
  pipe.aggregator.to(device=device, dtype=torch_dtype)
97
 
98
  MAX_SEED = np.iinfo(np.int32).max
99
- MAX_IMAGE_SIZE = 1024
 
100
 
101
  PROMPT = "Photorealistic, highly detailed, hyper detailed photo - realistic maximum detail, 32k, \
102
  ultra HD, extreme meticulous detailing, skin pore detailing, \
@@ -108,11 +120,15 @@ sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, \
108
  dirty, messy, worst quality, low quality, frames, painting, illustration, drawing, art, \
109
  watermark, signature, jpeg artifacts, deformed, lowres"
110
 
 
 
 
 
 
111
  def unpack_pipe_out(preview_row, index):
112
  return preview_row[index][0]
113
 
114
  def dynamic_preview_slider(sampling_steps):
115
- print(sampling_steps)
116
  return gr.Slider(label="Restoration Previews", value=sampling_steps-1, minimum=0, maximum=sampling_steps-1, step=1)
117
 
118
  def dynamic_guidance_slider(sampling_steps):
@@ -121,10 +137,14 @@ def dynamic_guidance_slider(sampling_steps):
121
  def show_final_preview(preview_row):
122
  return preview_row[-1][0]
123
 
124
- @spaces.GPU
125
  def instantir_restore(
126
  lq, prompt="", steps=30, cfg_scale=7.0, guidance_end=1.0,
127
- creative_restoration=False, seed=3407, height=1024, width=1024, preview_start=0.0):
 
 
 
 
128
  if creative_restoration:
129
  if "lcm" not in pipe.unet.active_adapters():
130
  pipe.unet.set_adapter('lcm')
@@ -140,10 +160,8 @@ def instantir_restore(
140
  preview_start = preview_start / steps
141
  elif preview_start > 1.0:
142
  preview_start = preview_start / steps
143
- print(lq)
144
- lq = load_image(lq)
145
- print(type(lq))
146
- lq = [resize_img(lq.convert("RGB"), size=(width, height))]
147
  generator = torch.Generator(device=device).manual_seed(seed)
148
  timesteps = [
149
  i * (1000//steps) + pipe.scheduler.config.steps_offset for i in range(0, steps)
@@ -167,8 +185,10 @@ def instantir_restore(
167
  return_dict=False,
168
  save_preview_row=True,
169
  )
170
- for i, preview_img in enumerate(out[1]):
171
- preview_img.append(f"preview_{i}")
 
 
172
  return out[0][0], out[1]
173
 
174
  css="""
@@ -182,7 +202,6 @@ with gr.Blocks() as demo:
182
  gr.Markdown(
183
  """
184
  # InstantIR: Blind Image Restoration with Instant Generative Reference.
185
-
186
  ### **Official 🤗 Gradio demo of [InstantIR](https://arxiv.org/abs/2410.06551).**
187
  ### **InstantIR can not only help you restore your broken image, but also capable of imaginative re-creation following your text prompts. See advance usage for more details!**
188
  ## Basic usage: revitalize your image
@@ -191,34 +210,37 @@ with gr.Blocks() as demo:
191
  3. Click `InstantIR magic!`.
192
  """)
193
  with gr.Row():
194
- lq_img = gr.Image(label="Low-quality image", type="filepath")
195
  with gr.Column():
 
 
 
 
196
  with gr.Row():
197
  steps = gr.Number(label="Steps", value=30, step=1)
198
  cfg_scale = gr.Number(label="CFG Scale", value=7.0, step=0.1)
199
  with gr.Row():
200
- height = gr.Number(label="Height", value=1024, step=1)
201
- weight = gr.Number(label="Weight", value=1024, step=1)
202
  seed = gr.Number(label="Seed", value=42, step=1)
203
- # guidance_start = gr.Slider(label="Guidance Start", value=1.0, minimum=0.0, maximum=1.0, step=0.05)
204
  guidance_end = gr.Slider(label="Start Free Rendering", value=30, minimum=0, maximum=30, step=1)
205
  preview_start = gr.Slider(label="Preview Start", value=0, minimum=0, maximum=30, step=1)
206
- prompt = gr.Textbox(label="Restoration prompts (Optional)", placeholder="", value="")
207
  mode = gr.Checkbox(label="Creative Restoration", value=False)
208
- with gr.Row():
209
- with gr.Row():
210
- restore_btn = gr.Button("InstantIR magic!")
211
- clear_btn = gr.ClearButton()
212
- index = gr.Slider(label="Restoration Previews", value=29, minimum=0, maximum=29, step=1)
213
- with gr.Row():
214
- output = gr.Image(label="InstantIR restored", type="filepath")
215
- preview = gr.Image(label="Preview", type="filepath")
 
216
  pipe_out = gr.Gallery(visible=False)
217
  clear_btn.add([lq_img, output, preview])
218
  restore_btn.click(
219
  instantir_restore, inputs=[
220
  lq_img, prompt, steps, cfg_scale, guidance_end,
221
- mode, seed, height, weight, preview_start,
222
  ],
223
  outputs=[output, pipe_out], api_name="InstantIR"
224
  )
@@ -236,17 +258,11 @@ with gr.Blocks() as demo:
236
  1. Check the `Creative Restoration` checkbox;
237
  2. Input your text prompts in the `Restoration prompts` textbox;
238
  3. Set `Start Free Rendering` slider to a medium value (around half of the `steps`) to provide adequate room for InstantIR creation.
239
-
240
- ## Examples
241
- Here are some examplar usage of InstantIR:
242
  """)
243
- # examples = gr.Gallery(label="Examples")
244
-
245
  gr.Markdown(
246
  """
247
  ## Citation
248
  If InstantIR is helpful to your work, please cite our paper via:
249
-
250
  ```
251
  @article{huang2024instantir,
252
  title={InstantIR: Blind Image Restoration with Instant Generative Reference},
@@ -257,4 +273,4 @@ with gr.Blocks() as demo:
257
  ```
258
  """)
259
 
260
- demo.queue().launch(debug=True)
 
1
  import os
2
+ import random
3
  import torch
4
  import spaces
5
 
 
8
  from PIL import Image
9
 
10
  from diffusers import DDPMScheduler
 
11
  from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
12
 
13
  from module.ip_adapter.utils import load_adapter_to_pipe
 
16
  from huggingface_hub import hf_hub_download
17
 
18
 
19
+ def resize_img(input_image, max_side=1024, min_side=768, width=None, height=None,
20
  pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
21
 
22
  w, h = input_image.size
23
+ # Prepare output size
24
+ if width is not None and height is not None:
25
+ out_w, out_h = width, height
26
+ elif width is not None:
27
+ out_w = width
28
+ out_h = round(h * width / w)
29
+ elif height is not None:
30
+ out_h = height
31
+ out_w = round(w * height / h)
32
  else:
33
+ out_w, out_h = w, h
34
+
35
+ # Resize input to runtime size
36
+ w, h = out_w, out_h
37
+ if min(w, h) < min_side:
38
+ ratio = min_side / min(w, h)
39
+ w, h = round(ratio * w), round(ratio * h)
40
+ if max(w, h) > max_side:
41
+ ratio = max_side / max(w, h)
42
+ w, h = round(ratio * w), round(ratio * h)
43
+ # Resize to cope with UNet and VAE operations
44
+ w_resize_new = (w // base_pixel_number) * base_pixel_number
45
+ h_resize_new = (h // base_pixel_number) * base_pixel_number
46
  input_image = input_image.resize([w_resize_new, h_resize_new], mode)
47
 
48
  if pad_to_max_side:
 
51
  offset_y = (max_side - h_resize_new) // 2
52
  res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
53
  input_image = Image.fromarray(res)
54
+ return input_image, (out_w, out_h)
55
 
56
 
57
  if not os.path.exists("models/adapter.pt"):
 
66
  dinov2_repo_id = "facebook/dinov2-large"
67
  lcm_repo_id = "latent-consistency/lcm-lora-sdxl"
68
 
69
+ torch_dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
 
 
 
70
 
71
  # Load pretrained models.
72
  print("Initializing pipeline...")
 
107
  pipe.aggregator.to(device=device, dtype=torch_dtype)
108
 
109
  MAX_SEED = np.iinfo(np.int32).max
110
+ MAX_IMAGE_SIZE = 1280
111
+ MIN_IMAGE_SIZE = 1024
112
 
113
  PROMPT = "Photorealistic, highly detailed, hyper detailed photo - realistic maximum detail, 32k, \
114
  ultra HD, extreme meticulous detailing, skin pore detailing, \
 
120
  dirty, messy, worst quality, low quality, frames, painting, illustration, drawing, art, \
121
  watermark, signature, jpeg artifacts, deformed, lowres"
122
 
123
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
124
+ if randomize_seed:
125
+ seed = random.randint(0, MAX_SEED)
126
+ return seed
127
+
128
  def unpack_pipe_out(preview_row, index):
129
  return preview_row[index][0]
130
 
131
  def dynamic_preview_slider(sampling_steps):
 
132
  return gr.Slider(label="Restoration Previews", value=sampling_steps-1, minimum=0, maximum=sampling_steps-1, step=1)
133
 
134
  def dynamic_guidance_slider(sampling_steps):
 
137
  def show_final_preview(preview_row):
138
  return preview_row[-1][0]
139
 
140
+ @spaces.GPU(duration=70)
141
  def instantir_restore(
142
  lq, prompt="", steps=30, cfg_scale=7.0, guidance_end=1.0,
143
+ creative_restoration=False, seed=3407, height=None, width=None, preview_start=0.0):
144
+ print(type(height), type(width))
145
+ print(height, width)
146
+ print(type(prompt))
147
+ print(prompt)
148
  if creative_restoration:
149
  if "lcm" not in pipe.unet.active_adapters():
150
  pipe.unet.set_adapter('lcm')
 
160
  preview_start = preview_start / steps
161
  elif preview_start > 1.0:
162
  preview_start = preview_start / steps
163
+
164
+ lq, out_size = [resize_img(lq, width=width, height=height)]
 
 
165
  generator = torch.Generator(device=device).manual_seed(seed)
166
  timesteps = [
167
  i * (1000//steps) + pipe.scheduler.config.steps_offset for i in range(0, steps)
 
185
  return_dict=False,
186
  save_preview_row=True,
187
  )
188
+ out[0][0] = out[0][0].resize(out_size[0], out_size[1], Image.BILINEAR)
189
+ for i, preview_tuple in enumerate(out[1]):
190
+ preview_tuple[0] = preview_tuple[0].resize(out_size[0], out_size[1], Image.BILINEAR)
191
+ preview_tuple.append(f"preview_{i}")
192
  return out[0][0], out[1]
193
 
194
  css="""
 
202
  gr.Markdown(
203
  """
204
  # InstantIR: Blind Image Restoration with Instant Generative Reference.
 
205
  ### **Official 🤗 Gradio demo of [InstantIR](https://arxiv.org/abs/2410.06551).**
206
  ### **InstantIR can not only help you restore your broken image, but also capable of imaginative re-creation following your text prompts. See advance usage for more details!**
207
  ## Basic usage: revitalize your image
 
210
  3. Click `InstantIR magic!`.
211
  """)
212
  with gr.Row():
 
213
  with gr.Column():
214
+ lq_img = gr.Image(label="Low-quality image", type="pil")
215
+ with gr.Row():
216
+ restore_btn = gr.Button("InstantIR magic!")
217
+ clear_btn = gr.ClearButton()
218
  with gr.Row():
219
  steps = gr.Number(label="Steps", value=30, step=1)
220
  cfg_scale = gr.Number(label="CFG Scale", value=7.0, step=0.1)
221
  with gr.Row():
222
+ height = gr.Number(label="Height", step=1, placeholder="Auto", maximum=MAX_IMAGE_SIZE)
223
+ width = gr.Number(label="Width", step=1, placeholder="Auto", maximum=MAX_IMAGE_SIZE)
224
  seed = gr.Number(label="Seed", value=42, step=1)
 
225
  guidance_end = gr.Slider(label="Start Free Rendering", value=30, minimum=0, maximum=30, step=1)
226
  preview_start = gr.Slider(label="Preview Start", value=0, minimum=0, maximum=30, step=1)
227
+ prompt = gr.Textbox(label="Restoration prompts (Optional)", placeholder="")
228
  mode = gr.Checkbox(label="Creative Restoration", value=False)
229
+ # gr.Examples(
230
+ # examples = ["assets/lady.png", "assets/man.png", "assets/dog.png", "assets/panda.png", "assets/sculpture.png", "assets/cottage.png", "assets/Naruto.png", "assets/Konan.png"],
231
+ # inputs = [lq_img]
232
+ # )
233
+ with gr.Column():
234
+ output = gr.Image(label="InstantIR restored", type="pil")
235
+ index = gr.Slider(label="Restoration Previews", value=29, minimum=0, maximum=29, step=1)
236
+ preview = gr.Image(label="Preview", type="pil")
237
+
238
  pipe_out = gr.Gallery(visible=False)
239
  clear_btn.add([lq_img, output, preview])
240
  restore_btn.click(
241
  instantir_restore, inputs=[
242
  lq_img, prompt, steps, cfg_scale, guidance_end,
243
+ mode, seed, height, width, preview_start,
244
  ],
245
  outputs=[output, pipe_out], api_name="InstantIR"
246
  )
 
258
  1. Check the `Creative Restoration` checkbox;
259
  2. Input your text prompts in the `Restoration prompts` textbox;
260
  3. Set `Start Free Rendering` slider to a medium value (around half of the `steps`) to provide adequate room for InstantIR creation.
 
 
 
261
  """)
 
 
262
  gr.Markdown(
263
  """
264
  ## Citation
265
  If InstantIR is helpful to your work, please cite our paper via:
 
266
  ```
267
  @article{huang2024instantir,
268
  title={InstantIR: Blind Image Restoration with Instant Generative Reference},
 
273
  ```
274
  """)
275
 
276
+ demo.queue().launch()