File size: 9,836 Bytes
cb5daed
7736f5f
4d6f2bc
 
 
48c31e7
dffd0bb
23f4f95
1e250ff
ca5a1e4
f70898c
ca5a1e4
aafe7f2
7a7cda5
 
51fab87
7a7cda5
 
 
 
 
4d6f2bc
 
51fab87
f70898c
39a6792
 
 
 
 
 
 
 
 
 
7e65847
 
 
 
 
 
39a6792
1a688bc
 
7a7cda5
6681256
4470520
 
 
 
6681256
 
4470520
 
 
6681256
 
4470520
6681256
 
51fab87
1e250ff
4d6f2bc
 
 
60849d7
98afd85
7a7cda5
f70898c
 
 
 
23f4f95
1a688bc
4d6f2bc
af07f4b
4470520
98afd85
48c31e7
 
1128e78
af07f4b
60849d7
6829539
 
4d6f2bc
1a688bc
48c31e7
c348e53
48c31e7
98afd85
5c4e8c1
1e250ff
 
4d6f2bc
069fc81
 
 
 
9e8b99d
 
069fc81
1128e78
5c4e8c1
1128e78
1a688bc
 
 
48c31e7
4470520
 
 
 
98afd85
 
48c31e7
 
 
 
 
4d6f2bc
fd9e8de
61ad3d2
6829539
 
 
51fab87
6829539
4470520
1e250ff
 
 
 
 
 
 
 
 
6829539
 
 
4470520
6829539
 
 
 
98afd85
039ff6d
 
6829539
 
 
1e250ff
6829539
4d6f2bc
4470520
f70898c
4470520
 
039ff6d
4470520
f70898c
 
 
 
 
13b498b
 
 
 
9e8b99d
13b498b
 
 
 
 
 
 
 
 
 
 
 
9e8b99d
13b498b
 
f70898c
 
 
 
 
 
 
 
 
 
 
6829539
 
 
 
 
 
 
 
f70898c
6829539
51fab87
6829539
 
 
1a7f234
6829539
 
 
 
 
4d6f2bc
6829539
 
9e8b99d
6829539
dffd0bb
aafe7f2
51fab87
aafe7f2
51fab87
39a6792
 
 
f70898c
 
 
 
 
39a6792
4470520
39a6792
6829539
dffd0bb
f70898c
6829539
 
 
 
 
4470520
6829539
 
f70898c
6829539
 
 
 
 
 
7a7cda5
6829539
 
7a7cda5
6829539
98afd85
7a7cda5
98afd85
 
7a7cda5
98afd85
6829539
7a7cda5
6829539
 
 
 
4470520
6829539
f70898c
 
 
 
4470520
 
6829539
51fab87
9e8b99d
 
 
 
 
 
 
 
 
51fab87
 
039ff6d
069fc81
 
aafe7f2
51fab87
 
6829539
aafe7f2
51fab87
6829539
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import os
import time
from datetime import datetime

import torch
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
from compel.prompt_parser import PromptParser
from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
from spaces import GPU

from .config import Config
from .loader import Loader
from .logger import Logger
from .utils import (
    annotate_image,
    clear_cuda_cache,
    load_json,
    resize_image,
    safe_progress,
    timer,
)


# Inject prompts into style templates
def apply_style(positive_prompt, negative_prompt, style_id="none"):
    if style_id.lower() == "none":
        return (positive_prompt, negative_prompt)

    styles = load_json("./data/styles.json")
    style = styles.get(style_id)
    if style is None:
        return (positive_prompt, negative_prompt)

    style_base = styles.get("_base", {})
    return (
        style.get("positive")
        .format(prompt=positive_prompt, _base=style_base.get("positive"))
        .strip(),
        style.get("negative")
        .format(prompt=negative_prompt, _base=style_base.get("negative"))
        .strip(),
    )


# Dynamic signature for the GPU duration function
def gpu_duration(**kwargs):
    loading = 20
    duration = 10
    width = kwargs.get("width", 512)
    height = kwargs.get("height", 512)
    scale = kwargs.get("scale", 1)
    num_images = kwargs.get("num_images", 1)
    size = width * height
    if size > 500_000:
        duration += 5
    if scale == 4:
        duration += 5
    return loading + (duration * num_images)


# Request GPU when deployed to Hugging Face
@GPU(duration=gpu_duration)
def generate(
    positive_prompt,
    negative_prompt="",
    image_prompt=None,
    control_image_prompt=None,
    ip_image_prompt=None,
    lora_1=None,
    lora_1_weight=0.0,
    lora_2=None,
    lora_2_weight=0.0,
    embeddings=[],
    style=None,
    seed=None,
    model="Lykon/dreamshaper-8",
    scheduler="DDIM",
    annotator="canny",
    width=512,
    height=512,
    guidance_scale=7.5,
    inference_steps=40,
    denoising_strength=0.8,
    deepcache=1,
    scale=1,
    num_images=1,
    karras=False,
    taesd=False,
    freeu=False,
    clip_skip=False,
    ip_face=False,
    Error=Exception,
    Info=None,
    progress=None,
):
    start = time.perf_counter()
    log = Logger("generate")
    log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}")

    if Config.ZERO_GPU:
        safe_progress(progress, 100, 100, "ZeroGPU init")

    if not torch.cuda.is_available():
        raise Error("CUDA not available")

    # https://pytorch.org/docs/stable/generated/torch.manual_seed.html
    if seed is None or seed < 0:
        seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)

    CURRENT_STEP = 0
    CURRENT_IMAGE = 1

    KIND = "img2img" if image_prompt is not None else "txt2img"
    KIND = f"controlnet_{KIND}" if control_image_prompt is not None else KIND

    EMBEDDINGS_TYPE = (
        ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
        if clip_skip
        else ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
    )

    if ip_image_prompt:
        IP_ADAPTER = "full-face" if ip_face else "plus"
    else:
        IP_ADAPTER = ""

    # Custom progress bar for multiple images
    def callback_on_step_end(pipeline, step, timestep, latents):
        nonlocal CURRENT_STEP, CURRENT_IMAGE
        if progress is not None:
            # calculate total steps for img2img based on denoising strength
            strength = denoising_strength if KIND == "img2img" else 1
            total_steps = min(int(inference_steps * strength), inference_steps)
            CURRENT_STEP = step + 1
            progress(
                (CURRENT_STEP, total_steps),
                desc=f"Generating image {CURRENT_IMAGE}/{num_images}",
            )
        return latents

    loader = Loader()
    loader.load(
        KIND,
        IP_ADAPTER,
        model,
        scheduler,
        annotator,
        deepcache,
        scale,
        karras,
        taesd,
        freeu,
        progress,
    )

    if loader.pipe is None:
        raise Error(f"Error loading {model}")

    pipe = loader.pipe
    upscaler = loader.upscaler

    # load loras
    loras = []
    weights = []
    loras_and_weights = [(lora_1, lora_1_weight), (lora_2, lora_2_weight)]
    loras_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "loras"))
    total_loras = sum(1 for lora, _ in loras_and_weights if lora and lora.lower() != "none")
    desc_loras = "Loading LoRAs"
    if total_loras > 0:
        with timer(f"Loading {total_loras} LoRA{'s' if total_loras > 1 else ''}"):
            safe_progress(progress, 0, total_loras, desc_loras)
            for i, (lora, weight) in enumerate(loras_and_weights):
                if lora and lora.lower() != "none" and lora not in loras:
                    config = Config.CIVIT_LORAS.get(lora)
                    if config:
                        try:
                            pipe.load_lora_weights(
                                loras_dir,
                                adapter_name=lora,
                                weight_name=f"{lora}.{config['model_version_id']}.safetensors",
                            )
                            weights.append(weight)
                            loras.append(lora)
                            safe_progress(progress, i + 1, total_loras, desc_loras)
                        except Exception:
                            raise Error(f"Error loading {config['name']} LoRA")

    # unload after generating or if there was an error
    try:
        if loras:
            pipe.set_adapters(loras, adapter_weights=weights)
    except Exception:
        pipe.unload_lora_weights()
        raise Error("Error setting LoRA weights")

    # load embeddings
    embeddings_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "embeddings"))
    for embedding in embeddings:
        try:
            # wrap embeddings in angle brackets
            pipe.load_textual_inversion(
                pretrained_model_name_or_path=f"{embeddings_dir}/{embedding}.pt",
                token=f"<{embedding}>",
            )
        except (EnvironmentError, HFValidationError, RepositoryNotFoundError):
            raise Error(f"Invalid embedding: {embedding}")

    # Embed prompts with weights
    compel = Compel(
        device=pipe.device,
        tokenizer=pipe.tokenizer,
        truncate_long_prompts=False,
        text_encoder=pipe.text_encoder,
        returned_embeddings_type=EMBEDDINGS_TYPE,
        dtype_for_device_getter=lambda _: pipe.dtype,
        textual_inversion_manager=DiffusersTextualInversionManager(pipe),
    )

    images = []
    current_seed = seed
    safe_progress(progress, 0, num_images, f"Generating image 0/{num_images}")
    for i in range(num_images):
        try:
            generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
            positive_styled, negative_styled = apply_style(positive_prompt, negative_prompt, style)

            # User didn't provide a negative prompt
            if negative_styled.startswith("(), "):
                negative_styled = negative_styled[4:]

            for lora in loras:
                positive_styled += f", {Config.CIVIT_LORAS[lora]['trigger']}"

            for embedding in embeddings:
                negative_styled += f", <{embedding}>"

            positive_embeds, negative_embeds = compel.pad_conditioning_tensors_to_same_length(
                [compel(positive_styled), compel(negative_styled)]
            )
        except PromptParser.ParsingException:
            raise Error("Invalid prompt")

        kwargs = {
            "width": width,
            "height": height,
            "generator": generator,
            "prompt_embeds": positive_embeds,
            "guidance_scale": guidance_scale,
            "num_inference_steps": inference_steps,
            "negative_prompt_embeds": negative_embeds,
            "output_type": "np" if scale > 1 else "pil",
        }

        if progress is not None:
            kwargs["callback_on_step_end"] = callback_on_step_end

        # Resizing so the initial latents are the same size as the generated image
        if KIND == "img2img":
            kwargs["strength"] = denoising_strength
            kwargs["image"] = resize_image(image_prompt, (width, height))

        if KIND == "controlnet_txt2img":
            kwargs["image"] = annotate_image(control_image_prompt, annotator)

        if KIND == "controlnet_img2img":
            kwargs["control_image"] = annotate_image(control_image_prompt, annotator)

        if IP_ADAPTER:
            kwargs["ip_adapter_image"] = resize_image(ip_image_prompt)

        try:
            image = pipe(**kwargs).images[0]
            images.append((image, str(current_seed)))
            current_seed += 1
        finally:
            if embeddings:
                pipe.unload_textual_inversion()
            if loras:
                pipe.unload_lora_weights()
            CURRENT_STEP = 0
            CURRENT_IMAGE += 1

    # Upscale
    if scale > 1:
        msg = f"Upscaling {scale}x"
        with timer(msg, logger=log.info):
            safe_progress(progress, 0, num_images, desc=msg)
            for i, image in enumerate(images):
                image = upscaler.predict(image[0])
                images[i] = image
                safe_progress(progress, i + 1, num_images, desc=msg)

    # Flush memory after generating
    clear_cuda_cache()

    end = time.perf_counter()
    msg = f"Generating {len(images)} image{'s' if len(images) > 1 else ''} took {end - start:.2f}s"
    log.info(msg)

    # Alert if notifier provided
    if Info:
        Info(msg)

    return images