File size: 4,550 Bytes
55a3c9a
 
33d7f34
55a3c9a
 
 
 
 
 
 
 
 
 
 
 
204fcce
55a3c9a
a2fc6fa
204fcce
55a3c9a
 
 
a2fc6fa
55a3c9a
a2fc6fa
55a3c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc3415b
 
8811405
 
 
 
 
 
 
 
 
 
ac03af8
 
8811405
 
55a3c9a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from diffusers.utils.peft_utils import set_weights_and_activate_adapters
from S2I.modules.models import PrimaryModel
import re
import gc
import torch
import warnings

warnings.filterwarnings("ignore")


class Sketch2ImagePipeline(PrimaryModel):
    def __init__(self):
        super().__init__()
        self.timestep = torch.tensor([999], device="cuda").long()

    def generate(self, c_t, prompt=None, prompt_quality=None, prompt_template=None, prompt_tokens=None, r=1.0, noise_map=None, half_model=None, model_name=None):
        self.from_pretrained(model_name=model_name, r=r)
        prompt_enhanced = self.automatic_enhance_prompt(prompt, prompt_quality)
        prompt_enhanced = prompt_template.replace("{prompt}", prompt_enhanced)
        assert (prompt is None) != (prompt_tokens is None), "Either prompt or prompt_tokens should be provided"

        if half_model == 'float16':
            output_image = self._generate_fp16(c_t, prompt_enhanced, prompt_tokens, r, noise_map)
        else:
            output_image = self._generate_full_precision(c_t, prompt_enhanced, prompt_tokens, r, noise_map)

        return output_image

    def _generate_fp16(self, c_t, prompt, prompt_tokens, r, noise_map):
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            caption_enc = self._get_caption_enc(prompt, prompt_tokens)

            self._set_weights_and_activate_adapters(r)
            encoded_control = self.global_vae.encode(c_t).latent_dist.sample() * self.global_vae.config.scaling_factor

            unet_input = encoded_control * r + noise_map * (1 - r)
            unet_output = self.global_unet(unet_input, self.timestep, encoder_hidden_states=caption_enc).sample
            x_denoise = self.global_scheduler.step(unet_output, self.timestep, unet_input, return_dict=True).prev_sample

            self.global_vae.decoder.incoming_skip_acts = self.global_vae.encoder.current_down_blocks
            self.global_vae.decoder.gamma = r

            output_image = self.global_vae.decode(x_denoise / self.global_vae.config.scaling_factor).sample.clamp(-1, 1)

        return output_image

    def _generate_full_precision(self, c_t, prompt, prompt_tokens, r, noise_map):
        caption_enc = self._get_caption_enc(prompt, prompt_tokens)

        self._set_weights_and_activate_adapters(r)
        encoded_control = self.global_vae.encode(c_t).latent_dist.sample() * self.global_vae.config.scaling_factor

        unet_input = encoded_control * r + noise_map * (1 - r)
        unet_output = self.global_unet(unet_input, self.timestep, encoder_hidden_states=caption_enc).sample
        x_denoise = self.global_scheduler.step(unet_output, self.timestep, unet_input, return_dict=True).prev_sample

        self.global_vae.decoder.incoming_skip_acts = self.global_vae.encoder.current_down_blocks
        self.global_vae.decoder.gamma = r

        output_image = self.global_vae.decode(x_denoise / self.global_vae.config.scaling_factor).sample.clamp(-1, 1)

        return output_image

    def _get_caption_enc(self, prompt, prompt_tokens):
        if prompt is not None:
            caption_tokens = self.global_tokenizer(prompt, max_length=self.global_tokenizer.model_max_length,
                                                   padding="max_length", truncation=True,
                                                   return_tensors="pt").input_ids.cuda()
        else:
            caption_tokens = prompt_tokens.cuda()

        return self.global_text_encoder(caption_tokens)[0]

    def _set_weights_and_activate_adapters(self, r):
        self.global_unet.set_adapters(["default"], weights=[r])
        set_weights_and_activate_adapters(self.global_vae, ["vae_skip"], [r])

    def automatic_enhance_prompt(self, input_prompt, prompt_quality):
        if prompt_quality:
            result = self.global_medium_prompt("Enhance the description: " + input_prompt)
            enhanced_text = result[0]['summary_text']
            
            pattern = r'^.*?of\s+(.*?(?:\.|$))'
            match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL)
            
            if match:
                remaining_text = enhanced_text[match.end():].strip()
                modified_sentence = match.group(1).capitalize()
                enhanced_text = modified_sentence + ' ' + remaining_text
        else:
            enhanced_text = input_prompt
        return enhanced_text

    def _move_to_cpu(self, module):
        module.to("cpu")

    def _move_to_gpu(self, module):
        module.to("cuda")