import os from typing import List import torch from diffusers import FlowMatchEulerDiscreteScheduler from PIL import Image from torchvision import transforms from lbm.models.embedders import ( ConditionerWrapper, LatentsConcatEmbedder, LatentsConcatEmbedderConfig, ) from lbm.models.lbm import LBMConfig, LBMModel from lbm.models.unets import DiffusersUNet2DCondWrapper from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig def get_model_from_config( backbone_signature: str = "stabilityai/stable-diffusion-xl-base-1.0", vae_num_channels: int = 4, unet_input_channels: int = 4, timestep_sampling: str = "log_normal", selected_timesteps: List[float] = None, prob: List[float] = None, conditioning_images_keys: List[str] = [], conditioning_masks_keys: List[str] = ["mask"], source_key: str = "source_image", target_key: str = "source_image_paste", bridge_noise_sigma: float = 0.0, ): conditioners = [] denoiser = DiffusersUNet2DCondWrapper( in_channels=unet_input_channels, # Add downsampled_image out_channels=vae_num_channels, center_input_sample=False, flip_sin_to_cos=True, freq_shift=0, down_block_types=[ "DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", ], mid_block_type="UNetMidBlock2DCrossAttn", up_block_types=["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"], only_cross_attention=False, block_out_channels=[320, 640, 1280], layers_per_block=2, downsample_padding=1, mid_block_scale_factor=1, dropout=0.0, act_fn="silu", norm_num_groups=32, norm_eps=1e-05, cross_attention_dim=[320, 640, 1280], transformer_layers_per_block=[1, 2, 10], reverse_transformer_layers_per_block=None, encoder_hid_dim=None, encoder_hid_dim_type=None, attention_head_dim=[5, 10, 20], num_attention_heads=None, dual_cross_attention=False, use_linear_projection=True, class_embed_type=None, addition_embed_type=None, addition_time_embed_dim=None, num_class_embeds=None, upcast_attention=None, resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, time_embedding_type="positional", time_embedding_dim=None, time_embedding_act_fn=None, timestep_post_act=None, time_cond_proj_dim=None, conv_in_kernel=3, conv_out_kernel=3, projection_class_embeddings_input_dim=None, attention_type="default", class_embeddings_concat=False, mid_block_only_cross_attention=None, cross_attention_norm=None, addition_embed_type_num_heads=64, ).to(torch.bfloat16) if conditioning_images_keys != [] or conditioning_masks_keys != []: latents_concat_embedder_config = LatentsConcatEmbedderConfig( image_keys=conditioning_images_keys, mask_keys=conditioning_masks_keys, ) latent_concat_embedder = LatentsConcatEmbedder(latents_concat_embedder_config) latent_concat_embedder.freeze() conditioners.append(latent_concat_embedder) # Wrap conditioners and set to device conditioner = ConditionerWrapper( conditioners=conditioners, ) ## VAE ## # Get VAE model vae_config = AutoencoderKLDiffusersConfig( version=backbone_signature, subfolder="vae", tiling_size=(128, 128), ) vae = AutoencoderKLDiffusers(vae_config).to(torch.bfloat16) vae.freeze() vae.to(torch.bfloat16) ## Diffusion Model ## # Get diffusion model config = LBMConfig( source_key=source_key, target_key=target_key, timestep_sampling=timestep_sampling, selected_timesteps=selected_timesteps, prob=prob, bridge_noise_sigma=bridge_noise_sigma, ) sampling_noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( backbone_signature, subfolder="scheduler", ) model = LBMModel( config, denoiser=denoiser, sampling_noise_scheduler=sampling_noise_scheduler, vae=vae, conditioner=conditioner, ).to(torch.bfloat16) return model def extract_object(birefnet, img): # Data settings image_size = (1024, 1024) transform_image = transforms.Compose( [ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) image = img input_images = transform_image(image).unsqueeze(0).cuda() # Prediction with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image.size) image = Image.composite(image, Image.new("RGB", image.size, (127, 127, 127)), mask) return image, mask def resize_and_center_crop(image, target_width, target_height): original_width, original_height = image.size scale_factor = max(target_width / original_width, target_height / original_height) resized_width = int(round(original_width * scale_factor)) resized_height = int(round(original_height * scale_factor)) resized_image = image.resize((resized_width, resized_height), Image.LANCZOS) left = (resized_width - target_width) / 2 top = (resized_height - target_height) / 2 right = (resized_width + target_width) / 2 bottom = (resized_height + target_height) / 2 cropped_image = resized_image.crop((left, top, right, bottom)) return cropped_image