File size: 4,326 Bytes
7d4afe8
 
 
 
 
 
 
dba151a
9746259
7d4afe8
 
 
 
 
 
 
 
 
 
 
 
 
 
9746259
 
0eabe5d
9746259
 
 
 
 
 
 
 
7d4afe8
 
 
 
 
 
 
 
 
9746259
7d4afe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from functools import partial

import torch
from diffusers import StableDiffusionXLKDiffusionPipeline
from k_diffusion.sampling import get_sigmas_polyexponential
from k_diffusion.sampling import sample_dpmpp_2m_sde

torch.set_float32_matmul_precision("medium")


def set_timesteps_polyexponential(self, orig_sigmas, num_inference_steps, device=None):
    self.num_inference_steps = num_inference_steps

    self.sigmas = get_sigmas_polyexponential(
        num_inference_steps + 1,
        sigma_min=orig_sigmas[-2],
        sigma_max=orig_sigmas[0],
        rho=0.666666,
        device=device or "cpu",
    )
    self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas.new_zeros([1])])


def model_forward(k_diffusion_model: torch.nn.Module):
    orig_forward = k_diffusion_model.forward

    def forward(*args, **kwargs):
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            result = orig_forward(*args, **kwargs)
        return result.float()

    return forward


def load_model(model_id="KBlueLeaf/Kohaku-XL-Epsilon", device="cuda"):
    pipe: StableDiffusionXLKDiffusionPipeline
    pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
        model_id, torch_dtype=torch.float16
    ).to(device)
    pipe.scheduler.set_timesteps = partial(
        set_timesteps_polyexponential, pipe.scheduler, pipe.scheduler.sigmas
    )
    pipe.sampler = partial(sample_dpmpp_2m_sde, eta=0.35, solver_type="heun")
    pipe.k_diffusion_model.forward = model_forward(pipe.k_diffusion_model)
    return pipe


def encode_prompts(pipe: StableDiffusionXLKDiffusionPipeline, prompt, neg_prompt):
    max_length = pipe.tokenizer.model_max_length

    input_ids = pipe.tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
    input_ids2 = pipe.tokenizer_2(prompt, return_tensors="pt").input_ids.to("cuda")

    negative_ids = pipe.tokenizer(
        neg_prompt,
        truncation=False,
        padding="max_length",
        max_length=input_ids.shape[-1],
        return_tensors="pt",
    ).input_ids.to("cuda")
    negative_ids2 = pipe.tokenizer_2(
        neg_prompt,
        truncation=False,
        padding="max_length",
        max_length=input_ids.shape[-1],
        return_tensors="pt",
    ).input_ids.to("cuda")

    if negative_ids.size() > input_ids.size():
        input_ids = pipe.tokenizer(
            prompt,
            truncation=False,
            padding="max_length",
            max_length=negative_ids.shape[-1],
            return_tensors="pt",
        ).input_ids.to("cuda")
        input_ids2 = pipe.tokenizer_2(
            prompt,
            truncation=False,
            padding="max_length",
            max_length=negative_ids.shape[-1],
            return_tensors="pt",
        ).input_ids.to("cuda")

    concat_embeds = []
    neg_embeds = []
    for i in range(0, input_ids.shape[-1], max_length):
        concat_embeds.append(pipe.text_encoder(input_ids[:, i : i + max_length])[0])
        neg_embeds.append(pipe.text_encoder(negative_ids[:, i : i + max_length])[0])

    concat_embeds2 = []
    neg_embeds2 = []
    pooled_embeds2 = []
    neg_pooled_embeds2 = []
    for i in range(0, input_ids.shape[-1], max_length):
        hidden_states = pipe.text_encoder_2(
            input_ids2[:, i : i + max_length], output_hidden_states=True
        )
        concat_embeds2.append(hidden_states.hidden_states[-2])
        pooled_embeds2.append(hidden_states[0])

        hidden_states = pipe.text_encoder_2(
            negative_ids2[:, i : i + max_length], output_hidden_states=True
        )
        neg_embeds2.append(hidden_states.hidden_states[-2])
        neg_pooled_embeds2.append(hidden_states[0])

    prompt_embeds = torch.cat(concat_embeds, dim=1)
    negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
    prompt_embeds2 = torch.cat(concat_embeds2, dim=1)
    negative_prompt_embeds2 = torch.cat(neg_embeds2, dim=1)
    prompt_embeds = torch.cat([prompt_embeds, prompt_embeds2], dim=-1)
    negative_prompt_embeds = torch.cat(
        [negative_prompt_embeds, negative_prompt_embeds2], dim=-1
    )

    pooled_embeds2 = torch.mean(torch.stack(pooled_embeds2, dim=0), dim=0)
    neg_pooled_embeds2 = torch.mean(torch.stack(neg_pooled_embeds2, dim=0), dim=0)

    return prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2