TryOffAnyone / src /pipeline.py
1aurent's picture
init
74a242e unverified
# type: ignore
# Inspired from https://github.com/ixarchakos/try-off-anyone/blob/aa3045453013065573a647e4536922bac696b968/src/model/pipeline.py
# Inspired from https://github.com/ixarchakos/try-off-anyone/blob/aa3045453013065573a647e4536922bac696b968/src/model/attention.py
import torch
from accelerate import load_checkpoint_in_model
from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils.torch_utils import randn_tensor
from huggingface_hub import hf_hub_download
from PIL import Image
class Skip(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def __call__(
self,
attn: torch.Tensor,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor = None,
temb: torch.Tensor = None,
) -> torch.Tensor:
return hidden_states
def fine_tuned_modules(unet: UNet2DConditionModel) -> torch.nn.ModuleList:
trainable_modules = torch.nn.ModuleList()
for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]:
if hasattr(blocks, "attentions"):
trainable_modules.append(blocks.attentions)
else:
for block in blocks:
if hasattr(block, "attentions"):
trainable_modules.append(block.attentions)
return trainable_modules
def skip_cross_attentions(unet: UNet2DConditionModel) -> dict[str, AttnProcessor | Skip]:
attn_processors = {
name: unet.attn_processors[name] if name.endswith("attn1.processor") else Skip()
for name in unet.attn_processors.keys()
}
return attn_processors
def encode(image: torch.Tensor, vae: AutoencoderKL) -> torch.Tensor:
image = image.to(memory_format=torch.contiguous_format).float().to(vae.device, dtype=vae.dtype)
with torch.no_grad():
return vae.encode(image).latent_dist.sample() * vae.config.scaling_factor
class TryOffAnyone:
def __init__(
self,
device: torch.device,
dtype: torch.dtype,
concat_dim: int = -2,
) -> None:
self.concat_dim = concat_dim
self.device = device
self.dtype = dtype
self.noise_scheduler = DDIMScheduler.from_pretrained(
pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-inpainting",
subfolder="scheduler",
)
self.vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path="stabilityai/sd-vae-ft-mse",
).to(device, dtype=dtype)
self.unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-inpainting",
subfolder="unet",
variant="fp16",
).to(device, dtype=dtype)
self.unet.set_attn_processor(skip_cross_attentions(self.unet))
load_checkpoint_in_model(
model=fine_tuned_modules(unet=self.unet),
checkpoint=hf_hub_download(
repo_id="ixarchakos/tryOffAnyone",
filename="model.safetensors",
),
)
@torch.no_grad()
def __call__(
self,
image: torch.Tensor,
mask: torch.Tensor,
inference_steps: int,
scale: float,
generator: torch.Generator,
) -> list[Image.Image]:
image = image.unsqueeze(0).to(self.device, dtype=self.dtype)
mask = (mask.unsqueeze(0) > 0.5).to(self.device, dtype=self.dtype)
masked_image = image * (mask < 0.5)
masked_latent = encode(masked_image, self.vae)
image_latent = encode(image, self.vae)
mask = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode="nearest")
masked_latent_concat = torch.cat([masked_latent, image_latent], dim=self.concat_dim)
mask_concat = torch.cat([mask, torch.zeros_like(mask)], dim=self.concat_dim)
latents = randn_tensor(
shape=masked_latent_concat.shape,
generator=generator,
device=self.device,
dtype=self.dtype,
)
self.noise_scheduler.set_timesteps(inference_steps, device=self.device)
timesteps = self.noise_scheduler.timesteps
if do_classifier_free_guidance := (scale > 1.0):
masked_latent_concat = torch.cat(
[
torch.cat([masked_latent, torch.zeros_like(image_latent)], dim=self.concat_dim),
masked_latent_concat,
]
)
mask_concat = torch.cat([mask_concat] * 2)
extra_step = {"generator": generator, "eta": 1.0}
for t in timesteps:
input_latents = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
input_latents = self.noise_scheduler.scale_model_input(input_latents, t)
input_latents = torch.cat([input_latents, mask_concat, masked_latent_concat], dim=1)
noise_pred = self.unet(
input_latents,
t.to(self.device),
encoder_hidden_states=None,
return_dict=False,
)[0]
if do_classifier_free_guidance:
noise_pred_unc, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_unc + scale * (noise_pred_text - noise_pred_unc)
latents = self.noise_scheduler.step(noise_pred, t, latents, **extra_step).prev_sample
latents = latents.split(latents.shape[self.concat_dim] // 2, dim=self.concat_dim)[0]
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents.to(self.device, dtype=self.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = (image * 255).round().astype("uint8")
image = [Image.fromarray(im) for im in image]
return image