TIPO-DEMO / diff.py
Kohaku-Blueleaf
Fix encode prompt impl
06f0d78
raw
history blame
5.59 kB
import math
from functools import partial
import torch
from diffusers import StableDiffusionXLKDiffusionPipeline, UNet2DConditionModel
from k_diffusion.sampling import get_sigmas_polyexponential
from k_diffusion.sampling import sample_dpmpp_2m_sde
torch.set_float32_matmul_precision("medium")
def set_timesteps_polyexponential(self, orig_sigmas, num_inference_steps, device=None):
self.num_inference_steps = num_inference_steps
self.sigmas = get_sigmas_polyexponential(
num_inference_steps + 1,
sigma_min=orig_sigmas[-2],
sigma_max=orig_sigmas[0],
rho=0.666666,
device=device or "cpu",
)
self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas.new_zeros([1])])
def model_forward(k_diffusion_model: torch.nn.Module):
orig_forward = k_diffusion_model.forward
def forward(*args, **kwargs):
with torch.autocast(device_type="cuda", dtype=torch.float16):
result = orig_forward(*args, **kwargs)
return result.float()
return forward
def load_model(model_id="KBlueLeaf/Kohaku-XL-Zeta", device="cuda"):
pipe: StableDiffusionXLKDiffusionPipeline
pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch.float16
).to(device)
unet: UNet2DConditionModel = pipe.k_diffusion_model.inner_model.model
pipe.scheduler.set_timesteps = partial(
set_timesteps_polyexponential, pipe.scheduler, pipe.scheduler.sigmas
)
pipe.sampler = partial(sample_dpmpp_2m_sde, eta=0.35, solver_type="heun")
pipe.k_diffusion_model.forward = model_forward(pipe.k_diffusion_model)
return pipe
@torch.no_grad()
def encode_prompts(
pipe: StableDiffusionXLKDiffusionPipeline, prompt: str, neg_prompt: str = ""
):
prompts = [prompt, neg_prompt]
max_length = pipe.tokenizer.model_max_length - 2
input_ids = pipe.tokenizer(prompts, padding=True, return_tensors="pt")
input_ids2 = pipe.tokenizer_2(prompts, padding=True, return_tensors="pt")
length = max(input_ids.input_ids.size(-1), input_ids2.input_ids.size(-1))
target_length = math.ceil(length / max_length) * max_length + 2
input_ids = pipe.tokenizer(
prompts, padding="max_length", max_length=target_length, return_tensors="pt"
).input_ids
input_ids = (
input_ids[:, 0:1],
input_ids[:, 1:-1],
input_ids[:, -1:],
)
input_ids2 = pipe.tokenizer_2(
prompts, padding="max_length", max_length=target_length, return_tensors="pt"
).input_ids
input_ids2 = (
input_ids2[:, 0:1],
input_ids2[:, 1:-1],
input_ids2[:, -1:],
)
concat_embeds = []
for i in range(0, input_ids[1].shape[-1], max_length):
input_id1 = torch.concat(
(input_ids[0], input_ids[1][:, i : i + max_length], input_ids[2]), dim=-1
).to(pipe.device)
result = pipe.text_encoder(input_id1, output_hidden_states=True).hidden_states[
-2
]
if i == 0:
concat_embeds.append(result[:, :-1])
elif i == input_ids[1].shape[-1] - max_length:
concat_embeds.append(result[:, 1:])
else:
concat_embeds.append(result[:, 1:-1])
concat_embeds2 = []
pooled_embeds2 = []
for i in range(0, input_ids2[1].shape[-1], max_length):
input_id2 = torch.concat(
(input_ids2[0], input_ids2[1][:, i : i + max_length], input_ids2[2]), dim=-1
).to(pipe.device)
hidden_states = pipe.text_encoder_2(input_id2, output_hidden_states=True)
pooled_embeds2.append(hidden_states[0])
if i == 0:
concat_embeds2.append(hidden_states.hidden_states[-2][:, :-1])
elif i == input_ids2[1].shape[-1] - max_length:
concat_embeds2.append(hidden_states.hidden_states[-2][:, 1:])
else:
concat_embeds2.append(hidden_states.hidden_states[-2][:, 1:-1])
prompt_embeds = torch.cat(concat_embeds, dim=1)
prompt_embeds2 = torch.cat(concat_embeds2, dim=1)
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds2], dim=-1)
pooled_embeds2 = torch.mean(torch.stack(pooled_embeds2, dim=0), dim=0)
return prompt_embeds, pooled_embeds2
if __name__ == "__main__":
from meta import DEFAULT_NEGATIVE_PROMPT
prompt = """
1girl,
king halo (umamusume), umamusume,
ogipote, misu kasumi, fuzichoco, ciloranko, ninjin nouka, ningen mame, ask (askzy), kita (kitairoha), amano kokoko, maccha (mochancc),
solo, leaning forward, cleavage, sky, cowboy shot, outdoors, cloud, long hair, looking at viewer, brown hair, day, horse girl, black bikini, cloudy sky, stomach, collarbone, blue sky, swimsuit, navel, thighs, blush, ocean, animal ears, standing, smile, breasts, open mouth, :d, red eyes, horse ears, tail, bare shoulders, wavy hair, bikini, medium breasts,
masterpiece, newest, absurdres, sensitive
""".strip()
sdxl_pipe = load_model("KBlueLeaf/xxx")
# sdxl_pipe = load_model()
prompt_embeds, pooled_embeds2 = encode_prompts(
sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT
)
result = sdxl_pipe(
prompt_embeds=prompt_embeds[0:1],
negative_prompt_embeds=prompt_embeds[1:],
pooled_prompt_embeds=pooled_embeds2[0:1],
negative_pooled_prompt_embeds=pooled_embeds2[1:],
num_inference_steps=24,
width=1024,
height=1024,
guidance_scale=6.0,
).images[0]
result.save("test.png")
module = torch.compile(sdxl_pipe)
if isinstance(module, torch._dynamo.OptimizedModule):
original_module = module._orig_mod