KevinQu7 commited on
Commit
a8e6640
1 Parent(s): 40be8c2

update -gitattributes

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
38
+ *.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,29 @@
1
  ---
2
- title: Marigold Iid Private
3
- emoji: 🏢
4
- colorFrom: gray
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.9.1
8
  app_file: app.py
9
- pinned: false
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Marigold Intrinsic Image Decomposition
3
+ emoji: 🏵️
4
+ colorFrom: purple
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.21.0
8
  app_file: app.py
9
+ pinned: true
10
+ license: cc-by-sa-4.0
11
+ hf_oauth: true
12
+ hf_oauth_expiration_minutes: 43200
13
  ---
14
 
15
+ This is a demo of Marigold-IID, the state-of-the-art intrinsic image decomposition model for images in the wild.
16
+ We provide two models:
17
+ - Marigold-IID-Appearance which predicts albedo, metallic and roughness
18
+ - Marigold-IID-Lighting which predicts albedo, diffuse shading and non-diffuse residual
19
+
20
+ Find out more in our CVPR 2024 paper titled ["Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation"](https://arxiv.org/abs/2312.02145)
21
+
22
+ ```
23
+ @InProceedings{ke2023repurposing,
24
+ title={Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation},
25
+ author={Bingxin Ke and Anton Obukhov and Shengyu Huang and Nando Metzger and Rodrigo Caye Daudt and Konrad Schindler},
26
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
27
+ year={2024}
28
+ }
29
+ ```
app.py CHANGED
@@ -14,7 +14,7 @@
14
  # --------------------------------------------------------------------------
15
  # If you find this code useful, we kindly ask you to cite our paper in your work.
16
  # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
- # More information about the method can be found at https://marigoldmonodepth.github.io
18
  # --------------------------------------------------------------------------
19
  from __future__ import annotations
20
 
@@ -28,7 +28,6 @@ import gradio as gr
28
  import numpy as np
29
  import torch as torch
30
  from PIL import Image
31
- from diffusers import UNet2DConditionModel
32
 
33
  from gradio_imageslider import ImageSlider
34
  from huggingface_hub import login
@@ -36,7 +35,7 @@ from huggingface_hub import login
36
  from gradio_patches.examples import Examples
37
  from gradio_patches.flagging import HuggingFaceDatasetSaver, FlagMethod
38
  from marigold_iid_appearance import MarigoldIIDAppearancePipeline
39
- from marigold_iid_residual import MarigoldIIDResidualPipeline
40
 
41
  warnings.filterwarnings(
42
  "ignore", message=".*LoginButton created outside of a Blocks context.*"
@@ -48,36 +47,53 @@ default_image_denoise_steps = 4
48
  default_image_ensemble_size = 1
49
  default_image_processing_res = 768
50
  default_image_reproducuble = True
51
- default_model_type="appearance"
52
 
53
  default_share_always_show_hf_logout_btn = True
54
  default_share_always_show_accordion = False
55
 
56
  loaded_pipelines = {} # Cache to store loaded pipelines
57
- def process_with_loaded_pipeline(image_path, denoise_steps, ensemble_size, processing_res, model_type):
58
 
 
 
 
 
 
 
 
 
59
  # Load and cache the pipeline based on the model type.
60
  if model_type not in loaded_pipelines.keys():
61
- auth_token = os.environ.get("KEV_DEV")
62
  if model_type == "appearance":
63
- loaded_pipelines[model_type] = MarigoldIIDAppearancePipeline.from_pretrained(
64
- "prs-eth/marigold-iid-appearance-v1-1", token=auth_token
 
 
 
 
 
 
 
65
  )
66
- elif model_type == "residual":
67
- loaded_pipelines[model_type] = MarigoldIIDResidualPipeline.from_pretrained(
68
- "prs-eth/marigold-iid-residual-v1-1", token=auth_token
 
 
 
 
 
69
  )
70
-
71
  # Move the pipeline to GPU if available
72
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
  loaded_pipelines[model_type] = loaded_pipelines[model_type].to(device)
74
  try:
75
- import xformers
76
-
77
  loaded_pipelines[model_type].enable_xformers_memory_efficient_attention()
78
  except:
79
  pass # run without xformers
80
-
81
  pipe = loaded_pipelines[model_type]
82
 
83
  # Process the image using the preloaded pipeline.
@@ -90,12 +106,14 @@ def process_with_loaded_pipeline(image_path, denoise_steps, ensemble_size, proce
90
  model_type=model_type,
91
  )
92
 
 
93
  def process_image_check(path_input):
94
  if path_input is None:
95
  raise gr.Error(
96
  "Missing image in the first pane: upload a file or use one from the gallery below."
97
  )
98
 
 
99
  def process_image(
100
  pipe,
101
  path_input,
@@ -111,73 +129,108 @@ def process_image(
111
 
112
  input_image = Image.open(path_input)
113
 
114
-
115
  pipe_out = pipe(
116
  input_image,
117
  denoising_steps=denoise_steps,
118
  ensemble_size=ensemble_size,
119
  processing_res=processing_res,
120
- batch_size=1 if processing_res == 0 else 0, # TODO: do we abuse "batch size" notation here?
 
 
121
  seed=default_seed,
122
  show_progress_bar=True,
123
  )
124
-
125
  path_output_dir = os.path.splitext(path_input)[0] + "_output"
126
  os.makedirs(path_output_dir, exist_ok=True)
127
-
128
- path_albedo_out = os.path.join(path_output_dir, f"{name_base}_albedo_fp32.npy")
129
- path_albedo_out_vis = os.path.join(path_output_dir, f"{name_base}_albedo.png")
130
-
131
- albedo = pipe_out.albedo
132
- albedo_colored = pipe_out.albedo_colored
133
-
134
- np.save(path_albedo_out, albedo)
135
- albedo_colored.save(path_albedo_out_vis)
136
-
137
-
138
  if model_type == "appearance":
139
- path_material_out = os.path.join(path_output_dir, f"{name_base}_material_fp32.npy")
140
- path_material_out_vis = os.path.join(path_output_dir, f"{name_base}_material.png")
141
-
 
 
 
 
 
 
 
 
 
 
 
 
142
  material = pipe_out.material
143
  material_colored = pipe_out.material_colored
144
-
 
 
145
  np.save(path_material_out, material)
146
  material_colored.save(path_material_out_vis)
147
-
148
  return (
149
- [path_input, path_albedo_out_vis],
150
  [path_input, path_material_out_vis],
151
- None,
152
- [path_albedo_out_vis, path_material_out_vis, path_albedo_out, path_material_out],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  )
154
-
155
- elif model_type == "residual":
156
- path_shading_out = os.path.join(path_output_dir, f"{name_base}_shading_fp32.npy")
157
  path_shading_out_vis = os.path.join(path_output_dir, f"{name_base}_shading.png")
158
- path_residual_out = os.path.join(path_output_dir, f"{name_base}_residual_fp32.npy")
159
- path_residual_out_vis = os.path.join(path_output_dir, f"{name_base}_residual.png")
160
-
 
 
 
 
 
 
161
  shading = pipe_out.shading
162
  shading_colored = pipe_out.shading_colored
163
  residual = pipe_out.residual
164
  residual_colored = pipe_out.residual_colored
165
-
 
 
166
  np.save(path_shading_out, shading)
167
  shading_colored.save(path_shading_out_vis)
168
  np.save(path_residual_out, residual)
169
  residual_colored.save(path_residual_out_vis)
170
-
171
  return (
172
- [path_input, path_albedo_out_vis],
173
  [path_input, path_shading_out_vis],
174
  [path_input, path_residual_out_vis],
175
- [path_albedo_out_vis, path_shading_out_vis, path_residual_out_vis, path_albedo_out, path_shading_out, path_residual_out],
 
 
 
 
 
 
 
176
  )
177
 
178
 
179
  def run_demo_server(hf_writer=None):
180
- process_pipe_image = spaces.GPU(functools.partial(process_with_loaded_pipeline), duration=120)
 
 
181
  gradio_theme = gr.themes.Default()
182
 
183
  with gr.Blocks(
@@ -233,7 +286,7 @@ def run_demo_server(hf_writer=None):
233
 
234
  gr.Markdown(
235
  """
236
- # Marigold Normals Estimation
237
 
238
  <p align="center">
239
  <a title="Website" href="https://marigoldcomputervision.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
@@ -280,25 +333,25 @@ def run_demo_server(hf_writer=None):
280
  )
281
  model_type = gr.Radio(
282
  [
283
- ("Appearance (Albedo & Material)", "appearance"),
284
- ("Residual (Albedo, Shading & Residual)", "residual"),
285
  ],
286
- label="Model Type",
287
  value=default_model_type,
288
  )
289
-
290
  with gr.Accordion("Advanced options", open=True):
291
  image_ensemble_size = gr.Slider(
292
  label="Ensemble size",
293
  minimum=1,
294
- maximum=10,
295
  step=1,
296
  value=default_image_ensemble_size,
297
  )
298
  image_denoise_steps = gr.Slider(
299
  label="Number of denoising steps",
300
  minimum=1,
301
- maximum=20,
302
  step=1,
303
  value=default_image_denoise_steps,
304
  )
@@ -311,7 +364,7 @@ def run_demo_server(hf_writer=None):
311
  value=default_image_processing_res,
312
  )
313
  with gr.Row():
314
- image_submit_btn = gr.Button(value="Compute Normals", variant="primary")
315
  image_reset_btn = gr.Button(value="Reset")
316
  with gr.Column():
317
  image_output_slider1 = ImageSlider(
@@ -322,7 +375,7 @@ def run_demo_server(hf_writer=None):
322
  interactive=False,
323
  elem_classes="slider",
324
  position=0.25,
325
- visible=True
326
  )
327
  image_output_slider2 = ImageSlider(
328
  label="Predicted Material",
@@ -332,7 +385,7 @@ def run_demo_server(hf_writer=None):
332
  interactive=False,
333
  elem_classes="slider",
334
  position=0.25,
335
- visible=True
336
  )
337
  image_output_slider3 = ImageSlider(
338
  label="Predicted Residual",
@@ -342,7 +395,7 @@ def run_demo_server(hf_writer=None):
342
  interactive=False,
343
  elem_classes="slider",
344
  position=0.25,
345
- visible=False
346
  )
347
  image_output_files = gr.Files(
348
  label="Output files",
@@ -352,9 +405,9 @@ def run_demo_server(hf_writer=None):
352
 
353
  if hf_writer is not None:
354
  with gr.Accordion(
355
- "Feedback",
356
- open=False,
357
- visible=default_share_always_show_accordion,
358
  ) as share_box:
359
  share_instructions = gr.Markdown(
360
  get_share_instructions(is_full=True),
@@ -362,16 +415,16 @@ def run_demo_server(hf_writer=None):
362
  )
363
  share_transfer_of_rights = gr.Checkbox(
364
  label="(Optional) I own or hold necessary rights to the submitted image. By "
365
- "checking this box, I grant an irrevocable, non-exclusive, transferable, "
366
- "royalty-free, worldwide license to use the uploaded image, including for "
367
- "publishing, reproducing, and model training. [transfer_of_rights]",
368
  scale=1,
369
  )
370
  share_content_is_legal = gr.Checkbox(
371
  label="By checking this box, I acknowledge that my uploaded content is legal and "
372
- "safe, and that I am solely responsible for ensuring it complies with all "
373
- "applicable laws and regulations. Additionally, I am aware that my Hugging Face "
374
- "username is collected. [content_is_legal]",
375
  scale=1,
376
  )
377
  share_reason = gr.Textbox(
@@ -384,7 +437,7 @@ def run_demo_server(hf_writer=None):
384
  share_share_btn = gr.Button(
385
  "Share", variant="stop", scale=1
386
  )
387
-
388
  # Function to toggle visibility and set dynamic labels
389
  def toggle_sliders_and_labels(model_type):
390
  if model_type == "appearance":
@@ -393,7 +446,7 @@ def run_demo_server(hf_writer=None):
393
  gr.update(visible=True, label="Predicted Material"),
394
  gr.update(visible=False), # Hide third slider
395
  )
396
- elif model_type == "residual":
397
  return (
398
  gr.update(visible=True, label="Predicted Albedo"),
399
  gr.update(visible=True, label="Predicted Shading"),
@@ -407,36 +460,35 @@ def run_demo_server(hf_writer=None):
407
  outputs=[image_output_slider1, image_output_slider2, image_output_slider3],
408
  show_progress=False,
409
  )
410
-
411
  Examples(
412
  fn=process_pipe_image,
413
  examples=[
414
- os.path.join("files", "image", name)
415
  for name in [
416
- "berries.jpeg",
 
 
 
417
  "costumes.png",
 
 
418
  "cat.jpg",
419
- "einstein.jpg",
420
  "food.jpeg",
421
- "food_counter.png",
422
  "puzzle.jpeg",
423
- "rocket.png",
424
- "scientists.jpg",
425
- "cat2.png",
426
  "screw.png",
427
- "statues.png",
428
- "swings.jpg"
429
  ]
 
430
  ],
431
- inputs=[image_input],
432
- outputs= [
433
  image_output_slider1,
434
  image_output_slider2,
435
  image_output_slider3,
436
- image_output_files
437
  ],
438
- cache_examples=False, # TODO: toggle later
439
- directory_name="examples_image",
440
  )
441
 
442
  ### Image tab
@@ -474,17 +526,17 @@ def run_demo_server(hf_writer=None):
474
  fn=process_pipe_image,
475
  inputs=[
476
  image_input,
 
477
  image_denoise_steps,
478
  image_ensemble_size,
479
  image_processing_res,
480
- model_type
481
  ],
482
- outputs= [
483
- image_output_slider1,
484
- image_output_slider2,
485
- image_output_slider3,
486
- image_output_files
487
- ],
488
  concurrency_limit=1,
489
  )
490
  else:
@@ -498,17 +550,17 @@ def run_demo_server(hf_writer=None):
498
  fn=process_pipe_image,
499
  inputs=[
500
  image_input,
 
501
  image_denoise_steps,
502
  image_ensemble_size,
503
  image_processing_res,
504
- model_type
505
  ],
506
- outputs= [
507
- image_output_slider1,
508
- image_output_slider2,
509
- image_output_slider3,
510
- image_output_files
511
- ],
512
  concurrency_limit=1,
513
  )
514
 
 
14
  # --------------------------------------------------------------------------
15
  # If you find this code useful, we kindly ask you to cite our paper in your work.
16
  # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldcomputervision.github.io
18
  # --------------------------------------------------------------------------
19
  from __future__ import annotations
20
 
 
28
  import numpy as np
29
  import torch as torch
30
  from PIL import Image
 
31
 
32
  from gradio_imageslider import ImageSlider
33
  from huggingface_hub import login
 
35
  from gradio_patches.examples import Examples
36
  from gradio_patches.flagging import HuggingFaceDatasetSaver, FlagMethod
37
  from marigold_iid_appearance import MarigoldIIDAppearancePipeline
38
+ from marigold_iid_lighting import MarigoldIIDLightingPipeline
39
 
40
  warnings.filterwarnings(
41
  "ignore", message=".*LoginButton created outside of a Blocks context.*"
 
47
  default_image_ensemble_size = 1
48
  default_image_processing_res = 768
49
  default_image_reproducuble = True
50
+ default_model_type = "appearance"
51
 
52
  default_share_always_show_hf_logout_btn = True
53
  default_share_always_show_accordion = False
54
 
55
  loaded_pipelines = {} # Cache to store loaded pipelines
 
56
 
57
+
58
+ def process_with_loaded_pipeline(
59
+ image_path,
60
+ model_type=default_model_type,
61
+ denoise_steps=default_image_denoise_steps,
62
+ ensemble_size=default_image_ensemble_size,
63
+ processing_res=default_image_processing_res,
64
+ ):
65
  # Load and cache the pipeline based on the model type.
66
  if model_type not in loaded_pipelines.keys():
67
+ auth_token = os.environ.get("KEV_TOKEN")
68
  if model_type == "appearance":
69
+ if "lighting" in loaded_pipelines.keys():
70
+ del loaded_pipelines[
71
+ "lighting"
72
+ ] # to save GPU memory. Can be removed if enough memory is available for faster switching between models
73
+ torch.cuda.empty_cache()
74
+ loaded_pipelines[model_type] = (
75
+ MarigoldIIDAppearancePipeline.from_pretrained(
76
+ "prs-eth/marigold-iid-appearance-v1-1", token=auth_token
77
+ )
78
  )
79
+ elif model_type == "lighting":
80
+ if "appearance" in loaded_pipelines.keys():
81
+ del loaded_pipelines[
82
+ "appearance"
83
+ ] # to save GPU memory. Can be removed if enough memory is available for faster switching between models
84
+ torch.cuda.empty_cache()
85
+ loaded_pipelines[model_type] = MarigoldIIDLightingPipeline.from_pretrained(
86
+ "prs-eth/marigold-iid-lighting-v1-1", token=auth_token
87
  )
88
+
89
  # Move the pipeline to GPU if available
90
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
91
  loaded_pipelines[model_type] = loaded_pipelines[model_type].to(device)
92
  try:
 
 
93
  loaded_pipelines[model_type].enable_xformers_memory_efficient_attention()
94
  except:
95
  pass # run without xformers
96
+
97
  pipe = loaded_pipelines[model_type]
98
 
99
  # Process the image using the preloaded pipeline.
 
106
  model_type=model_type,
107
  )
108
 
109
+
110
  def process_image_check(path_input):
111
  if path_input is None:
112
  raise gr.Error(
113
  "Missing image in the first pane: upload a file or use one from the gallery below."
114
  )
115
 
116
+
117
  def process_image(
118
  pipe,
119
  path_input,
 
129
 
130
  input_image = Image.open(path_input)
131
 
 
132
  pipe_out = pipe(
133
  input_image,
134
  denoising_steps=denoise_steps,
135
  ensemble_size=ensemble_size,
136
  processing_res=processing_res,
137
+ batch_size=1
138
+ if processing_res == 0
139
+ else 0, # TODO: do we abuse "batch size" notation here?
140
  seed=default_seed,
141
  show_progress_bar=True,
142
  )
143
+
144
  path_output_dir = os.path.splitext(path_input)[0] + "_output"
145
  os.makedirs(path_output_dir, exist_ok=True)
146
+
 
 
 
 
 
 
 
 
 
 
147
  if model_type == "appearance":
148
+ path_albedo_out = os.path.join(
149
+ path_output_dir, f"{name_base}_albedo_app_fp32.npy"
150
+ )
151
+ path_albedo_out_vis = os.path.join(
152
+ path_output_dir, f"{name_base}_albedo_app.png"
153
+ )
154
+ path_material_out = os.path.join(
155
+ path_output_dir, f"{name_base}_material_fp32.npy"
156
+ )
157
+ path_material_out_vis = os.path.join(
158
+ path_output_dir, f"{name_base}_material.png"
159
+ )
160
+
161
+ albedo = pipe_out.albedo
162
+ albedo_colored = pipe_out.albedo_colored
163
  material = pipe_out.material
164
  material_colored = pipe_out.material_colored
165
+
166
+ np.save(path_albedo_out, albedo)
167
+ albedo_colored.save(path_albedo_out_vis)
168
  np.save(path_material_out, material)
169
  material_colored.save(path_material_out_vis)
170
+
171
  return (
172
+ [path_input, path_albedo_out_vis],
173
  [path_input, path_material_out_vis],
174
+ [path_input, path_material_out_vis], # placeholder which is not displayed
175
+ [
176
+ path_albedo_out_vis,
177
+ path_material_out_vis,
178
+ path_albedo_out,
179
+ path_material_out,
180
+ ],
181
+ )
182
+
183
+ elif model_type == "lighting":
184
+ path_albedo_out = os.path.join(
185
+ path_output_dir, f"{name_base}_albedo_res_fp32.npy"
186
+ )
187
+ path_albedo_out_vis = os.path.join(
188
+ path_output_dir, f"{name_base}_albedo_res.png"
189
+ )
190
+ path_shading_out = os.path.join(
191
+ path_output_dir, f"{name_base}_shading_fp32.npy"
192
  )
 
 
 
193
  path_shading_out_vis = os.path.join(path_output_dir, f"{name_base}_shading.png")
194
+ path_residual_out = os.path.join(
195
+ path_output_dir, f"{name_base}_residual_fp32.npy"
196
+ )
197
+ path_residual_out_vis = os.path.join(
198
+ path_output_dir, f"{name_base}_residual.png"
199
+ )
200
+
201
+ albedo = pipe_out.albedo
202
+ albedo_colored = pipe_out.albedo_colored
203
  shading = pipe_out.shading
204
  shading_colored = pipe_out.shading_colored
205
  residual = pipe_out.residual
206
  residual_colored = pipe_out.residual_colored
207
+
208
+ np.save(path_albedo_out, albedo)
209
+ albedo_colored.save(path_albedo_out_vis)
210
  np.save(path_shading_out, shading)
211
  shading_colored.save(path_shading_out_vis)
212
  np.save(path_residual_out, residual)
213
  residual_colored.save(path_residual_out_vis)
214
+
215
  return (
216
+ [path_input, path_albedo_out_vis],
217
  [path_input, path_shading_out_vis],
218
  [path_input, path_residual_out_vis],
219
+ [
220
+ path_albedo_out_vis,
221
+ path_shading_out_vis,
222
+ path_residual_out_vis,
223
+ path_albedo_out,
224
+ path_shading_out,
225
+ path_residual_out,
226
+ ],
227
  )
228
 
229
 
230
  def run_demo_server(hf_writer=None):
231
+ process_pipe_image = spaces.GPU(
232
+ functools.partial(process_with_loaded_pipeline), duration=120
233
+ )
234
  gradio_theme = gr.themes.Default()
235
 
236
  with gr.Blocks(
 
286
 
287
  gr.Markdown(
288
  """
289
+ # Marigold Intrinsic Image Decomposition (IID)
290
 
291
  <p align="center">
292
  <a title="Website" href="https://marigoldcomputervision.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
 
333
  )
334
  model_type = gr.Radio(
335
  [
336
+ ("Appearance (albedo & material)", "appearance"),
337
+ ("Lighting (albedo, shading & residual)", "lighting"),
338
  ],
339
+ label="Model type: Marigold-IID-Appearance or Marigold IID-Lighting",
340
  value=default_model_type,
341
  )
342
+
343
  with gr.Accordion("Advanced options", open=True):
344
  image_ensemble_size = gr.Slider(
345
  label="Ensemble size",
346
  minimum=1,
347
+ maximum=5,
348
  step=1,
349
  value=default_image_ensemble_size,
350
  )
351
  image_denoise_steps = gr.Slider(
352
  label="Number of denoising steps",
353
  minimum=1,
354
+ maximum=10,
355
  step=1,
356
  value=default_image_denoise_steps,
357
  )
 
364
  value=default_image_processing_res,
365
  )
366
  with gr.Row():
367
+ image_submit_btn = gr.Button(value="Compute IID", variant="primary")
368
  image_reset_btn = gr.Button(value="Reset")
369
  with gr.Column():
370
  image_output_slider1 = ImageSlider(
 
375
  interactive=False,
376
  elem_classes="slider",
377
  position=0.25,
378
+ visible=True,
379
  )
380
  image_output_slider2 = ImageSlider(
381
  label="Predicted Material",
 
385
  interactive=False,
386
  elem_classes="slider",
387
  position=0.25,
388
+ visible=True,
389
  )
390
  image_output_slider3 = ImageSlider(
391
  label="Predicted Residual",
 
395
  interactive=False,
396
  elem_classes="slider",
397
  position=0.25,
398
+ visible=False,
399
  )
400
  image_output_files = gr.Files(
401
  label="Output files",
 
405
 
406
  if hf_writer is not None:
407
  with gr.Accordion(
408
+ "Feedback",
409
+ open=False,
410
+ visible=default_share_always_show_accordion,
411
  ) as share_box:
412
  share_instructions = gr.Markdown(
413
  get_share_instructions(is_full=True),
 
415
  )
416
  share_transfer_of_rights = gr.Checkbox(
417
  label="(Optional) I own or hold necessary rights to the submitted image. By "
418
+ "checking this box, I grant an irrevocable, non-exclusive, transferable, "
419
+ "royalty-free, worldwide license to use the uploaded image, including for "
420
+ "publishing, reproducing, and model training. [transfer_of_rights]",
421
  scale=1,
422
  )
423
  share_content_is_legal = gr.Checkbox(
424
  label="By checking this box, I acknowledge that my uploaded content is legal and "
425
+ "safe, and that I am solely responsible for ensuring it complies with all "
426
+ "applicable laws and regulations. Additionally, I am aware that my Hugging Face "
427
+ "username is collected. [content_is_legal]",
428
  scale=1,
429
  )
430
  share_reason = gr.Textbox(
 
437
  share_share_btn = gr.Button(
438
  "Share", variant="stop", scale=1
439
  )
440
+
441
  # Function to toggle visibility and set dynamic labels
442
  def toggle_sliders_and_labels(model_type):
443
  if model_type == "appearance":
 
446
  gr.update(visible=True, label="Predicted Material"),
447
  gr.update(visible=False), # Hide third slider
448
  )
449
+ elif model_type == "lighting":
450
  return (
451
  gr.update(visible=True, label="Predicted Albedo"),
452
  gr.update(visible=True, label="Predicted Shading"),
 
460
  outputs=[image_output_slider1, image_output_slider2, image_output_slider3],
461
  show_progress=False,
462
  )
463
+
464
  Examples(
465
  fn=process_pipe_image,
466
  examples=[
467
+ [os.path.join("files", "image", name), _model_type]
468
  for name in [
469
+ "livingroom.jpg",
470
+ "books.jpg",
471
+ "food_counter.png",
472
+ "cat2.png",
473
  "costumes.png",
474
+ "icecream.jpg",
475
+ "juices.jpeg",
476
  "cat.jpg",
 
477
  "food.jpeg",
 
478
  "puzzle.jpeg",
 
 
 
479
  "screw.png",
 
 
480
  ]
481
+ for _model_type in ["appearance", "lighting"]
482
  ],
483
+ inputs=[image_input, model_type],
484
+ outputs=[
485
  image_output_slider1,
486
  image_output_slider2,
487
  image_output_slider3,
488
+ image_output_files,
489
  ],
490
+ cache_examples=True, # TODO: toggle later
491
+ directory_name="examples_images",
492
  )
493
 
494
  ### Image tab
 
526
  fn=process_pipe_image,
527
  inputs=[
528
  image_input,
529
+ model_type,
530
  image_denoise_steps,
531
  image_ensemble_size,
532
  image_processing_res,
 
533
  ],
534
+ outputs=[
535
+ image_output_slider1,
536
+ image_output_slider2,
537
+ image_output_slider3,
538
+ image_output_files,
539
+ ],
540
  concurrency_limit=1,
541
  )
542
  else:
 
550
  fn=process_pipe_image,
551
  inputs=[
552
  image_input,
553
+ model_type,
554
  image_denoise_steps,
555
  image_ensemble_size,
556
  image_processing_res,
 
557
  ],
558
+ outputs=[
559
+ image_output_slider1,
560
+ image_output_slider2,
561
+ image_output_slider3,
562
+ image_output_files,
563
+ ],
564
  concurrency_limit=1,
565
  )
566
 
files/image/berries.jpeg ADDED

Git LFS Details

  • SHA256: dac1411ea48cf83b7a59c6424032f95b2ff496b3a98cdccf168bbed1c8f0aed4
  • Pointer size: 131 Bytes
  • Size of remote file: 940 kB
files/image/books.jpg ADDED

Git LFS Details

  • SHA256: 1d2648160e85956a5fb6e04b78c241be9662cdb7388bdf92af1b2d72af4506d1
  • Pointer size: 131 Bytes
  • Size of remote file: 743 kB
files/image/cat.jpg ADDED

Git LFS Details

  • SHA256: 794796a86e56a4b372287661dc934daa2d15e988d01afe88afc50b32644c007a
  • Pointer size: 131 Bytes
  • Size of remote file: 236 kB
files/image/cat2.png ADDED

Git LFS Details

  • SHA256: 04a24d72cf9599348d2e3e31e08684ea8a18fcec1e05c3e287e8678f8745fc9e
  • Pointer size: 131 Bytes
  • Size of remote file: 758 kB
files/image/costumes.png ADDED

Git LFS Details

  • SHA256: fc3197481cf925cc02a662dff6d7f8395223e43c249ca9c0b823e3dbc97adf55
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
files/image/einstein.jpg ADDED

Git LFS Details

  • SHA256: d4a4543c0fffb2ca5ea3c17e23e88fcfcf66eae8b487173fbc5c25d0d614bdb6
  • Pointer size: 131 Bytes
  • Size of remote file: 367 kB
files/image/food.jpeg ADDED

Git LFS Details

  • SHA256: a26151050a574b0dc0014e9c4806da3d6f6bc1297ee1035a16b9ace007a179af
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
files/image/food_counter.png ADDED

Git LFS Details

  • SHA256: 1ba51cd83534e42203c463614b2ea62a0b6ab39202042175714ea45e6e2061e6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.73 MB
files/image/icecream.jpg ADDED

Git LFS Details

  • SHA256: 1d7d0865b532267a62f9a3ecd67afec5246d4839242c1ef5717f53747b104f0b
  • Pointer size: 132 Bytes
  • Size of remote file: 2.74 MB
files/image/juices.jpeg ADDED

Git LFS Details

  • SHA256: 906c561aadaffd78ae2aa3b5d8aaf6986e8d890a5ed1ed4a26329f364ab60c97
  • Pointer size: 132 Bytes
  • Size of remote file: 1.29 MB
files/image/livingroom.jpg ADDED

Git LFS Details

  • SHA256: fd05910b4c9aa60af1e05c0985a3ecf7685662f1145eed972f14782a89a05e1d
  • Pointer size: 131 Bytes
  • Size of remote file: 815 kB
files/image/puzzle.jpeg ADDED

Git LFS Details

  • SHA256: 60b66432124a0936c6143301a9f9b793af4184bc9340c567d11fdd5a22cc98cc
  • Pointer size: 131 Bytes
  • Size of remote file: 374 kB
files/image/rocket.png ADDED

Git LFS Details

  • SHA256: 27faa0f9263fbdf13e57a2e4ee70211dae5afba8f919763f9fe3afb8c82ae627
  • Pointer size: 131 Bytes
  • Size of remote file: 620 kB
files/image/scientists.jpg ADDED

Git LFS Details

  • SHA256: 7b164dfbc4ab6e491ce81972b8c0e076fdc4af622289d0aa3cb43ee3c2be4030
  • Pointer size: 131 Bytes
  • Size of remote file: 444 kB
files/image/screw.png ADDED

Git LFS Details

  • SHA256: 550ac366acdbd07376c8215d7e09e621598639abb78fdcdaf85b1bb87e6786e4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
files/image/statues.png ADDED

Git LFS Details

  • SHA256: 143ded9acabd996f91f11c2fcf7bf7c240552551ef4e66308a49c225f1d81fec
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
files/image/swings.jpg ADDED

Git LFS Details

  • SHA256: cae2ac669c948313eae8aca53017f10b64b42f87c53b9c34639962b218fdf1f1
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
files/image/tabletennis.jpg ADDED

Git LFS Details

  • SHA256: cd0d95612636e9ee7e431246480314f873ee1a431c572886100da42bcda72ed2
  • Pointer size: 131 Bytes
  • Size of remote file: 695 kB
files/image/tent.jpg ADDED

Git LFS Details

  • SHA256: 3d0869e11523dfa405afa134078b344195c89c8a6195ad663d393570e8e6d405
  • Pointer size: 132 Bytes
  • Size of remote file: 1.87 MB
gradio_patches/examples.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import gradio
4
+ from gradio.utils import get_cache_folder
5
+
6
+
7
+ class Examples(gradio.helpers.Examples):
8
+ def __init__(self, *args, directory_name=None, **kwargs):
9
+ super().__init__(*args, **kwargs, _initiated_directly=False)
10
+ if directory_name is not None:
11
+ self.cached_folder = get_cache_folder() / directory_name
12
+ self.cached_file = Path(self.cached_folder) / "log.csv"
13
+ self.create()
gradio_patches/flagging.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import json
5
+ import time
6
+ import uuid
7
+ from collections import OrderedDict
8
+ from datetime import datetime, timezone
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ import gradio
13
+ import gradio as gr
14
+ import huggingface_hub
15
+ from gradio import FlaggingCallback
16
+ from gradio_client import utils as client_utils
17
+
18
+
19
+ class HuggingFaceDatasetSaver(gradio.HuggingFaceDatasetSaver):
20
+ def flag(
21
+ self,
22
+ flag_data: list[Any],
23
+ flag_option: str = "",
24
+ username: str | None = None,
25
+ ) -> int:
26
+ if self.separate_dirs:
27
+ # JSONL files to support dataset preview on the Hub
28
+ current_utc_time = datetime.now(timezone.utc)
29
+ iso_format_without_microseconds = current_utc_time.strftime(
30
+ "%Y-%m-%dT%H:%M:%S"
31
+ )
32
+ milliseconds = int(current_utc_time.microsecond / 1000)
33
+ unique_id = f"{iso_format_without_microseconds}.{milliseconds:03}Z"
34
+ if username not in (None, ""):
35
+ unique_id += f"_U_{username}"
36
+ else:
37
+ unique_id += f"_{str(uuid.uuid4())[:8]}"
38
+ components_dir = self.dataset_dir / unique_id
39
+ data_file = components_dir / "metadata.jsonl"
40
+ path_in_repo = unique_id # upload in sub folder (safer for concurrency)
41
+ else:
42
+ # Unique CSV file
43
+ components_dir = self.dataset_dir
44
+ data_file = components_dir / "data.csv"
45
+ path_in_repo = None # upload at root level
46
+
47
+ return self._flag_in_dir(
48
+ data_file=data_file,
49
+ components_dir=components_dir,
50
+ path_in_repo=path_in_repo,
51
+ flag_data=flag_data,
52
+ flag_option=flag_option,
53
+ username=username or "",
54
+ )
55
+
56
+ def _deserialize_components(
57
+ self,
58
+ data_dir: Path,
59
+ flag_data: list[Any],
60
+ flag_option: str = "",
61
+ username: str = "",
62
+ ) -> tuple[dict[Any, Any], list[Any]]:
63
+ """Deserialize components and return the corresponding row for the flagged sample.
64
+
65
+ Images/audio are saved to disk as individual files.
66
+ """
67
+ # Components that can have a preview on dataset repos
68
+ file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
69
+
70
+ # Generate the row corresponding to the flagged sample
71
+ features = OrderedDict()
72
+ row = []
73
+ for component, sample in zip(self.components, flag_data):
74
+ # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
75
+ label = component.label or ""
76
+ save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
77
+ save_dir.mkdir(exist_ok=True, parents=True)
78
+ deserialized = component.flag(sample, save_dir)
79
+
80
+ # Base component .flag method returns JSON; extract path from it when it is FileData
81
+ if component.data_model:
82
+ data = component.data_model.from_json(json.loads(deserialized))
83
+ if component.data_model == gr.data_classes.FileData:
84
+ deserialized = data.path
85
+
86
+ # Add deserialized object to row
87
+ features[label] = {"dtype": "string", "_type": "Value"}
88
+ try:
89
+ deserialized_path = Path(deserialized)
90
+ if not deserialized_path.exists():
91
+ raise FileNotFoundError(f"File {deserialized} not found")
92
+ row.append(str(deserialized_path.relative_to(self.dataset_dir)))
93
+ except (FileNotFoundError, TypeError, ValueError):
94
+ deserialized = "" if deserialized is None else str(deserialized)
95
+ row.append(deserialized)
96
+
97
+ # If component is eligible for a preview, add the URL of the file
98
+ # Be mindful that images and audio can be None
99
+ if isinstance(component, tuple(file_preview_types)): # type: ignore
100
+ for _component, _type in file_preview_types.items():
101
+ if isinstance(component, _component):
102
+ features[label + " file"] = {"_type": _type}
103
+ break
104
+ if deserialized:
105
+ path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
106
+ Path(deserialized).relative_to(self.dataset_dir)
107
+ ).replace(
108
+ "\\", "/"
109
+ )
110
+ row.append(
111
+ huggingface_hub.hf_hub_url(
112
+ repo_id=self.dataset_id,
113
+ filename=path_in_repo,
114
+ repo_type="dataset",
115
+ )
116
+ )
117
+ else:
118
+ row.append("")
119
+ features["flag"] = {"dtype": "string", "_type": "Value"}
120
+ features["username"] = {"dtype": "string", "_type": "Value"}
121
+ row.append(flag_option)
122
+ row.append(username)
123
+ return features, row
124
+
125
+
126
+ class FlagMethod:
127
+ """
128
+ Helper class that contains the flagging options and calls the flagging method. Also
129
+ provides visual feedback to the user when flag is clicked.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ flagging_callback: FlaggingCallback,
135
+ label: str,
136
+ value: str,
137
+ visual_feedback: bool = True,
138
+ ):
139
+ self.flagging_callback = flagging_callback
140
+ self.label = label
141
+ self.value = value
142
+ self.__name__ = "Flag"
143
+ self.visual_feedback = visual_feedback
144
+
145
+ def __call__(
146
+ self,
147
+ request: gr.Request,
148
+ profile: gr.OAuthProfile | None,
149
+ *flag_data,
150
+ ):
151
+ username = None
152
+ if profile is not None:
153
+ username = profile.username
154
+ try:
155
+ self.flagging_callback.flag(
156
+ list(flag_data), flag_option=self.value, username=username
157
+ )
158
+ except Exception as e:
159
+ print(f"Error while sharing: {e}")
160
+ if self.visual_feedback:
161
+ return gr.Button(value="Sharing error", interactive=False)
162
+ if not self.visual_feedback:
163
+ return
164
+ time.sleep(0.8) # to provide enough time for the user to observe button change
165
+ return gr.Button(value="Sharing complete", interactive=False)
marigold_iid_appearance.py CHANGED
@@ -278,12 +278,12 @@ class MarigoldIIDAppearancePipeline(DiffusionPipeline):
278
  )
279
 
280
  albedo_colored = (albedo + 1.0) * 0.5
281
- albedo_colored = (albedo_colored * 255).to(np.uint8)
282
  albedo_colored = self.chw2hwc(albedo_colored)
283
  albedo_colored_img = Image.fromarray(albedo_colored)
284
 
285
  material_colored = (material + 1.0) * 0.5
286
- material_colored = (material_colored * 255).to(np.uint8)
287
  material_colored = self.chw2hwc(material_colored)
288
  material_colored_img = Image.fromarray(material_colored)
289
 
@@ -436,7 +436,7 @@ class MarigoldIIDAppearancePipeline(DiffusionPipeline):
436
  assert target_latents.shape[1] == 8 # self.n_targets * 4
437
 
438
  # scale latent
439
- target_latents = target_latents / self.rgb_latent_scale_factor
440
  # decode
441
  targets = []
442
  for i in range(self.n_targets):
 
278
  )
279
 
280
  albedo_colored = (albedo + 1.0) * 0.5
281
+ albedo_colored = (albedo_colored * 255).astype(np.uint8)
282
  albedo_colored = self.chw2hwc(albedo_colored)
283
  albedo_colored_img = Image.fromarray(albedo_colored)
284
 
285
  material_colored = (material + 1.0) * 0.5
286
+ material_colored = (material_colored * 255).astype(np.uint8)
287
  material_colored = self.chw2hwc(material_colored)
288
  material_colored_img = Image.fromarray(material_colored)
289
 
 
436
  assert target_latents.shape[1] == 8 # self.n_targets * 4
437
 
438
  # scale latent
439
+ target_latents = target_latents / self.latent_scale_factor
440
  # decode
441
  targets = []
442
  for i in range(self.n_targets):
marigold_iid_residual.py → marigold_iid_lighting.py RENAMED
@@ -38,9 +38,9 @@ from transformers import CLIPTextModel, CLIPTokenizer
38
  # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
39
  check_min_version("0.27.0.dev0")
40
 
41
- class MarigoldIIDResidualOutput(BaseOutput):
42
  """
43
- Output class for Marigold IID Residual pipeline.
44
 
45
  Args:
46
  albedo (`np.ndarray`):
@@ -65,7 +65,7 @@ class MarigoldIIDResidualOutput(BaseOutput):
65
  residual: np.ndarray
66
  residual_colored: Image.Image
67
 
68
- class MarigoldIIDResidualPipeline(DiffusionPipeline):
69
  """
70
  Pipeline for Intrinsic Image Decomposition (Albedo, diffuse shading and non-diffuse residual) using Marigold: https://marigoldcomputervision.github.io.
71
 
@@ -124,7 +124,7 @@ class MarigoldIIDResidualPipeline(DiffusionPipeline):
124
  color_map: str = "Spectral", # TODO change colorization api based on modality
125
  show_progress_bar: bool = True,
126
  **kwargs,
127
- ) -> MarigoldIIDResidualOutput:
128
  """
129
  Function invoked when calling the pipeline.
130
 
@@ -155,7 +155,7 @@ class MarigoldIIDResidualPipeline(DiffusionPipeline):
155
  show_progress_bar (`bool`, *optional*, defaults to `True`):
156
  Display a progress bar of diffusion denoising.
157
  Returns:
158
- `MarigoldIIDResidualOutput`: Output class for Marigold monocular intrinsic image decomposition (Residual) prediction pipeline, including:
159
  - **albedo** (`np.ndarray`) Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1]
160
  - **albedo_colored** (`PIL.Image.Image`) Colorized albedo map with the shape of [3, H, W] values in the range of [0, 1]
161
  - **material** (`np.ndarray`) Predicted material map with the shape of [3, H, W] and values in [0, 1]
@@ -276,24 +276,25 @@ class MarigoldIIDResidualPipeline(DiffusionPipeline):
276
  shading = final_pred[3:6, :, :]
277
  residual = final_pred[6:, :, :]
278
 
279
- albedo_colored = (albedo + 1.0) * 0.5
280
- albedo_colored = (albedo_colored * 255).to(np.uint8)
 
281
  albedo_colored = self.chw2hwc(albedo_colored)
282
  albedo_colored_img = Image.fromarray(albedo_colored)
283
 
284
  shading_colored = (shading + 1.0) * 0.5
285
  shading_colored = shading_colored / shading_colored.max() # rescale for better visualization
286
- shading_colored = (shading_colored * 255).to(np.uint8)
287
  shading_colored = self.chw2hwc(shading_colored)
288
  shading_colored_img = Image.fromarray(shading_colored)
289
 
290
  residual_colored = (residual + 1.0) * 0.5
291
  residual_colored = residual_colored / residual_colored.max() # rescale for better visualization
292
- residual_colored = (residual_colored * 255).to(np.uint8)
293
  residual_colored = self.chw2hwc(residual_colored)
294
  residual_colored_img = Image.fromarray(residual_colored)
295
 
296
- out = MarigoldIIDResidualOutput(
297
  albedo=albedo,
298
  albedo_colored=albedo_colored_img,
299
  shading=shading,
@@ -444,7 +445,7 @@ class MarigoldIIDResidualPipeline(DiffusionPipeline):
444
  assert target_latents.shape[1] == 12 # self.n_targets * 4
445
 
446
  # scale latent
447
- target_latents = target_latents / self.rgb_latent_scale_factor
448
  # decode
449
  targets = []
450
  for i in range(self.n_targets):
 
38
  # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
39
  check_min_version("0.27.0.dev0")
40
 
41
+ class MarigoldIIDLightingOutput(BaseOutput):
42
  """
43
+ Output class for Marigold-IID-Lighting pipeline.
44
 
45
  Args:
46
  albedo (`np.ndarray`):
 
65
  residual: np.ndarray
66
  residual_colored: Image.Image
67
 
68
+ class MarigoldIIDLightingPipeline(DiffusionPipeline):
69
  """
70
  Pipeline for Intrinsic Image Decomposition (Albedo, diffuse shading and non-diffuse residual) using Marigold: https://marigoldcomputervision.github.io.
71
 
 
124
  color_map: str = "Spectral", # TODO change colorization api based on modality
125
  show_progress_bar: bool = True,
126
  **kwargs,
127
+ ) -> MarigoldIIDLightingOutput:
128
  """
129
  Function invoked when calling the pipeline.
130
 
 
155
  show_progress_bar (`bool`, *optional*, defaults to `True`):
156
  Display a progress bar of diffusion denoising.
157
  Returns:
158
+ `MarigoldIIDLightingOutput`: Output class for Marigold monocular intrinsic image decomposition (lighting) prediction pipeline, including:
159
  - **albedo** (`np.ndarray`) Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1]
160
  - **albedo_colored** (`PIL.Image.Image`) Colorized albedo map with the shape of [3, H, W] values in the range of [0, 1]
161
  - **material** (`np.ndarray`) Predicted material map with the shape of [3, H, W] and values in [0, 1]
 
276
  shading = final_pred[3:6, :, :]
277
  residual = final_pred[6:, :, :]
278
 
279
+ albedo_colored = (albedo + 1.0) * 0.5 # [-1,1] -> [0,1]
280
+ albedo_colored = albedo_colored ** (1/2.2) # from linear to sRGB (to be consistent with IID-Appearance model)
281
+ albedo_colored = (albedo_colored * 255).astype(np.uint8)
282
  albedo_colored = self.chw2hwc(albedo_colored)
283
  albedo_colored_img = Image.fromarray(albedo_colored)
284
 
285
  shading_colored = (shading + 1.0) * 0.5
286
  shading_colored = shading_colored / shading_colored.max() # rescale for better visualization
287
+ shading_colored = (shading_colored * 255).astype(np.uint8)
288
  shading_colored = self.chw2hwc(shading_colored)
289
  shading_colored_img = Image.fromarray(shading_colored)
290
 
291
  residual_colored = (residual + 1.0) * 0.5
292
  residual_colored = residual_colored / residual_colored.max() # rescale for better visualization
293
+ residual_colored = (residual_colored * 255).astype(np.uint8)
294
  residual_colored = self.chw2hwc(residual_colored)
295
  residual_colored_img = Image.fromarray(residual_colored)
296
 
297
+ out = MarigoldIIDLightingOutput(
298
  albedo=albedo,
299
  albedo_colored=albedo_colored_img,
300
  shading=shading,
 
445
  assert target_latents.shape[1] == 12 # self.n_targets * 4
446
 
447
  # scale latent
448
+ target_latents = target_latents / self.latent_scale_factor
449
  # decode
450
  targets = []
451
  for i in range(self.n_targets):