Huiwenshi commited on
Commit
435927f
·
verified ·
1 Parent(s): 0694d37

Delete folder mvd/.ipynb_checkpoints with huggingface_hub

Browse files
mvd/.ipynb_checkpoints/hunyuan3d_mvd_lite_pipeline-checkpoint.py DELETED
@@ -1,392 +0,0 @@
1
- # Open Source Model Licensed under the Apache License Version 2.0
2
- # and Other Licenses of the Third-Party Components therein:
3
- # The below Model in this distribution may have been modified by THL A29 Limited
4
- # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
-
6
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
- # The below software and/or models in this distribution may have been
8
- # modified by THL A29 Limited ("Tencent Modifications").
9
- # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
-
11
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
- # except for the third-party components listed below.
13
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
- # in the repsective licenses of these third-party components.
15
- # Users must comply with all terms and conditions of original licenses of these third-party
16
- # components and must ensure that the usage of the third party components adheres to
17
- # all relevant laws and regulations.
18
-
19
- # For avoidance of doubts, Hunyuan 3D means the large language models and
20
- # their software and algorithms, including trained model weights, parameters (including
21
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
- # fine-tuning enabling code and other elements of the foregoing made publicly available
23
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
-
25
- import math
26
- import numpy
27
- import torch
28
- import inspect
29
- import warnings
30
- from PIL import Image
31
- from einops import rearrange
32
- import torch.nn.functional as F
33
- from diffusers.utils.torch_utils import randn_tensor
34
- from diffusers.configuration_utils import FrozenDict
35
- from diffusers.image_processor import VaeImageProcessor
36
- from typing import Any, Callable, Dict, List, Optional, Union
37
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
38
- from diffusers.schedulers import KarrasDiffusionSchedulers
39
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
40
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
41
- from diffusers import DDPMScheduler, EulerAncestralDiscreteScheduler, ImagePipelineOutput
42
- from diffusers.loaders import (
43
- FromSingleFileMixin,
44
- LoraLoaderMixin,
45
- TextualInversionLoaderMixin
46
- )
47
- from transformers import (
48
- CLIPImageProcessor,
49
- CLIPTextModel,
50
- CLIPTokenizer,
51
- CLIPVisionModelWithProjection
52
- )
53
- from diffusers.models.attention_processor import (
54
- Attention,
55
- AttnProcessor,
56
- XFormersAttnProcessor,
57
- AttnProcessor2_0
58
- )
59
-
60
- from .utils import to_rgb_image, white_out_background, recenter_img
61
-
62
-
63
- EXAMPLE_DOC_STRING = """
64
- Examples:
65
- ```py
66
- >>> import torch
67
- >>> from here import Hunyuan3d_MVD_Lite_Pipeline
68
-
69
- >>> pipe = Hunyuan3d_MVD_Lite_Pipeline.from_pretrained(
70
- ... "weights/mvd_lite", torch_dtype=torch.float16
71
- ... )
72
- >>> pipe.to("cuda")
73
-
74
- >>> img = Image.open("demo.png")
75
- >>> res_img = pipe(img).images[0]
76
- """
77
-
78
- def unscale_latents(latents): return latents / 0.75 + 0.22
79
- def unscale_image (image ): return image / 0.50 * 0.80
80
-
81
-
82
- def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
83
- std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
84
- std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
85
- noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
86
- noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
87
- return noise_cfg
88
-
89
-
90
-
91
- class ReferenceOnlyAttnProc(torch.nn.Module):
92
- # reference attention
93
- def __init__(self, chained_proc, enabled=False, name=None):
94
- super().__init__()
95
- self.enabled = enabled
96
- self.chained_proc = chained_proc
97
- self.name = name
98
-
99
- def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict=None):
100
- if encoder_hidden_states is None: encoder_hidden_states = hidden_states
101
- if self.enabled:
102
- if mode == 'w':
103
- ref_dict[self.name] = encoder_hidden_states
104
- elif mode == 'r':
105
- encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
106
- res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
107
- return res
108
-
109
-
110
- class RefOnlyNoisedUNet(torch.nn.Module):
111
- def __init__(self, unet, train_sched, val_sched):
112
- super().__init__()
113
- self.unet = unet
114
- self.train_sched = train_sched
115
- self.val_sched = val_sched
116
-
117
- unet_lora_attn_procs = dict()
118
- for name, _ in unet.attn_processors.items():
119
- unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(AttnProcessor2_0(),
120
- enabled=name.endswith("attn1.processor"),
121
- name=name)
122
- unet.set_attn_processor(unet_lora_attn_procs)
123
-
124
- def __getattr__(self, name: str):
125
- try:
126
- return super().__getattr__(name)
127
- except AttributeError:
128
- return getattr(self.unet, name)
129
-
130
- def forward(self, sample, timestep, encoder_hidden_states, *args, cross_attention_kwargs, **kwargs):
131
- cond_lat = cross_attention_kwargs['cond_lat']
132
- noise = torch.randn_like(cond_lat)
133
- if self.training:
134
- noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
135
- noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
136
- else:
137
- noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
138
- noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
139
-
140
- ref_dict = {}
141
- self.unet(noisy_cond_lat,
142
- timestep,
143
- encoder_hidden_states,
144
- *args,
145
- cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
146
- **kwargs)
147
- return self.unet(sample,
148
- timestep,
149
- encoder_hidden_states,
150
- *args,
151
- cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict),
152
- **kwargs)
153
-
154
-
155
- class Hunyuan3d_MVD_Lite_Pipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
156
- def __init__(
157
- self,
158
- vae: AutoencoderKL,
159
- text_encoder: CLIPTextModel,
160
- tokenizer: CLIPTokenizer,
161
- unet: UNet2DConditionModel,
162
- scheduler: KarrasDiffusionSchedulers,
163
- vision_encoder: CLIPVisionModelWithProjection,
164
- feature_extractor_clip: CLIPImageProcessor,
165
- feature_extractor_vae: CLIPImageProcessor,
166
- ramping_coefficients: Optional[list] = None,
167
- safety_checker=None,
168
- ):
169
- DiffusionPipeline.__init__(self)
170
- self.register_modules(
171
- vae=vae,
172
- unet=unet,
173
- tokenizer=tokenizer,
174
- scheduler=scheduler,
175
- text_encoder=text_encoder,
176
- vision_encoder=vision_encoder,
177
- feature_extractor_vae=feature_extractor_vae,
178
- feature_extractor_clip=feature_extractor_clip
179
- )
180
- # rewrite the stable diffusion pipeline
181
- # vae: vae
182
- # unet: unet
183
- # tokenizer: tokenizer
184
- # scheduler: scheduler
185
- # text_encoder: text_encoder
186
- # vision_encoder: vision_encoder
187
- # feature_extractor_vae: feature_extractor_vae
188
- # feature_extractor_clip: feature_extractor_clip
189
- self.register_to_config(ramping_coefficients=ramping_coefficients)
190
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
191
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
192
-
193
- def prepare_extra_step_kwargs(self, generator, eta):
194
- extra_step_kwargs = {}
195
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
196
- if accepts_eta: extra_step_kwargs["eta"] = eta
197
-
198
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
199
- if accepts_generator: extra_step_kwargs["generator"] = generator
200
- return extra_step_kwargs
201
-
202
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
203
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
204
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
205
- latents = latents * self.scheduler.init_noise_sigma
206
- return latents
207
-
208
- @torch.no_grad()
209
- def _encode_prompt(
210
- self,
211
- prompt,
212
- device,
213
- num_images_per_prompt,
214
- do_classifier_free_guidance,
215
- negative_prompt=None,
216
- prompt_embeds: Optional[torch.FloatTensor] = None,
217
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
218
- lora_scale: Optional[float] = None,
219
- ):
220
- if lora_scale is not None and isinstance(self, LoraLoaderMixin):
221
- self._lora_scale = lora_scale
222
-
223
- if prompt is not None and isinstance(prompt, str):
224
- batch_size = 1
225
- elif prompt is not None and isinstance(prompt, list):
226
- batch_size = len(prompt)
227
- else:
228
- batch_size = prompt_embeds.shape[0]
229
-
230
- if prompt_embeds is None:
231
- if isinstance(self, TextualInversionLoaderMixin):
232
- prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
233
-
234
- text_inputs = self.tokenizer(
235
- prompt,
236
- padding="max_length",
237
- max_length=self.tokenizer.model_max_length,
238
- truncation=True,
239
- return_tensors="pt",
240
- )
241
- text_input_ids = text_inputs.input_ids
242
-
243
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
244
- attention_mask = text_inputs.attention_mask.to(device)
245
- else:
246
- attention_mask = None
247
-
248
- prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)[0]
249
-
250
- if self.text_encoder is not None:
251
- prompt_embeds_dtype = self.text_encoder.dtype
252
- elif self.unet is not None:
253
- prompt_embeds_dtype = self.unet.dtype
254
- else:
255
- prompt_embeds_dtype = prompt_embeds.dtype
256
-
257
- prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
258
- bs_embed, seq_len, _ = prompt_embeds.shape
259
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
260
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
261
-
262
- if do_classifier_free_guidance and negative_prompt_embeds is None:
263
- uncond_tokens: List[str]
264
- if negative_prompt is None: uncond_tokens = [""] * batch_size
265
- elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError()
266
- elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt]
267
- elif batch_size != len(negative_prompt): raise ValueError()
268
- else: uncond_tokens = negative_prompt
269
- if isinstance(self, TextualInversionLoaderMixin):
270
- uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
271
-
272
- max_length = prompt_embeds.shape[1]
273
- uncond_input = self.tokenizer(uncond_tokens,
274
- padding="max_length",
275
- max_length=max_length,
276
- truncation=True,
277
- return_tensors="pt")
278
-
279
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
280
- attention_mask = uncond_input.attention_mask.to(device)
281
- else:
282
- attention_mask = None
283
-
284
- negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(device), attention_mask=attention_mask)
285
- negative_prompt_embeds = negative_prompt_embeds[0]
286
-
287
- if do_classifier_free_guidance:
288
- seq_len = negative_prompt_embeds.shape[1]
289
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
290
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
291
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
292
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
293
-
294
- return prompt_embeds
295
-
296
- @torch.no_grad()
297
- def encode_condition_image(self, image: torch.Tensor): return self.vae.encode(image).latent_dist.sample()
298
-
299
- @torch.no_grad()
300
- def __call__(self, image=None,
301
- width=640,
302
- height=960,
303
- num_inference_steps=75,
304
- return_dict=True,
305
- generator=None,
306
- **kwargs):
307
- batch_size = 1
308
- num_images_per_prompt = 1
309
- output_type = 'pil'
310
- do_classifier_free_guidance = True
311
- guidance_rescale = 0.
312
- if isinstance(self.unet, UNet2DConditionModel):
313
- self.unet = RefOnlyNoisedUNet(self.unet, None, self.scheduler).eval()
314
-
315
- cond_image = recenter_img(image)
316
- cond_image = to_rgb_image(image)
317
- image = cond_image
318
- image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values
319
- image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values
320
- image_1 = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
321
- image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)
322
-
323
- cond_lat = self.encode_condition_image(image_1)
324
- negative_lat = self.encode_condition_image(torch.zeros_like(image_1))
325
- cond_lat = torch.cat([negative_lat, cond_lat])
326
- cross_attention_kwargs = dict(cond_lat=cond_lat)
327
-
328
- global_embeds = self.vision_encoder(image_2, output_hidden_states=False).image_embeds.unsqueeze(-2)
329
- encoder_hidden_states = self._encode_prompt('', self.device, num_images_per_prompt, False)
330
- ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
331
- prompt_embeds = torch.cat([encoder_hidden_states, encoder_hidden_states + global_embeds * ramp])
332
-
333
- device = self._execution_device
334
- self.scheduler.set_timesteps(num_inference_steps, device=device)
335
- timesteps = self.scheduler.timesteps
336
- num_channels_latents = self.unet.config.in_channels
337
- latents = self.prepare_latents(batch_size * num_images_per_prompt,
338
- num_channels_latents,
339
- height,
340
- width,
341
- prompt_embeds.dtype,
342
- device,
343
- generator,
344
- None)
345
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0)
346
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
347
-
348
- # set adaptive cfg
349
- # the image order is:
350
- # [0, 60,
351
- # 120, 180,
352
- # 240, 300]
353
- # the cfg is set as 3, 2.5, 2, 1.5
354
-
355
- tmp_guidance_scale = torch.ones_like(latents)
356
- tmp_guidance_scale[:, :, :40, :40] = 3
357
- tmp_guidance_scale[:, :, :40, 40:] = 2.5
358
- tmp_guidance_scale[:, :, 40:80, :40] = 2
359
- tmp_guidance_scale[:, :, 40:80, 40:] = 1.5
360
- tmp_guidance_scale[:, :, 80:120, :40] = 2
361
- tmp_guidance_scale[:, :, 80:120, 40:] = 2.5
362
-
363
- with self.progress_bar(total=num_inference_steps) as progress_bar:
364
- for i, t in enumerate(timesteps):
365
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
366
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
367
-
368
- noise_pred = self.unet(latent_model_input, t,
369
- encoder_hidden_states=prompt_embeds,
370
- cross_attention_kwargs=cross_attention_kwargs,
371
- return_dict=False)[0]
372
-
373
- adaptive_guidance_scale = (2 + 16 * (t / 1000) ** 5) / 3
374
- if do_classifier_free_guidance:
375
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
376
- noise_pred = noise_pred_uncond + \
377
- tmp_guidance_scale * adaptive_guidance_scale * \
378
- (noise_pred_text - noise_pred_uncond)
379
-
380
- if do_classifier_free_guidance and guidance_rescale > 0.0:
381
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
382
-
383
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
384
- if i==len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order==0):
385
- progress_bar.update()
386
-
387
- latents = unscale_latents(latents)
388
- image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0])
389
- image = self.image_processor.postprocess(image, output_type='pil')[0]
390
- image = [image, cond_image]
391
- return ImagePipelineOutput(images=image) if return_dict else (image,)
392
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvd/.ipynb_checkpoints/hunyuan3d_mvd_std_pipeline-checkpoint.py DELETED
@@ -1,473 +0,0 @@
1
- # Open Source Model Licensed under the Apache License Version 2.0
2
- # and Other Licenses of the Third-Party Components therein:
3
- # The below Model in this distribution may have been modified by THL A29 Limited
4
- # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
-
6
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
- # The below software and/or models in this distribution may have been
8
- # modified by THL A29 Limited ("Tencent Modifications").
9
- # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
-
11
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
- # except for the third-party components listed below.
13
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
- # in the repsective licenses of these third-party components.
15
- # Users must comply with all terms and conditions of original licenses of these third-party
16
- # components and must ensure that the usage of the third party components adheres to
17
- # all relevant laws and regulations.
18
-
19
- # For avoidance of doubts, Hunyuan 3D means the large language models and
20
- # their software and algorithms, including trained model weights, parameters (including
21
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
- # fine-tuning enabling code and other elements of the foregoing made publicly available
23
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
-
25
- import inspect
26
- from typing import Any, Dict, Optional
27
- from typing import Any, Dict, List, Optional, Tuple, Union
28
-
29
- import os
30
- import torch
31
- import numpy as np
32
- from PIL import Image
33
-
34
- import diffusers
35
- from diffusers.image_processor import VaeImageProcessor
36
- from diffusers.utils.import_utils import is_xformers_available
37
- from diffusers.schedulers import KarrasDiffusionSchedulers
38
- from diffusers.utils.torch_utils import randn_tensor
39
- from diffusers.utils.import_utils import is_xformers_available
40
- from diffusers.models.attention_processor import (
41
- Attention,
42
- AttnProcessor,
43
- XFormersAttnProcessor,
44
- AttnProcessor2_0
45
- )
46
- from diffusers import (
47
- AutoencoderKL,
48
- DDPMScheduler,
49
- DiffusionPipeline,
50
- EulerAncestralDiscreteScheduler,
51
- UNet2DConditionModel,
52
- ImagePipelineOutput
53
- )
54
- import transformers
55
- from transformers import (
56
- CLIPImageProcessor,
57
- CLIPTextModel,
58
- CLIPTokenizer,
59
- CLIPVisionModelWithProjection,
60
- CLIPTextModelWithProjection
61
- )
62
-
63
- from .utils import to_rgb_image, white_out_background, recenter_img
64
-
65
- EXAMPLE_DOC_STRING = """
66
- Examples:
67
- ```py
68
- >>> import torch
69
- >>> from diffusers import Hunyuan3d_MVD_XL_Pipeline
70
-
71
- >>> pipe = Hunyuan3d_MVD_XL_Pipeline.from_pretrained(
72
- ... "Tencent-Hunyuan-3D/MVD-XL", torch_dtype=torch.float16
73
- ... )
74
- >>> pipe.to("cuda")
75
-
76
- >>> img = Image.open("demo.png")
77
- >>> res_img = pipe(img).images[0]
78
- ```
79
- """
80
-
81
-
82
-
83
- def scale_latents(latents): return (latents - 0.22) * 0.75
84
- def unscale_latents(latents): return (latents / 0.75) + 0.22
85
- def scale_image(image): return (image - 0.5) / 0.5
86
- def scale_image_2(image): return (image * 0.5) / 0.8
87
- def unscale_image(image): return (image * 0.5) + 0.5
88
- def unscale_image_2(image): return (image * 0.8) / 0.5
89
-
90
-
91
-
92
-
93
- class ReferenceOnlyAttnProc(torch.nn.Module):
94
- def __init__(self, chained_proc, enabled=False, name=None):
95
- super().__init__()
96
- self.enabled = enabled
97
- self.chained_proc = chained_proc
98
- self.name = name
99
-
100
- def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict=None):
101
- encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states
102
- if self.enabled:
103
- if mode == 'w': ref_dict[self.name] = encoder_hidden_states
104
- elif mode == 'r': encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
105
- else: raise Exception(f"mode should not be {mode}")
106
- return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
107
-
108
-
109
- class RefOnlyNoisedUNet(torch.nn.Module):
110
- def __init__(self, unet, scheduler) -> None:
111
- super().__init__()
112
- self.unet = unet
113
- self.scheduler = scheduler
114
-
115
- unet_attn_procs = dict()
116
- for name, _ in unet.attn_processors.items():
117
- if torch.__version__ >= '2.0': default_attn_proc = AttnProcessor2_0()
118
- elif is_xformers_available(): default_attn_proc = XFormersAttnProcessor()
119
- else: default_attn_proc = AttnProcessor()
120
- unet_attn_procs[name] = ReferenceOnlyAttnProc(
121
- default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
122
- )
123
- unet.set_attn_processor(unet_attn_procs)
124
-
125
- def __getattr__(self, name: str):
126
- try:
127
- return super().__getattr__(name)
128
- except AttributeError:
129
- return getattr(self.unet, name)
130
-
131
- def forward(
132
- self,
133
- sample: torch.FloatTensor,
134
- timestep: Union[torch.Tensor, float, int],
135
- encoder_hidden_states: torch.Tensor,
136
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
137
- class_labels: Optional[torch.Tensor] = None,
138
- down_block_res_samples: Optional[Tuple[torch.Tensor]] = None,
139
- mid_block_res_sample: Optional[Tuple[torch.Tensor]] = None,
140
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
141
- return_dict: bool = True,
142
- **kwargs
143
- ):
144
-
145
- dtype = self.unet.dtype
146
-
147
- # cond_lat add same level noise
148
- cond_lat = cross_attention_kwargs['cond_lat']
149
- noise = torch.randn_like(cond_lat)
150
-
151
- noisy_cond_lat = self.scheduler.add_noise(cond_lat, noise, timestep.reshape(-1))
152
- noisy_cond_lat = self.scheduler.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
153
-
154
- ref_dict = {}
155
-
156
- _ = self.unet(
157
- noisy_cond_lat,
158
- timestep,
159
- encoder_hidden_states = encoder_hidden_states,
160
- class_labels = class_labels,
161
- cross_attention_kwargs = dict(mode="w", ref_dict=ref_dict),
162
- added_cond_kwargs = added_cond_kwargs,
163
- return_dict = return_dict,
164
- **kwargs
165
- )
166
-
167
- res = self.unet(
168
- sample,
169
- timestep,
170
- encoder_hidden_states,
171
- class_labels=class_labels,
172
- cross_attention_kwargs = dict(mode="r", ref_dict=ref_dict),
173
- down_block_additional_residuals = [
174
- sample.to(dtype=dtype) for sample in down_block_res_samples
175
- ] if down_block_res_samples is not None else None,
176
- mid_block_additional_residual = (
177
- mid_block_res_sample.to(dtype=dtype)
178
- if mid_block_res_sample is not None else None),
179
- added_cond_kwargs = added_cond_kwargs,
180
- return_dict = return_dict,
181
- **kwargs
182
- )
183
- return res
184
-
185
-
186
-
187
- class HunYuan3D_MVD_Std_Pipeline(diffusers.DiffusionPipeline):
188
- def __init__(
189
- self,
190
- vae: AutoencoderKL,
191
- unet: UNet2DConditionModel,
192
- scheduler: KarrasDiffusionSchedulers,
193
- feature_extractor_vae: CLIPImageProcessor,
194
- vision_processor: CLIPImageProcessor,
195
- vision_encoder: CLIPVisionModelWithProjection,
196
- vision_encoder_2: CLIPVisionModelWithProjection,
197
- ramping_coefficients: Optional[list] = None,
198
- add_watermarker: Optional[bool] = None,
199
- safety_checker = None,
200
- ):
201
- DiffusionPipeline.__init__(self)
202
-
203
- self.register_modules(
204
- vae=vae, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor_vae=feature_extractor_vae,
205
- vision_processor=vision_processor, vision_encoder=vision_encoder, vision_encoder_2=vision_encoder_2,
206
- )
207
- self.register_to_config( ramping_coefficients = ramping_coefficients)
208
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
209
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
210
- self.default_sample_size = self.unet.config.sample_size
211
- self.watermark = None
212
- self.prepare_init = False
213
-
214
- def prepare(self):
215
- assert isinstance(self.unet, UNet2DConditionModel), "unet should be UNet2DConditionModel"
216
- self.unet = RefOnlyNoisedUNet(self.unet, self.scheduler).eval()
217
- self.prepare_init = True
218
-
219
- def encode_image(self, image: torch.Tensor, scale_factor: bool = False):
220
- latent = self.vae.encode(image).latent_dist.sample()
221
- return (latent * self.vae.config.scaling_factor) if scale_factor else latent
222
-
223
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
224
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
225
- shape = (
226
- batch_size,
227
- num_channels_latents,
228
- int(height) // self.vae_scale_factor,
229
- int(width) // self.vae_scale_factor,
230
- )
231
- if isinstance(generator, list) and len(generator) != batch_size:
232
- raise ValueError(
233
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
234
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
235
- )
236
-
237
- if latents is None:
238
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
239
- else:
240
- latents = latents.to(device)
241
-
242
- # scale the initial noise by the standard deviation required by the scheduler
243
- latents = latents * self.scheduler.init_noise_sigma
244
- return latents
245
-
246
- def _get_add_time_ids(
247
- self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
248
- ):
249
- add_time_ids = list(original_size + crops_coords_top_left + target_size)
250
-
251
- passed_add_embed_dim = (
252
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
253
- )
254
- expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
255
-
256
- if expected_add_embed_dim != passed_add_embed_dim:
257
- raise ValueError(
258
- f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " \
259
- f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config." \
260
- f" Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
261
- )
262
-
263
- add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
264
- return add_time_ids
265
-
266
- def prepare_extra_step_kwargs(self, generator, eta):
267
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
268
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
269
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
270
- # and should be between [0, 1]
271
-
272
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
273
- extra_step_kwargs = {}
274
- if accepts_eta: extra_step_kwargs["eta"] = eta
275
-
276
- # check if the scheduler accepts generator
277
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
278
- if accepts_generator: extra_step_kwargs["generator"] = generator
279
- return extra_step_kwargs
280
-
281
- @property
282
- def guidance_scale(self):
283
- return self._guidance_scale
284
-
285
- @property
286
- def interrupt(self):
287
- return self._interrupt
288
-
289
- @property
290
- def do_classifier_free_guidance(self):
291
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
292
-
293
- @torch.no_grad()
294
- def __call__(
295
- self,
296
- image: Image.Image = None,
297
- guidance_scale = 2.0,
298
- output_type: Optional[str] = "pil",
299
- num_inference_steps: int = 50,
300
- return_dict: bool = True,
301
- eta: float = 0.0,
302
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
303
- crops_coords_top_left: Tuple[int, int] = (0, 0),
304
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
305
- latent: torch.Tensor = None,
306
- guidance_curve = None,
307
- **kwargs
308
- ):
309
- if not self.prepare_init:
310
- self.prepare()
311
-
312
- here = dict(device=self.vae.device, dtype=self.vae.dtype)
313
-
314
- batch_size = 1
315
- num_images_per_prompt = 1
316
- width, height = 512 * 2, 512 * 3
317
- target_size = original_size = (height, width)
318
-
319
- self._guidance_scale = guidance_scale
320
- self._cross_attention_kwargs = cross_attention_kwargs
321
- self._interrupt = False
322
-
323
- device = self._execution_device
324
-
325
- # Prepare timesteps
326
- self.scheduler.set_timesteps(num_inference_steps, device=device)
327
- timesteps = self.scheduler.timesteps
328
-
329
- # Prepare latent variables
330
- num_channels_latents = self.unet.config.in_channels
331
- latents = self.prepare_latents(
332
- batch_size * num_images_per_prompt,
333
- num_channels_latents,
334
- height,
335
- width,
336
- self.vae.dtype,
337
- device,
338
- generator,
339
- latents=latent,
340
- )
341
-
342
- # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
343
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
344
-
345
-
346
- # Prepare added time ids & embeddings
347
- text_encoder_projection_dim = 1280
348
- add_time_ids = self._get_add_time_ids(
349
- original_size,
350
- crops_coords_top_left,
351
- target_size,
352
- dtype=self.vae.dtype,
353
- text_encoder_projection_dim=text_encoder_projection_dim,
354
- )
355
- negative_add_time_ids = add_time_ids
356
-
357
- # hw: preprocess
358
- cond_image = recenter_img(image)
359
- cond_image = to_rgb_image(image)
360
- image_vae = self.feature_extractor_vae(images=cond_image, return_tensors="pt").pixel_values.to(**here)
361
- image_clip = self.vision_processor(images=cond_image, return_tensors="pt").pixel_values.to(**here)
362
-
363
- # hw: get cond_lat from cond_img using vae
364
- cond_lat = self.encode_image(image_vae, scale_factor=False)
365
- negative_lat = self.encode_image(torch.zeros_like(image_vae), scale_factor=False)
366
- cond_lat = torch.cat([negative_lat, cond_lat])
367
-
368
- # hw: get visual global embedding using clip
369
- global_embeds_1 = self.vision_encoder(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2)
370
- global_embeds_2 = self.vision_encoder_2(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2)
371
- global_embeds = torch.concat([global_embeds_1, global_embeds_2], dim=-1)
372
-
373
- ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
374
- prompt_embeds = self.uc_text_emb.to(**here)
375
- pooled_prompt_embeds = self.uc_text_emb_2.to(**here)
376
-
377
- prompt_embeds = prompt_embeds + global_embeds * ramp
378
- add_text_embeds = pooled_prompt_embeds
379
-
380
- if self.do_classifier_free_guidance:
381
- negative_prompt_embeds = torch.zeros_like(prompt_embeds)
382
- negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
383
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
384
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
385
- add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
386
-
387
- prompt_embeds = prompt_embeds.to(device)
388
- add_text_embeds = add_text_embeds.to(device)
389
- add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
390
-
391
- # Denoising loop
392
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
393
- timestep_cond = None
394
- self._num_timesteps = len(timesteps)
395
-
396
- if guidance_curve is None:
397
- guidance_curve = lambda t: guidance_scale
398
-
399
- with self.progress_bar(total=num_inference_steps) as progress_bar:
400
- for i, t in enumerate(timesteps):
401
- if self.interrupt:
402
- continue
403
-
404
- # expand the latents if we are doing classifier free guidance
405
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
406
-
407
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
408
-
409
- # predict the noise residual
410
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
411
-
412
- noise_pred = self.unet(
413
- latent_model_input,
414
- t,
415
- encoder_hidden_states=prompt_embeds,
416
- timestep_cond=timestep_cond,
417
- cross_attention_kwargs=dict(cond_lat=cond_lat),
418
- added_cond_kwargs=added_cond_kwargs,
419
- return_dict=False,
420
- )[0]
421
-
422
- # perform guidance
423
-
424
- # cur_guidance_scale = self.guidance_scale
425
- cur_guidance_scale = guidance_curve(t) # 1.5 + 2.5 * ((t/1000)**2)
426
-
427
- if self.do_classifier_free_guidance:
428
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
429
- noise_pred = noise_pred_uncond + cur_guidance_scale * (noise_pred_text - noise_pred_uncond)
430
-
431
- # cur_guidance_scale_topleft = (cur_guidance_scale - 1.0) * 4 + 1.0
432
- # noise_pred_top_left = noise_pred_uncond +
433
- # cur_guidance_scale_topleft * (noise_pred_text - noise_pred_uncond)
434
- # _, _, h, w = noise_pred.shape
435
- # noise_pred[:, :, :h//3, :w//2] = noise_pred_top_left[:, :, :h//3, :w//2]
436
-
437
- # compute the previous noisy sample x_t -> x_t-1
438
- latents_dtype = latents.dtype
439
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
440
-
441
- # call the callback, if provided
442
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
443
- progress_bar.update()
444
-
445
- latents = unscale_latents(latents)
446
-
447
- if output_type=="latent":
448
- image = latents
449
- else:
450
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
451
- image = unscale_image(unscale_image_2(image)).clamp(0, 1)
452
- image = [
453
- Image.fromarray((image[0]*255+0.5).clamp_(0, 255).permute(1, 2, 0).cpu().numpy().astype("uint8")),
454
- # self.image_processor.postprocess(image, output_type=output_type)[0],
455
- cond_image.resize((512, 512))
456
- ]
457
-
458
- if not return_dict: return (image,)
459
- return ImagePipelineOutput(images=image)
460
-
461
- def save_pretrained(self, save_directory):
462
- # uc_text_emb.pt and uc_text_emb_2.pt are inferenced and saved in advance
463
- super().save_pretrained(save_directory)
464
- torch.save(self.uc_text_emb, os.path.join(save_directory, "uc_text_emb.pt"))
465
- torch.save(self.uc_text_emb_2, os.path.join(save_directory, "uc_text_emb_2.pt"))
466
-
467
- @classmethod
468
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
469
- # uc_text_emb.pt and uc_text_emb_2.pt are inferenced and saved in advance
470
- pipeline = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
471
- pipeline.uc_text_emb = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb.pt"))
472
- pipeline.uc_text_emb_2 = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb_2.pt"))
473
- return pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvd/.ipynb_checkpoints/utils-checkpoint.py DELETED
@@ -1,87 +0,0 @@
1
- # Open Source Model Licensed under the Apache License Version 2.0
2
- # and Other Licenses of the Third-Party Components therein:
3
- # The below Model in this distribution may have been modified by THL A29 Limited
4
- # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
-
6
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
- # The below software and/or models in this distribution may have been
8
- # modified by THL A29 Limited ("Tencent Modifications").
9
- # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
-
11
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
- # except for the third-party components listed below.
13
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
- # in the repsective licenses of these third-party components.
15
- # Users must comply with all terms and conditions of original licenses of these third-party
16
- # components and must ensure that the usage of the third party components adheres to
17
- # all relevant laws and regulations.
18
-
19
- # For avoidance of doubts, Hunyuan 3D means the large language models and
20
- # their software and algorithms, including trained model weights, parameters (including
21
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
- # fine-tuning enabling code and other elements of the foregoing made publicly available
23
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
-
25
- import numpy as np
26
- from PIL import Image
27
-
28
- def to_rgb_image(maybe_rgba: Image.Image):
29
- '''
30
- convert a PIL.Image to rgb mode with white background
31
- maybe_rgba: PIL.Image
32
- return: PIL.Image
33
- '''
34
- if maybe_rgba.mode == 'RGB':
35
- return maybe_rgba
36
- elif maybe_rgba.mode == 'RGBA':
37
- rgba = maybe_rgba
38
- img = np.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=np.uint8)
39
- img = Image.fromarray(img, 'RGB')
40
- img.paste(rgba, mask=rgba.getchannel('A'))
41
- return img
42
- else:
43
- raise ValueError("Unsupported image type.", maybe_rgba.mode)
44
-
45
- def white_out_background(pil_img, is_gray_fg=True):
46
- data = pil_img.getdata()
47
- new_data = []
48
- # convert fore-ground white to gray
49
- for r, g, b, a in data:
50
- if a < 16:
51
- new_data.append((255, 255, 255, 0)) # back-ground to be black
52
- else:
53
- is_white = is_gray_fg and (r>235) and (g>235) and (b>235)
54
- new_r = 235 if is_white else r
55
- new_g = 235 if is_white else g
56
- new_b = 235 if is_white else b
57
- new_data.append((new_r, new_g, new_b, a))
58
- pil_img.putdata(new_data)
59
- return pil_img
60
-
61
- def recenter_img(img, size=512, color=(255,255,255)):
62
- img = white_out_background(img)
63
- mask = np.array(img)[..., 3]
64
- image = np.array(img)[..., :3]
65
-
66
- H, W, C = image.shape
67
- coords = np.nonzero(mask)
68
- x_min, x_max = coords[0].min(), coords[0].max()
69
- y_min, y_max = coords[1].min(), coords[1].max()
70
- h = x_max - x_min
71
- w = y_max - y_min
72
- if h == 0 or w == 0: raise ValueError
73
- roi = image[x_min:x_max, y_min:y_max]
74
-
75
- border_ratio = 0.15 # 0.2
76
- pad_h = int(h * border_ratio)
77
- pad_w = int(w * border_ratio)
78
-
79
- result_tmp = np.full((h + pad_h, w + pad_w, C), color, dtype=np.uint8)
80
- result_tmp[pad_h // 2: pad_h // 2 + h, pad_w // 2: pad_w // 2 + w] = roi
81
-
82
- cur_h, cur_w = result_tmp.shape[:2]
83
- side = max(cur_h, cur_w)
84
- result = np.full((side, side, C), color, dtype=np.uint8)
85
- result[(side-cur_h)//2:(side-cur_h)//2+cur_h, (side-cur_w)//2:(side - cur_w)//2+cur_w,:] = result_tmp
86
- result = Image.fromarray(result)
87
- return result.resize((size, size), Image.LANCZOS) if size else result