KV-Edit / flux /sampling.py
xilluill's picture
inb version init
95d4bb7
import math
from typing import Callable
import torch
from einops import rearrange, repeat
from torch import Tensor
from .model import Flux,Flux_kv
from .modules.conditioner import HFEmbedder
from tqdm import tqdm
from tqdm.contrib import tzip
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def prepare_flowedit(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, source_prompt: str | list[str],target_prompt) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(source_prompt, str):
bs = len(source_prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
# if isinstance(prompt, str):
# prompt = [prompt]
# txt = t5(prompt)
# if txt.shape[0] == 1 and bs > 1:
# txt = repeat(txt, "1 ... -> bs ...", bs=bs)
# txt_ids = torch.zeros(bs, txt.shape[1], 3)
# vec = clip(prompt)
# if vec.shape[0] == 1 and bs > 1:
# vec = repeat(vec, "1 ... -> bs ...", bs=bs)
if isinstance(source_prompt, str):
source_prompt = [source_prompt]
source_txt = t5(source_prompt)
if source_txt.shape[0] == 1 and bs > 1:
source_txt = repeat(source_txt, "1 ... -> bs ...", bs=bs)
source_txt_ids = torch.zeros(bs, source_txt.shape[1], 3)
source_vec = clip(target_prompt)
if source_vec.shape[0] == 1 and bs > 1:
source_vec = repeat(source_vec, "1 ... -> bs ...", bs=bs)
if isinstance(target_prompt, str):
target_prompt = [target_prompt]
target_txt = t5(target_prompt)
if target_txt.shape[0] == 1 and bs > 1:
target_txt = repeat(target_txt, "1 ... -> bs ...", bs=bs)
target_txt_ids = torch.zeros(bs, target_txt.shape[1], 3)
target_vec = clip(target_prompt)
if target_vec.shape[0] == 1 and bs > 1:
target_vec = repeat(target_vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"source_txt": source_txt.to(img.device),
"source_txt_ids": source_txt_ids.to(img.device),
"source_vec": source_vec.to(img.device),
"target_txt": target_txt.to(img.device),
"target_txt_ids": target_txt_ids.to(img.device),
"target_vec": target_vec.to(img.device)
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# estimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
def denoise_kv(
model: Flux_kv,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
inverse,
info,
guidance: float = 4.0
):
if inverse:
timesteps = timesteps[::-1]
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for i, (t_curr, t_prev) in enumerate(tzip(timesteps[:-1], timesteps[1:])):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
info['t'] = t_prev if inverse else t_curr
if inverse:
img_name = str(info['t']) + '_' + 'img'
info['feature'][img_name] = img.cpu()
else:
img_name = str(info['t']) + '_' + 'img'
source_img = info['feature'][img_name].to(img.device)
img = source_img[:, info['mask_indices'],...] * (1 - info['mask'][:, info['mask_indices'],...]) + img * info['mask'][:, info['mask_indices'],...]
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
info=info
)
img = img + (t_prev - t_curr) * pred
return img, info
def denoise_kv_inf(
model: Flux_kv,
# model input
img: Tensor,
img_ids: Tensor,
source_txt: Tensor,
source_txt_ids: Tensor,
source_vec: Tensor,
target_txt: Tensor,
target_txt_ids: Tensor,
target_vec: Tensor,
# sampling parameters
timesteps: list[float],
target_guidance: float = 4.0,
source_guidance: float = 4.0,
info: dict = {},
):
target_guidance_vec = torch.full((img.shape[0],), target_guidance, device=img.device, dtype=img.dtype)
source_guidance_vec = torch.full((img.shape[0],), source_guidance, device=img.device, dtype=img.dtype)
mask_indices = info['mask_indices']
init_img = img.clone() # torch.Size([1, 4080, 64])
z_fe = img[:, mask_indices,...]
noise_list = []
for i in range(len(timesteps)):
noise = torch.randn(init_img.size(), dtype=init_img.dtype,
layout=init_img.layout, device=init_img.device,
generator=torch.Generator(device=init_img.device).manual_seed(0)) # 每次重新取噪声 根据t进行加噪
noise_list.append(noise)
for i, (t_curr, t_prev) in enumerate(tzip(timesteps[:-1], timesteps[1:])): # 从高到低
info['t'] = 'inf'
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
z_src = (1 - t_curr) * init_img + t_curr * noise_list[i]
z_tar = z_src[:, mask_indices,...] - init_img[:, mask_indices,...] + z_fe
info['inverse'] = True
info['feature'] = {} # 清空kv特征
v_src = model(
img=z_src,
img_ids=img_ids,
txt=source_txt,
txt_ids=source_txt_ids,
y=source_vec,
timesteps=t_vec,
guidance=source_guidance_vec,
info=info
)
info['inverse'] = False
v_tar = model(
img=z_tar,
img_ids=img_ids,
txt=target_txt,
txt_ids=target_txt_ids,
y=target_vec,
timesteps=t_vec,
guidance=target_guidance_vec,
info=info
)
v_fe = v_tar - v_src[:, mask_indices,...]
z_fe = z_fe + (t_prev - t_curr) * v_fe * info['mask'][:, mask_indices,...]
return z_fe, info