Safetensors
FLUX.1-dev-fp8-flumina / sampling.py
aredden's picture
Add img2img, more options, gradio interface
c2ecfb5
raw
history blame
4.25 kB
import math
from typing import Callable
import torch
from einops import rearrange, repeat
from torch import Tensor
from modules.flux_model import Flux
from modules.conditioner import HFEmbedder
@torch.inference_mode()
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
@torch.inference_mode()
def prepare(
t5: HFEmbedder,
clip: HFEmbedder,
img: Tensor,
prompt: str | list[str],
target_device: torch.device = torch.device("cuda:0"),
target_dtype: torch.dtype = torch.float16,
) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3, device=target_device, dtype=target_dtype)
img_ids[..., 1] = (
img_ids[..., 1]
+ torch.arange(h // 2, device=target_device, dtype=target_dtype)[:, None]
)
img_ids[..., 2] = (
img_ids[..., 2]
+ torch.arange(w // 2, device=target_device, dtype=target_dtype)[None, :]
)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt).to(target_device, dtype=target_dtype)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3, device=target_device, dtype=target_dtype)
vec = clip(prompt).to(target_device, dtype=target_dtype)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids,
"txt": txt,
"txt_ids": txt_ids,
"vec": vec,
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
@torch.inference_mode()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
dtype: torch.dtype = torch.bfloat16,
device: torch.device = torch.device("cuda:0"),
):
from tqdm import tqdm
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=device, dtype=dtype)
for t_curr, t_prev in tqdm(
zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1
):
t_vec = torch.full((img.shape[0],), t_curr, dtype=dtype, device=device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)