Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,326 Bytes
7d4afe8 dba151a 9746259 7d4afe8 9746259 0eabe5d 9746259 7d4afe8 9746259 7d4afe8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
from functools import partial
import torch
from diffusers import StableDiffusionXLKDiffusionPipeline
from k_diffusion.sampling import get_sigmas_polyexponential
from k_diffusion.sampling import sample_dpmpp_2m_sde
torch.set_float32_matmul_precision("medium")
def set_timesteps_polyexponential(self, orig_sigmas, num_inference_steps, device=None):
self.num_inference_steps = num_inference_steps
self.sigmas = get_sigmas_polyexponential(
num_inference_steps + 1,
sigma_min=orig_sigmas[-2],
sigma_max=orig_sigmas[0],
rho=0.666666,
device=device or "cpu",
)
self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas.new_zeros([1])])
def model_forward(k_diffusion_model: torch.nn.Module):
orig_forward = k_diffusion_model.forward
def forward(*args, **kwargs):
with torch.autocast(device_type="cuda", dtype=torch.float16):
result = orig_forward(*args, **kwargs)
return result.float()
return forward
def load_model(model_id="KBlueLeaf/Kohaku-XL-Epsilon", device="cuda"):
pipe: StableDiffusionXLKDiffusionPipeline
pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch.float16
).to(device)
pipe.scheduler.set_timesteps = partial(
set_timesteps_polyexponential, pipe.scheduler, pipe.scheduler.sigmas
)
pipe.sampler = partial(sample_dpmpp_2m_sde, eta=0.35, solver_type="heun")
pipe.k_diffusion_model.forward = model_forward(pipe.k_diffusion_model)
return pipe
def encode_prompts(pipe: StableDiffusionXLKDiffusionPipeline, prompt, neg_prompt):
max_length = pipe.tokenizer.model_max_length
input_ids = pipe.tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
input_ids2 = pipe.tokenizer_2(prompt, return_tensors="pt").input_ids.to("cuda")
negative_ids = pipe.tokenizer(
neg_prompt,
truncation=False,
padding="max_length",
max_length=input_ids.shape[-1],
return_tensors="pt",
).input_ids.to("cuda")
negative_ids2 = pipe.tokenizer_2(
neg_prompt,
truncation=False,
padding="max_length",
max_length=input_ids.shape[-1],
return_tensors="pt",
).input_ids.to("cuda")
if negative_ids.size() > input_ids.size():
input_ids = pipe.tokenizer(
prompt,
truncation=False,
padding="max_length",
max_length=negative_ids.shape[-1],
return_tensors="pt",
).input_ids.to("cuda")
input_ids2 = pipe.tokenizer_2(
prompt,
truncation=False,
padding="max_length",
max_length=negative_ids.shape[-1],
return_tensors="pt",
).input_ids.to("cuda")
concat_embeds = []
neg_embeds = []
for i in range(0, input_ids.shape[-1], max_length):
concat_embeds.append(pipe.text_encoder(input_ids[:, i : i + max_length])[0])
neg_embeds.append(pipe.text_encoder(negative_ids[:, i : i + max_length])[0])
concat_embeds2 = []
neg_embeds2 = []
pooled_embeds2 = []
neg_pooled_embeds2 = []
for i in range(0, input_ids.shape[-1], max_length):
hidden_states = pipe.text_encoder_2(
input_ids2[:, i : i + max_length], output_hidden_states=True
)
concat_embeds2.append(hidden_states.hidden_states[-2])
pooled_embeds2.append(hidden_states[0])
hidden_states = pipe.text_encoder_2(
negative_ids2[:, i : i + max_length], output_hidden_states=True
)
neg_embeds2.append(hidden_states.hidden_states[-2])
neg_pooled_embeds2.append(hidden_states[0])
prompt_embeds = torch.cat(concat_embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
prompt_embeds2 = torch.cat(concat_embeds2, dim=1)
negative_prompt_embeds2 = torch.cat(neg_embeds2, dim=1)
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds2], dim=-1)
negative_prompt_embeds = torch.cat(
[negative_prompt_embeds, negative_prompt_embeds2], dim=-1
)
pooled_embeds2 = torch.mean(torch.stack(pooled_embeds2, dim=0), dim=0)
neg_pooled_embeds2 = torch.mean(torch.stack(neg_pooled_embeds2, dim=0), dim=0)
return prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2
|