tight-inversion
Initial commit
4ebc565
import math
from typing import Callable
import torch
from einops import rearrange, repeat
from torch import Tensor
from tqdm import tqdm
from .model import Flux
from .modules.conditioner import HFEmbedder
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
img_ids = img_ids.to(img.dtype)
img_ids = img_ids.to(torch.bfloat16)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3).to(txt.dtype)
txt_ids = txt_ids.to(torch.bfloat16)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def rf_inversion(
model: Flux,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
timesteps: list[float],
guidance: float = 4.0,
id_weight=1.0,
id=None,
start_step=0,
uncond_id=None,
true_cfg=1.0,
timestep_to_start_cfg=1,
neg_txt=None,
neg_txt_ids=None,
neg_vec=None,
aggressive_offload=False,
y_1: Tensor = None,
gamma: float = 0.5,
):
# reverse the timesteps
timesteps = timesteps[::-1]
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
use_true_cfg = abs(true_cfg - 1.0) > 1e-2
for i in tqdm(range(len(timesteps) - 1), desc="Inverting"):
t_i = i / len(timesteps)
t_vec = torch.full((img.shape[0],), t_i, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
id=id if (len(timesteps) - 1 - i) >= start_step else None,
id_weight=id_weight,
aggressive_offload=aggressive_offload,
)
if use_true_cfg and i >= timestep_to_start_cfg:
neg_pred = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
id=uncond_id if (len(timesteps) - 1 - i) >= start_step else None,
id_weight=id_weight,
aggressive_offload=aggressive_offload,
)
pred = neg_pred + true_cfg * (pred - neg_pred)
assert (1 - t_i) != 0
u_t_i_cond = (y_1 - img) / (1 - t_i)
pred = pred + gamma * (u_t_i_cond - pred)
img = img + (timesteps[i+1] - timesteps[i]) * pred
return img
def rf_denoise(
model: Flux,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
timesteps: list[float],
guidance: float = 4.0,
id_weight=1.0,
id=None,
start_step=0,
uncond_id=None,
true_cfg=1.0,
timestep_to_start_cfg=1,
neg_txt=None,
neg_txt_ids=None,
neg_vec=None,
aggressive_offload=False,
y_0: Tensor = None,
eta=0.9,
s=0,
tau=6,
):
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
use_true_cfg = abs(true_cfg - 1.0) > 1e-2
for i in tqdm(range(len(timesteps) - 1), desc="Denoising"):
t_i = i / len(timesteps)
t_vec = torch.full((img.shape[0],), 1-t_i, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
id=id if i >= start_step else None,
id_weight=id_weight,
aggressive_offload=aggressive_offload,
)
if use_true_cfg and i >= timestep_to_start_cfg:
neg_pred = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
id=uncond_id if i >= start_step else None,
id_weight=id_weight,
aggressive_offload=aggressive_offload,
)
pred = neg_pred + true_cfg * (pred - neg_pred)
pred = -pred
assert (1 - t_i) != 0
v_t_cond = (y_0 - img) / (1 - t_i)
eta_t = eta if s <= i < tau else 0
pred = pred + eta_t * (v_t_cond - pred)
img = img + (timesteps[i] - timesteps[i+1]) * pred
return img
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
timesteps: list[float],
guidance: float = 4.0,
id_weight=1.0,
id=None,
start_step=0,
uncond_id=None,
true_cfg=1.0,
timestep_to_start_cfg=1,
neg_txt=None,
neg_txt_ids=None,
neg_vec=None,
aggressive_offload=False,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
use_true_cfg = abs(true_cfg - 1.0) > 1e-2
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
id=id if i >= start_step else None,
id_weight=id_weight,
aggressive_offload=aggressive_offload,
)
if use_true_cfg and i >= timestep_to_start_cfg:
neg_pred = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
id=uncond_id if i >= start_step else None,
id_weight=id_weight,
aggressive_offload=aggressive_offload,
)
pred = neg_pred + true_cfg * (pred - neg_pred)
img = img + (t_prev - t_curr) * pred
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)