gokaygokay commited on
Commit
4659d74
·
1 Parent(s): 6dc5714

kolorsplusplus

Browse files
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, pipeline
5
+ import re
6
+ import random
7
+ import os
8
+ from huggingface_hub import snapshot_download
9
+ from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline
10
+ from kolors.models.modeling_chatglm import ChatGLMModel
11
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
12
+ from diffusers import UNet2DConditionModel, AutoencoderKL
13
+ from diffusers import EulerDiscreteScheduler
14
+
15
+ # Initialize models
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ dtype = torch.float16
18
+
19
+ # Download Kolors model
20
+ ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
21
+
22
+ # Load Kolors models
23
+ text_encoder = ChatGLMModel.from_pretrained(os.path.join(ckpt_dir, 'text_encoder'), torch_dtype=dtype).to(device)
24
+ tokenizer = ChatGLMTokenizer.from_pretrained(os.path.join(ckpt_dir, 'text_encoder'))
25
+ vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), revision=None).to(dtype).to(device)
26
+ scheduler = EulerDiscreteScheduler.from_pretrained(os.path.join(ckpt_dir, "scheduler"))
27
+ unet = UNet2DConditionModel.from_pretrained(os.path.join(ckpt_dir, "unet"), revision=None).to(dtype).to(device)
28
+
29
+ kolors_pipe = StableDiffusionXLPipeline(
30
+ vae=vae,
31
+ text_encoder=text_encoder,
32
+ tokenizer=tokenizer,
33
+ unet=unet,
34
+ scheduler=scheduler,
35
+ force_zeros_for_empty_prompt=False
36
+ ).to(device)
37
+ kolors_pipe.enable_model_cpu_offload()
38
+
39
+ # VLM Captioner
40
+ vlm_model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner-v2").to(device).eval()
41
+ vlm_processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner-v2")
42
+
43
+ # Prompt Enhancer
44
+ enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
45
+ enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
46
+
47
+ MAX_SEED = 2**32 - 1
48
+
49
+ # VLM Captioner function
50
+ def create_captions_rich(image):
51
+ prompt = "caption en"
52
+ model_inputs = vlm_processor(text=prompt, images=image, return_tensors="pt").to(device)
53
+ input_len = model_inputs["input_ids"].shape[-1]
54
+
55
+ with torch.inference_mode():
56
+ generation = vlm_model.generate(**model_inputs, repetition_penalty=1.10, max_new_tokens=256, do_sample=False)
57
+ generation = generation[0][input_len:]
58
+ decoded = vlm_processor.decode(generation, skip_special_tokens=True)
59
+
60
+ return modify_caption(decoded)
61
+
62
+ # Helper function for caption modification
63
+ def modify_caption(caption: str) -> str:
64
+ prefix_substrings = [
65
+ ('captured from ', ''),
66
+ ('captured at ', '')
67
+ ]
68
+ pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
69
+ replacers = {opening: replacer for opening, replacer in prefix_substrings}
70
+
71
+ def replace_fn(match):
72
+ return replacers[match.group(0)]
73
+
74
+ return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
75
+
76
+ # Prompt Enhancer function
77
+ def enhance_prompt(input_prompt, model_choice):
78
+ if model_choice == "Medium":
79
+ result = enhancer_medium("Enhance the description: " + input_prompt)
80
+ enhanced_text = result[0]['summary_text']
81
+
82
+ pattern = r'^.*?of\s+(.*?(?:\.|$))'
83
+ match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL)
84
+
85
+ if match:
86
+ remaining_text = enhanced_text[match.end():].strip()
87
+ modified_sentence = match.group(1).capitalize()
88
+ enhanced_text = modified_sentence + ' ' + remaining_text
89
+ else: # Long
90
+ result = enhancer_long("Enhance the description: " + input_prompt)
91
+ enhanced_text = result[0]['summary_text']
92
+
93
+ return enhanced_text
94
+
95
+ def generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
96
+ if randomize_seed:
97
+ seed = random.randint(0, MAX_SEED)
98
+
99
+ generator = torch.Generator(device=device).manual_seed(seed)
100
+
101
+ image = kolors_pipe(
102
+ prompt=prompt,
103
+ negative_prompt=negative_prompt,
104
+ guidance_scale=guidance_scale,
105
+ num_inference_steps=num_inference_steps,
106
+ width=width,
107
+ height=height,
108
+ generator=generator
109
+ ).images[0]
110
+
111
+ return image, seed
112
+
113
+ # Gradio Interface
114
+ @spaces.GPU
115
+ def process_workflow(image, text_prompt, use_vlm, use_enhancer, model_choice, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
116
+ if use_vlm and image is not None:
117
+ prompt = create_captions_rich(image)
118
+ else:
119
+ prompt = text_prompt
120
+
121
+ if use_enhancer:
122
+ prompt = enhance_prompt(prompt, model_choice)
123
+
124
+ generated_image, used_seed = generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps)
125
+
126
+ return generated_image, prompt, used_seed
127
+
128
+ custom_css = """
129
+ .input-group, .output-group {
130
+ border: 1px solid #e0e0e0;
131
+ border-radius: 10px;
132
+ padding: 20px;
133
+ margin-bottom: 20px;
134
+ background-color: #f9f9f9;
135
+ }
136
+ .submit-btn {
137
+ background-color: #2980b9 !important;
138
+ color: white !important;
139
+ }
140
+ .submit-btn:hover {
141
+ background-color: #3498db !important;
142
+ }
143
+ """
144
+
145
+ title = """<h1 align="center">VLM Captioner + Prompt Enhancer + Kolors Image Generator</h1>
146
+ <p><center>
147
+ <a href="https://huggingface.co/spaces/gokaygokay/SD3-Long-Captioner-V2" target="_blank">[VLM Model]</a>
148
+ <a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long" target="_blank">[Prompt Enhancer Long]</a>
149
+ <a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance" target="_blank">[Prompt Enhancer Medium]</a>
150
+ <a href="https://huggingface.co/Kwai-Kolors/Kolors" target="_blank">[Kolors Model]</a>
151
+ <p align="center">Don't forget to click <b>Use VLM Captioner</b> or <b>Use Prompt Enhancer</b> Buttons!</p>
152
+ </center></p>
153
+ """
154
+
155
+ # Gradio Interface
156
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
157
+ gr.HTML(title)
158
+
159
+ with gr.Row():
160
+ with gr.Column(scale=1):
161
+ with gr.Group(elem_classes="input-group"):
162
+ input_image = gr.Image(label="Input Image for VLM")
163
+ use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False)
164
+
165
+ with gr.Group(elem_classes="input-group"):
166
+ text_prompt = gr.Textbox(label="Text Prompt")
167
+ use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
168
+ model_choice = gr.Radio(["Medium", "Long"], label="Enhancer Model", value="Long")
169
+
170
+ with gr.Accordion("Advanced Settings", open=False):
171
+ negative_prompt = gr.Textbox(label="Negative Prompt")
172
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
173
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
174
+ width = gr.Slider(label="Width", minimum=512, maximum=2048, step=64, value=1024)
175
+ height = gr.Slider(label="Height", minimum=512, maximum=2048, step=64, value=1024)
176
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, step=0.5, value=5.0)
177
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=20, maximum=50, step=1, value=20)
178
+
179
+ generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
180
+
181
+ with gr.Column(scale=1):
182
+ with gr.Group(elem_classes="output-group"):
183
+ output_image = gr.Image(label="Generated Image")
184
+ final_prompt = gr.Textbox(label="Final Prompt Used")
185
+ used_seed = gr.Number(label="Seed Used")
186
+
187
+ generate_btn.click(
188
+ fn=process_workflow,
189
+ inputs=[
190
+ input_image, text_prompt, use_vlm, use_enhancer, model_choice,
191
+ negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps
192
+ ],
193
+ outputs=[output_image, final_prompt, used_seed]
194
+ )
195
+
196
+ demo.launch(debug=True)
kolors/__init__.py ADDED
File without changes
kolors/models/__init__.py ADDED
File without changes
kolors/models/configuration_chatglm.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class ChatGLMConfig(PretrainedConfig):
5
+ model_type = "chatglm"
6
+ def __init__(
7
+ self,
8
+ num_layers=28,
9
+ padded_vocab_size=65024,
10
+ hidden_size=4096,
11
+ ffn_hidden_size=13696,
12
+ kv_channels=128,
13
+ num_attention_heads=32,
14
+ seq_length=2048,
15
+ hidden_dropout=0.0,
16
+ classifier_dropout=None,
17
+ attention_dropout=0.0,
18
+ layernorm_epsilon=1e-5,
19
+ rmsnorm=True,
20
+ apply_residual_connection_post_layernorm=False,
21
+ post_layer_norm=True,
22
+ add_bias_linear=False,
23
+ add_qkv_bias=False,
24
+ bias_dropout_fusion=True,
25
+ multi_query_attention=False,
26
+ multi_query_group_num=1,
27
+ apply_query_key_layer_scaling=True,
28
+ attention_softmax_in_fp32=True,
29
+ fp32_residual_connection=False,
30
+ quantization_bit=0,
31
+ pre_seq_len=None,
32
+ prefix_projection=False,
33
+ **kwargs
34
+ ):
35
+ self.num_layers = num_layers
36
+ self.vocab_size = padded_vocab_size
37
+ self.padded_vocab_size = padded_vocab_size
38
+ self.hidden_size = hidden_size
39
+ self.ffn_hidden_size = ffn_hidden_size
40
+ self.kv_channels = kv_channels
41
+ self.num_attention_heads = num_attention_heads
42
+ self.seq_length = seq_length
43
+ self.hidden_dropout = hidden_dropout
44
+ self.classifier_dropout = classifier_dropout
45
+ self.attention_dropout = attention_dropout
46
+ self.layernorm_epsilon = layernorm_epsilon
47
+ self.rmsnorm = rmsnorm
48
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
49
+ self.post_layer_norm = post_layer_norm
50
+ self.add_bias_linear = add_bias_linear
51
+ self.add_qkv_bias = add_qkv_bias
52
+ self.bias_dropout_fusion = bias_dropout_fusion
53
+ self.multi_query_attention = multi_query_attention
54
+ self.multi_query_group_num = multi_query_group_num
55
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
56
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
57
+ self.fp32_residual_connection = fp32_residual_connection
58
+ self.quantization_bit = quantization_bit
59
+ self.pre_seq_len = pre_seq_len
60
+ self.prefix_projection = prefix_projection
61
+ super().__init__(**kwargs)
kolors/models/modeling_chatglm.py ADDED
@@ -0,0 +1,1298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch ChatGLM model. """
2
+
3
+ import math
4
+ import copy
5
+ import warnings
6
+ import re
7
+ import sys
8
+
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+ from torch.nn import CrossEntropyLoss, LayerNorm
14
+ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
15
+ from torch.nn.utils import skip_init
16
+ from typing import Optional, Tuple, Union, List, Callable, Dict, Any
17
+ from copy import deepcopy
18
+
19
+ from transformers.modeling_outputs import (
20
+ BaseModelOutputWithPast,
21
+ CausalLMOutputWithPast,
22
+ SequenceClassifierOutputWithPast,
23
+ )
24
+ from transformers.modeling_utils import PreTrainedModel
25
+ from transformers.utils import logging
26
+ from transformers.generation.logits_process import LogitsProcessor
27
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
28
+
29
+ try:
30
+ from .configuration_chatglm import ChatGLMConfig
31
+ except:
32
+ from configuration_chatglm import ChatGLMConfig
33
+
34
+
35
+ # flags required to enable jit fusion kernels
36
+
37
+ if sys.platform != 'darwin':
38
+ torch._C._jit_set_profiling_mode(False)
39
+ torch._C._jit_set_profiling_executor(False)
40
+ torch._C._jit_override_can_fuse_on_cpu(True)
41
+ torch._C._jit_override_can_fuse_on_gpu(True)
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
46
+ _CONFIG_FOR_DOC = "ChatGLM6BConfig"
47
+
48
+ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
49
+ "THUDM/chatglm3-6b-base",
50
+ # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
51
+ ]
52
+
53
+
54
+ def default_init(cls, *args, **kwargs):
55
+ return cls(*args, **kwargs)
56
+
57
+
58
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
59
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
60
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
61
+ scores.zero_()
62
+ scores[..., 5] = 5e4
63
+ return scores
64
+
65
+
66
+ class PrefixEncoder(torch.nn.Module):
67
+ """
68
+ The torch.nn model to encode the prefix
69
+ Input shape: (batch-size, prefix-length)
70
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
71
+ """
72
+
73
+ def __init__(self, config: ChatGLMConfig):
74
+ super().__init__()
75
+ self.prefix_projection = config.prefix_projection
76
+ if self.prefix_projection:
77
+ # Use a two-layer MLP to encode the prefix
78
+ kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
79
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
80
+ self.trans = torch.nn.Sequential(
81
+ torch.nn.Linear(kv_size, config.hidden_size),
82
+ torch.nn.Tanh(),
83
+ torch.nn.Linear(config.hidden_size, kv_size)
84
+ )
85
+ else:
86
+ self.embedding = torch.nn.Embedding(config.pre_seq_len,
87
+ config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
88
+
89
+ def forward(self, prefix: torch.Tensor):
90
+ if self.prefix_projection:
91
+ prefix_tokens = self.embedding(prefix)
92
+ past_key_values = self.trans(prefix_tokens)
93
+ else:
94
+ past_key_values = self.embedding(prefix)
95
+ return past_key_values
96
+
97
+
98
+ def split_tensor_along_last_dim(
99
+ tensor: torch.Tensor,
100
+ num_partitions: int,
101
+ contiguous_split_chunks: bool = False,
102
+ ) -> List[torch.Tensor]:
103
+ """Split a tensor along its last dimension.
104
+
105
+ Arguments:
106
+ tensor: input tensor.
107
+ num_partitions: number of partitions to split the tensor
108
+ contiguous_split_chunks: If True, make each chunk contiguous
109
+ in memory.
110
+
111
+ Returns:
112
+ A list of Tensors
113
+ """
114
+ # Get the size and dimension.
115
+ last_dim = tensor.dim() - 1
116
+ last_dim_size = tensor.size()[last_dim] // num_partitions
117
+ # Split.
118
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
119
+ # Note: torch.split does not create contiguous tensors by default.
120
+ if contiguous_split_chunks:
121
+ return tuple(chunk.contiguous() for chunk in tensor_list)
122
+
123
+ return tensor_list
124
+
125
+
126
+ class RotaryEmbedding(nn.Module):
127
+ def __init__(self, dim, original_impl=False, device=None, dtype=None):
128
+ super().__init__()
129
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
130
+ self.register_buffer("inv_freq", inv_freq)
131
+ self.dim = dim
132
+ self.original_impl = original_impl
133
+
134
+ def forward_impl(
135
+ self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
136
+ ):
137
+ """Enhanced Transformer with Rotary Position Embedding.
138
+
139
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
140
+ transformers/rope/__init__.py. MIT License:
141
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
142
+ """
143
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
144
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
145
+
146
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
147
+ seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
148
+
149
+ # Calculate the product of position index and $\theta_i$
150
+ idx_theta = torch.outer(seq_idx, theta).float()
151
+
152
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
153
+
154
+ # this is to mimic the behaviour of complex32, else we will get different results
155
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
156
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
157
+ return cache
158
+
159
+ def forward(self, max_seq_len, offset=0):
160
+ return self.forward_impl(
161
+ max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
162
+ )
163
+
164
+
165
+ @torch.jit.script
166
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
167
+ # x: [sq, b, np, hn]
168
+ sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
169
+ rot_dim = rope_cache.shape[-2] * 2
170
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
171
+ # truncate to support variable sizes
172
+ rope_cache = rope_cache[:sq]
173
+ xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
174
+ rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
175
+ x_out2 = torch.stack(
176
+ [
177
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
178
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
179
+ ],
180
+ -1,
181
+ )
182
+ x_out2 = x_out2.flatten(3)
183
+ return torch.cat((x_out2, x_pass), dim=-1)
184
+
185
+
186
+ class RMSNorm(torch.nn.Module):
187
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
188
+ super().__init__()
189
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
190
+ self.eps = eps
191
+
192
+ def forward(self, hidden_states: torch.Tensor):
193
+ input_dtype = hidden_states.dtype
194
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
195
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
196
+
197
+ return (self.weight * hidden_states).to(input_dtype)
198
+
199
+
200
+ class CoreAttention(torch.nn.Module):
201
+ def __init__(self, config: ChatGLMConfig, layer_number):
202
+ super(CoreAttention, self).__init__()
203
+
204
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
205
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
206
+ if self.apply_query_key_layer_scaling:
207
+ self.attention_softmax_in_fp32 = True
208
+ self.layer_number = max(1, layer_number)
209
+
210
+ projection_size = config.kv_channels * config.num_attention_heads
211
+
212
+ # Per attention head and per partition values.
213
+ self.hidden_size_per_partition = projection_size
214
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
215
+ self.num_attention_heads_per_partition = config.num_attention_heads
216
+
217
+ coeff = None
218
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
219
+ if self.apply_query_key_layer_scaling:
220
+ coeff = self.layer_number
221
+ self.norm_factor *= coeff
222
+ self.coeff = coeff
223
+
224
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
225
+
226
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
227
+ pytorch_major_version = int(torch.__version__.split('.')[0])
228
+ if pytorch_major_version >= 2:
229
+ query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
230
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
231
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
232
+ is_causal=True)
233
+ else:
234
+ if attention_mask is not None:
235
+ attention_mask = ~attention_mask
236
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
237
+ attention_mask)
238
+ context_layer = context_layer.permute(2, 0, 1, 3)
239
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
240
+ context_layer = context_layer.reshape(*new_context_layer_shape)
241
+ else:
242
+ # Raw attention scores
243
+
244
+ # [b, np, sq, sk]
245
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
246
+
247
+ # [sq, b, np, hn] -> [sq, b * np, hn]
248
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
249
+ # [sk, b, np, hn] -> [sk, b * np, hn]
250
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
251
+
252
+ # preallocting input tensor: [b * np, sq, sk]
253
+ matmul_input_buffer = torch.empty(
254
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
255
+ device=query_layer.device
256
+ )
257
+
258
+ # Raw attention scores. [b * np, sq, sk]
259
+ matmul_result = torch.baddbmm(
260
+ matmul_input_buffer,
261
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
262
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
263
+ beta=0.0,
264
+ alpha=(1.0 / self.norm_factor),
265
+ )
266
+
267
+ # change view to [b, np, sq, sk]
268
+ attention_scores = matmul_result.view(*output_size)
269
+
270
+ # ===========================
271
+ # Attention probs and dropout
272
+ # ===========================
273
+
274
+ # attention scores and attention mask [b, np, sq, sk]
275
+ if self.attention_softmax_in_fp32:
276
+ attention_scores = attention_scores.float()
277
+ if self.coeff is not None:
278
+ attention_scores = attention_scores * self.coeff
279
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
280
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
281
+ device=attention_scores.device, dtype=torch.bool)
282
+ attention_mask.tril_()
283
+ attention_mask = ~attention_mask
284
+ if attention_mask is not None:
285
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
286
+ attention_probs = F.softmax(attention_scores, dim=-1)
287
+ attention_probs = attention_probs.type_as(value_layer)
288
+
289
+ # This is actually dropping out entire tokens to attend to, which might
290
+ # seem a bit unusual, but is taken from the original Transformer paper.
291
+ attention_probs = self.attention_dropout(attention_probs)
292
+ # =========================
293
+ # Context layer. [sq, b, hp]
294
+ # =========================
295
+
296
+ # value_layer -> context layer.
297
+ # [sk, b, np, hn] --> [b, np, sq, hn]
298
+
299
+ # context layer shape: [b, np, sq, hn]
300
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
301
+ # change view [sk, b * np, hn]
302
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
303
+ # change view [b * np, sq, sk]
304
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
305
+ # matmul: [b * np, sq, hn]
306
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
307
+ # change view [b, np, sq, hn]
308
+ context_layer = context_layer.view(*output_size)
309
+ # [b, np, sq, hn] --> [sq, b, np, hn]
310
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
311
+ # [sq, b, np, hn] --> [sq, b, hp]
312
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
313
+ context_layer = context_layer.view(*new_context_layer_shape)
314
+
315
+ return context_layer
316
+
317
+
318
+ class SelfAttention(torch.nn.Module):
319
+ """Parallel self-attention layer abstract class.
320
+
321
+ Self-attention layer takes input with size [s, b, h]
322
+ and returns output of the same size.
323
+ """
324
+
325
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
326
+ super(SelfAttention, self).__init__()
327
+ self.layer_number = max(1, layer_number)
328
+
329
+ self.projection_size = config.kv_channels * config.num_attention_heads
330
+
331
+ # Per attention head and per partition values.
332
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
333
+ self.num_attention_heads_per_partition = config.num_attention_heads
334
+
335
+ self.multi_query_attention = config.multi_query_attention
336
+ self.qkv_hidden_size = 3 * self.projection_size
337
+ if self.multi_query_attention:
338
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
339
+ self.qkv_hidden_size = (
340
+ self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
341
+ )
342
+ self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
343
+ bias=config.add_bias_linear or config.add_qkv_bias,
344
+ device=device, **_config_to_kwargs(config)
345
+ )
346
+
347
+ self.core_attention = CoreAttention(config, self.layer_number)
348
+
349
+ # Output.
350
+ self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
351
+ device=device, **_config_to_kwargs(config)
352
+ )
353
+
354
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
355
+ if self.multi_query_attention:
356
+ num_attention_heads = self.num_multi_query_groups_per_partition
357
+ else:
358
+ num_attention_heads = self.num_attention_heads_per_partition
359
+ return torch.empty(
360
+ inference_max_sequence_len,
361
+ batch_size,
362
+ num_attention_heads,
363
+ self.hidden_size_per_attention_head,
364
+ dtype=dtype,
365
+ device=device,
366
+ )
367
+
368
+ def forward(
369
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
370
+ ):
371
+ # hidden_states: [sq, b, h]
372
+
373
+ # =================================================
374
+ # Pre-allocate memory for key-values for inference.
375
+ # =================================================
376
+ # =====================
377
+ # Query, Key, and Value
378
+ # =====================
379
+
380
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
381
+ mixed_x_layer = self.query_key_value(hidden_states)
382
+
383
+ if self.multi_query_attention:
384
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
385
+ [
386
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
387
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
388
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
389
+ ],
390
+ dim=-1,
391
+ )
392
+ query_layer = query_layer.view(
393
+ query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
394
+ )
395
+ key_layer = key_layer.view(
396
+ key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
397
+ )
398
+ value_layer = value_layer.view(
399
+ value_layer.size()[:-1]
400
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
401
+ )
402
+ else:
403
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
404
+ (self.num_attention_heads_per_partition,
405
+ 3 * self.hidden_size_per_attention_head)
406
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
407
+
408
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
409
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
410
+
411
+ # apply relative positional encoding (rotary embedding)
412
+ if rotary_pos_emb is not None:
413
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
414
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
415
+
416
+ # adjust key and value for inference
417
+ if kv_cache is not None:
418
+ cache_k, cache_v = kv_cache
419
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
420
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
421
+ if use_cache:
422
+ kv_cache = (key_layer, value_layer)
423
+ else:
424
+ kv_cache = None
425
+
426
+ if self.multi_query_attention:
427
+ key_layer = key_layer.unsqueeze(-2)
428
+ key_layer = key_layer.expand(
429
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
430
+ )
431
+ key_layer = key_layer.contiguous().view(
432
+ key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
433
+ )
434
+ value_layer = value_layer.unsqueeze(-2)
435
+ value_layer = value_layer.expand(
436
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
437
+ )
438
+ value_layer = value_layer.contiguous().view(
439
+ value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
440
+ )
441
+
442
+ # ==================================
443
+ # core attention computation
444
+ # ==================================
445
+
446
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
447
+
448
+ # =================
449
+ # Output. [sq, b, h]
450
+ # =================
451
+
452
+ output = self.dense(context_layer)
453
+
454
+ return output, kv_cache
455
+
456
+
457
+ def _config_to_kwargs(args):
458
+ common_kwargs = {
459
+ "dtype": args.torch_dtype,
460
+ }
461
+ return common_kwargs
462
+
463
+
464
+ class MLP(torch.nn.Module):
465
+ """MLP.
466
+
467
+ MLP will take the input with h hidden state, project it to 4*h
468
+ hidden dimension, perform nonlinear transformation, and project the
469
+ state back into h hidden dimension.
470
+ """
471
+
472
+ def __init__(self, config: ChatGLMConfig, device=None):
473
+ super(MLP, self).__init__()
474
+
475
+ self.add_bias = config.add_bias_linear
476
+
477
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
478
+ self.dense_h_to_4h = nn.Linear(
479
+ config.hidden_size,
480
+ config.ffn_hidden_size * 2,
481
+ bias=self.add_bias,
482
+ device=device,
483
+ **_config_to_kwargs(config)
484
+ )
485
+
486
+ def swiglu(x):
487
+ x = torch.chunk(x, 2, dim=-1)
488
+ return F.silu(x[0]) * x[1]
489
+
490
+ self.activation_func = swiglu
491
+
492
+ # Project back to h.
493
+ self.dense_4h_to_h = nn.Linear(
494
+ config.ffn_hidden_size,
495
+ config.hidden_size,
496
+ bias=self.add_bias,
497
+ device=device,
498
+ **_config_to_kwargs(config)
499
+ )
500
+
501
+ def forward(self, hidden_states):
502
+ # [s, b, 4hp]
503
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
504
+ intermediate_parallel = self.activation_func(intermediate_parallel)
505
+ # [s, b, h]
506
+ output = self.dense_4h_to_h(intermediate_parallel)
507
+ return output
508
+
509
+
510
+ class GLMBlock(torch.nn.Module):
511
+ """A single transformer layer.
512
+
513
+ Transformer layer takes input with size [s, b, h] and returns an
514
+ output of the same size.
515
+ """
516
+
517
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
518
+ super(GLMBlock, self).__init__()
519
+ self.layer_number = layer_number
520
+
521
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
522
+
523
+ self.fp32_residual_connection = config.fp32_residual_connection
524
+
525
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
526
+ # Layernorm on the input data.
527
+ self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
528
+ dtype=config.torch_dtype)
529
+
530
+ # Self attention.
531
+ self.self_attention = SelfAttention(config, layer_number, device=device)
532
+ self.hidden_dropout = config.hidden_dropout
533
+
534
+ # Layernorm on the attention output
535
+ self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
536
+ dtype=config.torch_dtype)
537
+
538
+ # MLP
539
+ self.mlp = MLP(config, device=device)
540
+
541
+ def forward(
542
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
543
+ ):
544
+ # hidden_states: [s, b, h]
545
+
546
+ # Layer norm at the beginning of the transformer layer.
547
+ layernorm_output = self.input_layernorm(hidden_states)
548
+ # Self attention.
549
+ attention_output, kv_cache = self.self_attention(
550
+ layernorm_output,
551
+ attention_mask,
552
+ rotary_pos_emb,
553
+ kv_cache=kv_cache,
554
+ use_cache=use_cache
555
+ )
556
+
557
+ # Residual connection.
558
+ if self.apply_residual_connection_post_layernorm:
559
+ residual = layernorm_output
560
+ else:
561
+ residual = hidden_states
562
+
563
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
564
+ layernorm_input = residual + layernorm_input
565
+
566
+ # Layer norm post the self attention.
567
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
568
+
569
+ # MLP.
570
+ mlp_output = self.mlp(layernorm_output)
571
+
572
+ # Second residual connection.
573
+ if self.apply_residual_connection_post_layernorm:
574
+ residual = layernorm_output
575
+ else:
576
+ residual = layernorm_input
577
+
578
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
579
+ output = residual + output
580
+
581
+ return output, kv_cache
582
+
583
+
584
+ class GLMTransformer(torch.nn.Module):
585
+ """Transformer class."""
586
+
587
+ def __init__(self, config: ChatGLMConfig, device=None):
588
+ super(GLMTransformer, self).__init__()
589
+
590
+ self.fp32_residual_connection = config.fp32_residual_connection
591
+ self.post_layer_norm = config.post_layer_norm
592
+
593
+ # Number of layers.
594
+ self.num_layers = config.num_layers
595
+
596
+ # Transformer layers.
597
+ def build_layer(layer_number):
598
+ return GLMBlock(config, layer_number, device=device)
599
+
600
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
601
+
602
+ if self.post_layer_norm:
603
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
604
+ # Final layer norm before output.
605
+ self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
606
+ dtype=config.torch_dtype)
607
+
608
+ self.gradient_checkpointing = False
609
+
610
+ def _get_layer(self, layer_number):
611
+ return self.layers[layer_number]
612
+
613
+ def forward(
614
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
615
+ use_cache: Optional[bool] = True,
616
+ output_hidden_states: Optional[bool] = False,
617
+ ):
618
+ if not kv_caches:
619
+ kv_caches = [None for _ in range(self.num_layers)]
620
+ presents = () if use_cache else None
621
+ if self.gradient_checkpointing and self.training:
622
+ if use_cache:
623
+ logger.warning_once(
624
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
625
+ )
626
+ use_cache = False
627
+
628
+ all_self_attentions = None
629
+ all_hidden_states = () if output_hidden_states else None
630
+ for index in range(self.num_layers):
631
+ if output_hidden_states:
632
+ all_hidden_states = all_hidden_states + (hidden_states,)
633
+
634
+ layer = self._get_layer(index)
635
+ if self.gradient_checkpointing and self.training:
636
+ layer_ret = torch.utils.checkpoint.checkpoint(
637
+ layer,
638
+ hidden_states,
639
+ attention_mask,
640
+ rotary_pos_emb,
641
+ kv_caches[index],
642
+ use_cache
643
+ )
644
+ else:
645
+ layer_ret = layer(
646
+ hidden_states,
647
+ attention_mask,
648
+ rotary_pos_emb,
649
+ kv_cache=kv_caches[index],
650
+ use_cache=use_cache
651
+ )
652
+ hidden_states, kv_cache = layer_ret
653
+ if use_cache:
654
+ presents = presents + (kv_cache,)
655
+
656
+ if output_hidden_states:
657
+ all_hidden_states = all_hidden_states + (hidden_states,)
658
+
659
+ # Final layer norm.
660
+ if self.post_layer_norm:
661
+ hidden_states = self.final_layernorm(hidden_states)
662
+
663
+ return hidden_states, presents, all_hidden_states, all_self_attentions
664
+
665
+
666
+ class ChatGLMPreTrainedModel(PreTrainedModel):
667
+ """
668
+ An abstract class to handle weights initialization and
669
+ a simple interface for downloading and loading pretrained models.
670
+ """
671
+
672
+ is_parallelizable = False
673
+ supports_gradient_checkpointing = True
674
+ config_class = ChatGLMConfig
675
+ base_model_prefix = "transformer"
676
+ _no_split_modules = ["GLMBlock"]
677
+
678
+ def _init_weights(self, module: nn.Module):
679
+ """Initialize the weights."""
680
+ return
681
+
682
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
683
+ batch_size, seq_length = input_ids.shape
684
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
685
+ full_attention_mask.tril_()
686
+ past_length = 0
687
+ if past_key_values:
688
+ past_length = past_key_values[0][0].shape[0]
689
+ if past_length:
690
+ full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
691
+ device=input_ids.device), full_attention_mask), dim=-1)
692
+ if padding_mask is not None:
693
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
694
+ if not past_length and padding_mask is not None:
695
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
696
+ full_attention_mask = (full_attention_mask < 0.5).bool()
697
+ full_attention_mask.unsqueeze_(1)
698
+ return full_attention_mask
699
+
700
+ def get_position_ids(self, input_ids, device):
701
+ batch_size, seq_length = input_ids.shape
702
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
703
+ return position_ids
704
+
705
+ def _set_gradient_checkpointing(self, module, value=False):
706
+ if isinstance(module, GLMTransformer):
707
+ module.gradient_checkpointing = value
708
+
709
+
710
+ class Embedding(torch.nn.Module):
711
+ """Language model embeddings."""
712
+
713
+ def __init__(self, config: ChatGLMConfig, device=None):
714
+ super(Embedding, self).__init__()
715
+
716
+ self.hidden_size = config.hidden_size
717
+ # Word embeddings (parallel).
718
+ self.word_embeddings = nn.Embedding(
719
+ config.padded_vocab_size,
720
+ self.hidden_size,
721
+ dtype=config.torch_dtype,
722
+ device=device
723
+ )
724
+ self.fp32_residual_connection = config.fp32_residual_connection
725
+
726
+ def forward(self, input_ids):
727
+ # Embeddings.
728
+ words_embeddings = self.word_embeddings(input_ids)
729
+ embeddings = words_embeddings
730
+ # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
731
+ embeddings = embeddings.transpose(0, 1).contiguous()
732
+ # If the input flag for fp32 residual connection is set, convert for float.
733
+ if self.fp32_residual_connection:
734
+ embeddings = embeddings.float()
735
+ return embeddings
736
+
737
+
738
+ class ChatGLMModel(ChatGLMPreTrainedModel):
739
+ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
740
+ super().__init__(config)
741
+ if empty_init:
742
+ init_method = skip_init
743
+ else:
744
+ init_method = default_init
745
+ init_kwargs = {}
746
+ if device is not None:
747
+ init_kwargs["device"] = device
748
+ self.embedding = init_method(Embedding, config, **init_kwargs)
749
+ self.num_layers = config.num_layers
750
+ self.multi_query_group_num = config.multi_query_group_num
751
+ self.kv_channels = config.kv_channels
752
+
753
+ # Rotary positional embeddings
754
+ self.seq_length = config.seq_length
755
+ rotary_dim = (
756
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
757
+ )
758
+
759
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
760
+ dtype=config.torch_dtype)
761
+ self.encoder = init_method(GLMTransformer, config, **init_kwargs)
762
+ self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
763
+ dtype=config.torch_dtype, **init_kwargs)
764
+ self.pre_seq_len = config.pre_seq_len
765
+ self.prefix_projection = config.prefix_projection
766
+ if self.pre_seq_len is not None:
767
+ for param in self.parameters():
768
+ param.requires_grad = False
769
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
770
+ self.prefix_encoder = PrefixEncoder(config)
771
+ self.dropout = torch.nn.Dropout(0.1)
772
+
773
+ def get_input_embeddings(self):
774
+ return self.embedding.word_embeddings
775
+
776
+ def get_prompt(self, batch_size, device, dtype=torch.half):
777
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
778
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
779
+ past_key_values = past_key_values.view(
780
+ batch_size,
781
+ self.pre_seq_len,
782
+ self.num_layers * 2,
783
+ self.multi_query_group_num,
784
+ self.kv_channels
785
+ )
786
+ # seq_len, b, nh, hidden_size
787
+ past_key_values = self.dropout(past_key_values)
788
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
789
+ return past_key_values
790
+
791
+ def forward(
792
+ self,
793
+ input_ids,
794
+ position_ids: Optional[torch.Tensor] = None,
795
+ attention_mask: Optional[torch.BoolTensor] = None,
796
+ full_attention_mask: Optional[torch.BoolTensor] = None,
797
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
798
+ inputs_embeds: Optional[torch.Tensor] = None,
799
+ use_cache: Optional[bool] = None,
800
+ output_hidden_states: Optional[bool] = None,
801
+ return_dict: Optional[bool] = None,
802
+ ):
803
+ output_hidden_states = (
804
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
805
+ )
806
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
807
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
808
+
809
+ batch_size, seq_length = input_ids.shape
810
+
811
+ if inputs_embeds is None:
812
+ inputs_embeds = self.embedding(input_ids)
813
+
814
+ if self.pre_seq_len is not None:
815
+ if past_key_values is None:
816
+ past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
817
+ dtype=inputs_embeds.dtype)
818
+ if attention_mask is not None:
819
+ attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
820
+ attention_mask], dim=-1)
821
+
822
+ if full_attention_mask is None:
823
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
824
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
825
+
826
+ # Rotary positional embeddings
827
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
828
+ if position_ids is not None:
829
+ rotary_pos_emb = rotary_pos_emb[position_ids]
830
+ else:
831
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
832
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
833
+
834
+ # Run encoder.
835
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
836
+ inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
837
+ kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
838
+ )
839
+
840
+ if not return_dict:
841
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
842
+
843
+ return BaseModelOutputWithPast(
844
+ last_hidden_state=hidden_states,
845
+ past_key_values=presents,
846
+ hidden_states=all_hidden_states,
847
+ attentions=all_self_attentions,
848
+ )
849
+
850
+ def quantize(self, weight_bit_width: int):
851
+ from .quantization import quantize
852
+ quantize(self.encoder, weight_bit_width)
853
+ return self
854
+
855
+
856
+ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
857
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
858
+ super().__init__(config)
859
+
860
+ self.max_sequence_length = config.max_length
861
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
862
+ self.config = config
863
+ self.quantized = False
864
+
865
+ if self.config.quantization_bit:
866
+ self.quantize(self.config.quantization_bit, empty_init=True)
867
+
868
+ def _update_model_kwargs_for_generation(
869
+ self,
870
+ outputs: ModelOutput,
871
+ model_kwargs: Dict[str, Any],
872
+ is_encoder_decoder: bool = False,
873
+ standardize_cache_format: bool = False,
874
+ ) -> Dict[str, Any]:
875
+ # update past_key_values
876
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
877
+ outputs, standardize_cache_format=standardize_cache_format
878
+ )
879
+
880
+ # update attention mask
881
+ if "attention_mask" in model_kwargs:
882
+ attention_mask = model_kwargs["attention_mask"]
883
+ model_kwargs["attention_mask"] = torch.cat(
884
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
885
+ )
886
+
887
+ # update position ids
888
+ if "position_ids" in model_kwargs:
889
+ position_ids = model_kwargs["position_ids"]
890
+ new_position_id = position_ids[..., -1:].clone()
891
+ new_position_id += 1
892
+ model_kwargs["position_ids"] = torch.cat(
893
+ [position_ids, new_position_id], dim=-1
894
+ )
895
+
896
+ model_kwargs["is_first_forward"] = False
897
+ return model_kwargs
898
+
899
+ def prepare_inputs_for_generation(
900
+ self,
901
+ input_ids: torch.LongTensor,
902
+ past_key_values: Optional[torch.Tensor] = None,
903
+ attention_mask: Optional[torch.Tensor] = None,
904
+ position_ids: Optional[torch.Tensor] = None,
905
+ use_cache: Optional[bool] = None,
906
+ is_first_forward: bool = True,
907
+ **kwargs
908
+ ) -> dict:
909
+ # only last token for input_ids if past is not None
910
+ if position_ids is None:
911
+ position_ids = self.get_position_ids(input_ids, device=input_ids.device)
912
+ if not is_first_forward:
913
+ if past_key_values is not None:
914
+ position_ids = position_ids[..., -1:]
915
+ input_ids = input_ids[:, -1:]
916
+ return {
917
+ "input_ids": input_ids,
918
+ "past_key_values": past_key_values,
919
+ "position_ids": position_ids,
920
+ "attention_mask": attention_mask,
921
+ "return_last_logit": True,
922
+ "use_cache": use_cache
923
+ }
924
+
925
+ def forward(
926
+ self,
927
+ input_ids: Optional[torch.Tensor] = None,
928
+ position_ids: Optional[torch.Tensor] = None,
929
+ attention_mask: Optional[torch.Tensor] = None,
930
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
931
+ inputs_embeds: Optional[torch.Tensor] = None,
932
+ labels: Optional[torch.Tensor] = None,
933
+ use_cache: Optional[bool] = None,
934
+ output_attentions: Optional[bool] = None,
935
+ output_hidden_states: Optional[bool] = None,
936
+ return_dict: Optional[bool] = None,
937
+ return_last_logit: Optional[bool] = False,
938
+ ):
939
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
940
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
941
+
942
+ transformer_outputs = self.transformer(
943
+ input_ids=input_ids,
944
+ position_ids=position_ids,
945
+ attention_mask=attention_mask,
946
+ past_key_values=past_key_values,
947
+ inputs_embeds=inputs_embeds,
948
+ use_cache=use_cache,
949
+ output_hidden_states=output_hidden_states,
950
+ return_dict=return_dict,
951
+ )
952
+
953
+ hidden_states = transformer_outputs[0]
954
+ if return_last_logit:
955
+ hidden_states = hidden_states[-1:]
956
+ lm_logits = self.transformer.output_layer(hidden_states)
957
+ lm_logits = lm_logits.transpose(0, 1).contiguous()
958
+
959
+ loss = None
960
+ if labels is not None:
961
+ lm_logits = lm_logits.to(torch.float32)
962
+
963
+ # Shift so that tokens < n predict n
964
+ shift_logits = lm_logits[..., :-1, :].contiguous()
965
+ shift_labels = labels[..., 1:].contiguous()
966
+ # Flatten the tokens
967
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
968
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
969
+
970
+ lm_logits = lm_logits.to(hidden_states.dtype)
971
+ loss = loss.to(hidden_states.dtype)
972
+
973
+ if not return_dict:
974
+ output = (lm_logits,) + transformer_outputs[1:]
975
+ return ((loss,) + output) if loss is not None else output
976
+
977
+ return CausalLMOutputWithPast(
978
+ loss=loss,
979
+ logits=lm_logits,
980
+ past_key_values=transformer_outputs.past_key_values,
981
+ hidden_states=transformer_outputs.hidden_states,
982
+ attentions=transformer_outputs.attentions,
983
+ )
984
+
985
+ @staticmethod
986
+ def _reorder_cache(
987
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
988
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
989
+ """
990
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
991
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
992
+ beam_idx at every generation step.
993
+
994
+ Output shares the same memory storage as `past`.
995
+ """
996
+ return tuple(
997
+ (
998
+ layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
999
+ layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
1000
+ )
1001
+ for layer_past in past
1002
+ )
1003
+
1004
+ def process_response(self, output, history):
1005
+ content = ""
1006
+ history = deepcopy(history)
1007
+ for response in output.split("<|assistant|>"):
1008
+ metadata, content = response.split("\n", maxsplit=1)
1009
+ if not metadata.strip():
1010
+ content = content.strip()
1011
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1012
+ content = content.replace("[[训练时间]]", "2023年")
1013
+ else:
1014
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1015
+ if history[0]["role"] == "system" and "tools" in history[0]:
1016
+ content = "\n".join(content.split("\n")[1:-1])
1017
+ def tool_call(**kwargs):
1018
+ return kwargs
1019
+ parameters = eval(content)
1020
+ content = {"name": metadata.strip(), "parameters": parameters}
1021
+ else:
1022
+ content = {"name": metadata.strip(), "content": content}
1023
+ return content, history
1024
+
1025
+ @torch.inference_mode()
1026
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
1027
+ max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1028
+ **kwargs):
1029
+ if history is None:
1030
+ history = []
1031
+ if logits_processor is None:
1032
+ logits_processor = LogitsProcessorList()
1033
+ logits_processor.append(InvalidScoreLogitsProcessor())
1034
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1035
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1036
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
1037
+ inputs = inputs.to(self.device)
1038
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1039
+ tokenizer.get_command("<|observation|>")]
1040
+ outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1041
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1042
+ response = tokenizer.decode(outputs)
1043
+ history.append({"role": role, "content": query})
1044
+ response, history = self.process_response(response, history)
1045
+ return response, history
1046
+
1047
+ @torch.inference_mode()
1048
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
1049
+ past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
1050
+ logits_processor=None, return_past_key_values=False, **kwargs):
1051
+ if history is None:
1052
+ history = []
1053
+ if logits_processor is None:
1054
+ logits_processor = LogitsProcessorList()
1055
+ logits_processor.append(InvalidScoreLogitsProcessor())
1056
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1057
+ tokenizer.get_command("<|observation|>")]
1058
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1059
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1060
+ if past_key_values is None:
1061
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
1062
+ else:
1063
+ inputs = tokenizer.build_chat_input(query, role=role)
1064
+ inputs = inputs.to(self.device)
1065
+ if past_key_values is not None:
1066
+ past_length = past_key_values[0][0].shape[0]
1067
+ if self.transformer.pre_seq_len is not None:
1068
+ past_length -= self.transformer.pre_seq_len
1069
+ inputs.position_ids += past_length
1070
+ attention_mask = inputs.attention_mask
1071
+ attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
1072
+ inputs['attention_mask'] = attention_mask
1073
+ history.append({"role": role, "content": query})
1074
+ for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1075
+ eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
1076
+ **gen_kwargs):
1077
+ if return_past_key_values:
1078
+ outputs, past_key_values = outputs
1079
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1080
+ response = tokenizer.decode(outputs)
1081
+ if response and response[-1] != "�":
1082
+ response, new_history = self.process_response(response, history)
1083
+ if return_past_key_values:
1084
+ yield response, new_history, past_key_values
1085
+ else:
1086
+ yield response, new_history
1087
+
1088
+ @torch.inference_mode()
1089
+ def stream_generate(
1090
+ self,
1091
+ input_ids,
1092
+ generation_config: Optional[GenerationConfig] = None,
1093
+ logits_processor: Optional[LogitsProcessorList] = None,
1094
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1095
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1096
+ return_past_key_values=False,
1097
+ **kwargs,
1098
+ ):
1099
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1100
+
1101
+ if generation_config is None:
1102
+ generation_config = self.generation_config
1103
+ generation_config = copy.deepcopy(generation_config)
1104
+ model_kwargs = generation_config.update(**kwargs)
1105
+ model_kwargs["use_cache"] = generation_config.use_cache
1106
+ bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1107
+
1108
+ if isinstance(eos_token_id, int):
1109
+ eos_token_id = [eos_token_id]
1110
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
1111
+
1112
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1113
+ if has_default_max_length and generation_config.max_new_tokens is None:
1114
+ warnings.warn(
1115
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1116
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1117
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
1118
+ UserWarning,
1119
+ )
1120
+ elif generation_config.max_new_tokens is not None:
1121
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1122
+ if not has_default_max_length:
1123
+ logger.warn(
1124
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1125
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1126
+ "Please refer to the documentation for more information. "
1127
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1128
+ UserWarning,
1129
+ )
1130
+
1131
+ if input_ids_seq_length >= generation_config.max_length:
1132
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1133
+ logger.warning(
1134
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1135
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1136
+ " increasing `max_new_tokens`."
1137
+ )
1138
+
1139
+ # 2. Set generation parameters if not already defined
1140
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1141
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1142
+
1143
+ logits_processor = self._get_logits_processor(
1144
+ generation_config=generation_config,
1145
+ input_ids_seq_length=input_ids_seq_length,
1146
+ encoder_input_ids=input_ids,
1147
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1148
+ logits_processor=logits_processor,
1149
+ )
1150
+
1151
+ stopping_criteria = self._get_stopping_criteria(
1152
+ generation_config=generation_config, stopping_criteria=stopping_criteria
1153
+ )
1154
+ logits_warper = self._get_logits_warper(generation_config)
1155
+
1156
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1157
+ scores = None
1158
+ while True:
1159
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1160
+ # forward pass to get next token
1161
+ outputs = self(
1162
+ **model_inputs,
1163
+ return_dict=True,
1164
+ output_attentions=False,
1165
+ output_hidden_states=False,
1166
+ )
1167
+
1168
+ next_token_logits = outputs.logits[:, -1, :]
1169
+
1170
+ # pre-process distribution
1171
+ next_token_scores = logits_processor(input_ids, next_token_logits)
1172
+ next_token_scores = logits_warper(input_ids, next_token_scores)
1173
+
1174
+ # sample
1175
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1176
+ if generation_config.do_sample:
1177
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1178
+ else:
1179
+ next_tokens = torch.argmax(probs, dim=-1)
1180
+ # update generated ids, model inputs, and length for next step
1181
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1182
+ model_kwargs = self._update_model_kwargs_for_generation(
1183
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1184
+ )
1185
+ unfinished_sequences = unfinished_sequences.mul(
1186
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
1187
+ )
1188
+ if return_past_key_values:
1189
+ yield input_ids, outputs.past_key_values
1190
+ else:
1191
+ yield input_ids
1192
+ # stop when each sentence is finished, or if we exceed the maximum length
1193
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1194
+ break
1195
+
1196
+ def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
1197
+ if bits == 0:
1198
+ return
1199
+
1200
+ from .quantization import quantize
1201
+
1202
+ if self.quantized:
1203
+ logger.info("Already quantized.")
1204
+ return self
1205
+
1206
+ self.quantized = True
1207
+
1208
+ self.config.quantization_bit = bits
1209
+
1210
+ self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
1211
+ **kwargs)
1212
+ return self
1213
+
1214
+
1215
+ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1216
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1217
+ super().__init__(config)
1218
+
1219
+ self.num_labels = config.num_labels
1220
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1221
+
1222
+ self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
1223
+ if config.classifier_dropout is not None:
1224
+ self.dropout = nn.Dropout(config.classifier_dropout)
1225
+ else:
1226
+ self.dropout = None
1227
+ self.config = config
1228
+
1229
+ if self.config.quantization_bit:
1230
+ self.quantize(self.config.quantization_bit, empty_init=True)
1231
+
1232
+ def forward(
1233
+ self,
1234
+ input_ids: Optional[torch.LongTensor] = None,
1235
+ position_ids: Optional[torch.LongTensor] = None,
1236
+ attention_mask: Optional[torch.Tensor] = None,
1237
+ full_attention_mask: Optional[torch.Tensor] = None,
1238
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1239
+ inputs_embeds: Optional[torch.LongTensor] = None,
1240
+ labels: Optional[torch.LongTensor] = None,
1241
+ use_cache: Optional[bool] = None,
1242
+ output_hidden_states: Optional[bool] = None,
1243
+ return_dict: Optional[bool] = None,
1244
+ ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
1245
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1246
+
1247
+ transformer_outputs = self.transformer(
1248
+ input_ids=input_ids,
1249
+ position_ids=position_ids,
1250
+ attention_mask=attention_mask,
1251
+ full_attention_mask=full_attention_mask,
1252
+ past_key_values=past_key_values,
1253
+ inputs_embeds=inputs_embeds,
1254
+ use_cache=use_cache,
1255
+ output_hidden_states=output_hidden_states,
1256
+ return_dict=return_dict,
1257
+ )
1258
+
1259
+ hidden_states = transformer_outputs[0]
1260
+ pooled_hidden_states = hidden_states[-1]
1261
+ if self.dropout is not None:
1262
+ pooled_hidden_states = self.dropout(pooled_hidden_states)
1263
+ logits = self.classifier_head(pooled_hidden_states)
1264
+
1265
+ loss = None
1266
+ if labels is not None:
1267
+ if self.config.problem_type is None:
1268
+ if self.num_labels == 1:
1269
+ self.config.problem_type = "regression"
1270
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1271
+ self.config.problem_type = "single_label_classification"
1272
+ else:
1273
+ self.config.problem_type = "multi_label_classification"
1274
+
1275
+ if self.config.problem_type == "regression":
1276
+ loss_fct = MSELoss()
1277
+ if self.num_labels == 1:
1278
+ loss = loss_fct(logits.squeeze().float(), labels.squeeze())
1279
+ else:
1280
+ loss = loss_fct(logits.float(), labels)
1281
+ elif self.config.problem_type == "single_label_classification":
1282
+ loss_fct = CrossEntropyLoss()
1283
+ loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
1284
+ elif self.config.problem_type == "multi_label_classification":
1285
+ loss_fct = BCEWithLogitsLoss()
1286
+ loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
1287
+
1288
+ if not return_dict:
1289
+ output = (logits,) + transformer_outputs[1:]
1290
+ return ((loss,) + output) if loss is not None else output
1291
+
1292
+ return SequenceClassifierOutputWithPast(
1293
+ loss=loss,
1294
+ logits=logits,
1295
+ past_key_values=transformer_outputs.past_key_values,
1296
+ hidden_states=transformer_outputs.hidden_states,
1297
+ attentions=transformer_outputs.attentions,
1298
+ )
kolors/models/tokenization_chatglm.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ from typing import List, Optional, Union, Dict
5
+ from sentencepiece import SentencePieceProcessor
6
+ from transformers import PreTrainedTokenizer
7
+ from transformers.utils import logging, PaddingStrategy
8
+ from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
9
+
10
+
11
+ class SPTokenizer:
12
+ def __init__(self, model_path: str):
13
+ # reload tokenizer
14
+ assert os.path.isfile(model_path), model_path
15
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
16
+
17
+ # BOS / EOS token IDs
18
+ self.n_words: int = self.sp_model.vocab_size()
19
+ self.bos_id: int = self.sp_model.bos_id()
20
+ self.eos_id: int = self.sp_model.eos_id()
21
+ self.pad_id: int = self.sp_model.unk_id()
22
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
23
+
24
+ role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
25
+ special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
26
+ self.special_tokens = {}
27
+ self.index_special_tokens = {}
28
+ for token in special_tokens:
29
+ self.special_tokens[token] = self.n_words
30
+ self.index_special_tokens[self.n_words] = token
31
+ self.n_words += 1
32
+ self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
33
+
34
+ def tokenize(self, s: str, encode_special_tokens=False):
35
+ if encode_special_tokens:
36
+ last_index = 0
37
+ t = []
38
+ for match in re.finditer(self.role_special_token_expression, s):
39
+ if last_index < match.start():
40
+ t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
41
+ t.append(s[match.start():match.end()])
42
+ last_index = match.end()
43
+ if last_index < len(s):
44
+ t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
45
+ return t
46
+ else:
47
+ return self.sp_model.EncodeAsPieces(s)
48
+
49
+ def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
50
+ assert type(s) is str
51
+ t = self.sp_model.encode(s)
52
+ if bos:
53
+ t = [self.bos_id] + t
54
+ if eos:
55
+ t = t + [self.eos_id]
56
+ return t
57
+
58
+ def decode(self, t: List[int]) -> str:
59
+ text, buffer = "", []
60
+ for token in t:
61
+ if token in self.index_special_tokens:
62
+ if buffer:
63
+ text += self.sp_model.decode(buffer)
64
+ buffer = []
65
+ text += self.index_special_tokens[token]
66
+ else:
67
+ buffer.append(token)
68
+ if buffer:
69
+ text += self.sp_model.decode(buffer)
70
+ return text
71
+
72
+ def decode_tokens(self, tokens: List[str]) -> str:
73
+ text = self.sp_model.DecodePieces(tokens)
74
+ return text
75
+
76
+ def convert_token_to_id(self, token):
77
+ """ Converts a token (str) in an id using the vocab. """
78
+ if token in self.special_tokens:
79
+ return self.special_tokens[token]
80
+ return self.sp_model.PieceToId(token)
81
+
82
+ def convert_id_to_token(self, index):
83
+ """Converts an index (integer) in a token (str) using the vocab."""
84
+ if index in self.index_special_tokens:
85
+ return self.index_special_tokens[index]
86
+ if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
87
+ return ""
88
+ return self.sp_model.IdToPiece(index)
89
+
90
+
91
+ class ChatGLMTokenizer(PreTrainedTokenizer):
92
+ vocab_files_names = {"vocab_file": "tokenizer.model"}
93
+
94
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
95
+
96
+ def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
97
+ **kwargs):
98
+ self.name = "GLMTokenizer"
99
+
100
+ self.vocab_file = vocab_file
101
+ self.tokenizer = SPTokenizer(vocab_file)
102
+ self.special_tokens = {
103
+ "<bos>": self.tokenizer.bos_id,
104
+ "<eos>": self.tokenizer.eos_id,
105
+ "<pad>": self.tokenizer.pad_id
106
+ }
107
+ self.encode_special_tokens = encode_special_tokens
108
+ super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
109
+ encode_special_tokens=encode_special_tokens,
110
+ **kwargs)
111
+
112
+ def get_command(self, token):
113
+ if token in self.special_tokens:
114
+ return self.special_tokens[token]
115
+ assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
116
+ return self.tokenizer.special_tokens[token]
117
+
118
+ @property
119
+ def unk_token(self) -> str:
120
+ return "<unk>"
121
+
122
+ @property
123
+ def pad_token(self) -> str:
124
+ return "<unk>"
125
+
126
+ @property
127
+ def pad_token_id(self):
128
+ return self.get_command("<pad>")
129
+
130
+ @property
131
+ def eos_token(self) -> str:
132
+ return "</s>"
133
+
134
+ @property
135
+ def eos_token_id(self):
136
+ return self.get_command("<eos>")
137
+
138
+ @property
139
+ def vocab_size(self):
140
+ return self.tokenizer.n_words
141
+
142
+ def get_vocab(self):
143
+ """ Returns vocab as a dict """
144
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
145
+ vocab.update(self.added_tokens_encoder)
146
+ return vocab
147
+
148
+ def _tokenize(self, text, **kwargs):
149
+ return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
150
+
151
+ def _convert_token_to_id(self, token):
152
+ """ Converts a token (str) in an id using the vocab. """
153
+ return self.tokenizer.convert_token_to_id(token)
154
+
155
+ def _convert_id_to_token(self, index):
156
+ """Converts an index (integer) in a token (str) using the vocab."""
157
+ return self.tokenizer.convert_id_to_token(index)
158
+
159
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
160
+ return self.tokenizer.decode_tokens(tokens)
161
+
162
+ def save_vocabulary(self, save_directory, filename_prefix=None):
163
+ """
164
+ Save the vocabulary and special tokens file to a directory.
165
+
166
+ Args:
167
+ save_directory (`str`):
168
+ The directory in which to save the vocabulary.
169
+ filename_prefix (`str`, *optional*):
170
+ An optional prefix to add to the named of the saved files.
171
+
172
+ Returns:
173
+ `Tuple(str)`: Paths to the files saved.
174
+ """
175
+ if os.path.isdir(save_directory):
176
+ vocab_file = os.path.join(
177
+ save_directory, self.vocab_files_names["vocab_file"]
178
+ )
179
+ else:
180
+ vocab_file = save_directory
181
+
182
+ with open(self.vocab_file, 'rb') as fin:
183
+ proto_str = fin.read()
184
+
185
+ with open(vocab_file, "wb") as writer:
186
+ writer.write(proto_str)
187
+
188
+ return (vocab_file,)
189
+
190
+ def get_prefix_tokens(self):
191
+ prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
192
+ return prefix_tokens
193
+
194
+ def build_single_message(self, role, metadata, message):
195
+ assert role in ["system", "user", "assistant", "observation"], role
196
+ role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
197
+ message_tokens = self.tokenizer.encode(message)
198
+ tokens = role_tokens + message_tokens
199
+ return tokens
200
+
201
+ def build_chat_input(self, query, history=None, role="user"):
202
+ if history is None:
203
+ history = []
204
+ input_ids = []
205
+ for item in history:
206
+ content = item["content"]
207
+ if item["role"] == "system" and "tools" in item:
208
+ content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
209
+ input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
210
+ input_ids.extend(self.build_single_message(role, "", query))
211
+ input_ids.extend([self.get_command("<|assistant|>")])
212
+ return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
213
+
214
+ def build_inputs_with_special_tokens(
215
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
216
+ ) -> List[int]:
217
+ """
218
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
219
+ adding special tokens. A BERT sequence has the following format:
220
+
221
+ - single sequence: `[CLS] X [SEP]`
222
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
223
+
224
+ Args:
225
+ token_ids_0 (`List[int]`):
226
+ List of IDs to which the special tokens will be added.
227
+ token_ids_1 (`List[int]`, *optional*):
228
+ Optional second list of IDs for sequence pairs.
229
+
230
+ Returns:
231
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
232
+ """
233
+ prefix_tokens = self.get_prefix_tokens()
234
+ token_ids_0 = prefix_tokens + token_ids_0
235
+ if token_ids_1 is not None:
236
+ token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
237
+ return token_ids_0
238
+
239
+ def _pad(
240
+ self,
241
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
242
+ max_length: Optional[int] = None,
243
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
244
+ pad_to_multiple_of: Optional[int] = None,
245
+ return_attention_mask: Optional[bool] = None,
246
+ ) -> dict:
247
+ """
248
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
249
+
250
+ Args:
251
+ encoded_inputs:
252
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
253
+ max_length: maximum length of the returned list and optionally padding length (see below).
254
+ Will truncate by taking into account the special tokens.
255
+ padding_strategy: PaddingStrategy to use for padding.
256
+
257
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
258
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
259
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
260
+ The tokenizer padding sides are defined in self.padding_side:
261
+
262
+ - 'left': pads on the left of the sequences
263
+ - 'right': pads on the right of the sequences
264
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
265
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
266
+ `>= 7.5` (Volta).
267
+ return_attention_mask:
268
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
269
+ """
270
+ # Load from model defaults
271
+ assert self.padding_side == "left"
272
+
273
+ required_input = encoded_inputs[self.model_input_names[0]]
274
+ seq_length = len(required_input)
275
+
276
+ if padding_strategy == PaddingStrategy.LONGEST:
277
+ max_length = len(required_input)
278
+
279
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
280
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
281
+
282
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
283
+
284
+ # Initialize attention mask if not present.
285
+ if "attention_mask" not in encoded_inputs:
286
+ encoded_inputs["attention_mask"] = [1] * seq_length
287
+
288
+ if "position_ids" not in encoded_inputs:
289
+ encoded_inputs["position_ids"] = list(range(seq_length))
290
+
291
+ if needs_to_be_padded:
292
+ difference = max_length - len(required_input)
293
+
294
+ if "attention_mask" in encoded_inputs:
295
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
296
+ if "position_ids" in encoded_inputs:
297
+ encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
298
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
299
+
300
+ return encoded_inputs
kolors/pipelines/__init__.py ADDED
File without changes
kolors/pipelines/pipeline_stable_diffusion_xl_chatglm_256.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import sys
15
+ import os
16
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
17
+ from kolors.models.modeling_chatglm import ChatGLMModel
18
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+ import torch
22
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
23
+ from transformers import XLMRobertaModel, ChineseCLIPTextModel
24
+
25
+ from diffusers.image_processor import VaeImageProcessor
26
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
28
+ from diffusers.models.attention_processor import (
29
+ AttnProcessor2_0,
30
+ LoRAAttnProcessor2_0,
31
+ LoRAXFormersAttnProcessor,
32
+ XFormersAttnProcessor,
33
+ )
34
+ from diffusers.schedulers import KarrasDiffusionSchedulers
35
+ from diffusers.utils import (
36
+ is_accelerate_available,
37
+ is_accelerate_version,
38
+ logging,
39
+ replace_example_docstring,
40
+ )
41
+ try:
42
+ from diffusers.utils import randn_tensor
43
+ except:
44
+ from diffusers.utils.torch_utils import randn_tensor
45
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
46
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
47
+
48
+
49
+
50
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
+
52
+ EXAMPLE_DOC_STRING = """
53
+ Examples:
54
+ ```py
55
+ >>> import torch
56
+ >>> from diffusers import StableDiffusionXLPipeline
57
+
58
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
59
+ ... "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16
60
+ ... )
61
+ >>> pipe = pipe.to("cuda")
62
+
63
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
64
+ >>> image = pipe(prompt).images[0]
65
+ ```
66
+ """
67
+
68
+
69
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
70
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
71
+ """
72
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
73
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
74
+ """
75
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
76
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
77
+ # rescale the results from guidance (fixes overexposure)
78
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
79
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
80
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
81
+ return noise_cfg
82
+
83
+
84
+ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
85
+ r"""
86
+ Pipeline for text-to-image generation using Stable Diffusion XL.
87
+
88
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
89
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
90
+
91
+ In addition the pipeline inherits the following loading methods:
92
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
93
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
94
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
95
+
96
+ as well as the following saving methods:
97
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
98
+
99
+ Args:
100
+ vae ([`AutoencoderKL`]):
101
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
102
+ text_encoder ([`CLIPTextModel`]):
103
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
104
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
105
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
106
+
107
+ tokenizer (`CLIPTokenizer`):
108
+ Tokenizer of class
109
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
110
+
111
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
112
+ scheduler ([`SchedulerMixin`]):
113
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
114
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ vae: AutoencoderKL,
120
+ text_encoder: ChatGLMModel,
121
+ tokenizer: ChatGLMTokenizer,
122
+ unet: UNet2DConditionModel,
123
+ scheduler: KarrasDiffusionSchedulers,
124
+ force_zeros_for_empty_prompt: bool = True,
125
+ ):
126
+ super().__init__()
127
+
128
+ self.register_modules(
129
+ vae=vae,
130
+ text_encoder=text_encoder,
131
+ tokenizer=tokenizer,
132
+ unet=unet,
133
+ scheduler=scheduler,
134
+ )
135
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
136
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
137
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
138
+ self.default_sample_size = self.unet.config.sample_size
139
+
140
+ # self.watermark = StableDiffusionXLWatermarker()
141
+
142
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
143
+ def enable_vae_slicing(self):
144
+ r"""
145
+ Enable sliced VAE decoding.
146
+
147
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
148
+ steps. This is useful to save some memory and allow larger batch sizes.
149
+ """
150
+ self.vae.enable_slicing()
151
+
152
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
153
+ def disable_vae_slicing(self):
154
+ r"""
155
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
156
+ computing decoding in one step.
157
+ """
158
+ self.vae.disable_slicing()
159
+
160
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
161
+ def enable_vae_tiling(self):
162
+ r"""
163
+ Enable tiled VAE decoding.
164
+
165
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
166
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
167
+ """
168
+ self.vae.enable_tiling()
169
+
170
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
171
+ def disable_vae_tiling(self):
172
+ r"""
173
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
174
+ computing decoding in one step.
175
+ """
176
+ self.vae.disable_tiling()
177
+
178
+ def enable_sequential_cpu_offload(self, gpu_id=0):
179
+ r"""
180
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
181
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
182
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
183
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
184
+ `enable_model_cpu_offload`, but performance is lower.
185
+ """
186
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
187
+ from accelerate import cpu_offload
188
+ else:
189
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
190
+
191
+ device = torch.device(f"cuda:{gpu_id}")
192
+
193
+ if self.device.type != "cpu":
194
+ self.to("cpu", silence_dtype_warnings=True)
195
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
196
+
197
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
198
+ cpu_offload(cpu_offloaded_model, device)
199
+
200
+ def enable_model_cpu_offload(self, gpu_id=0):
201
+ r"""
202
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
203
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
204
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
205
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
206
+ """
207
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
208
+ from accelerate import cpu_offload_with_hook
209
+ else:
210
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
211
+
212
+ device = torch.device(f"cuda:{gpu_id}")
213
+
214
+ if self.device.type != "cpu":
215
+ self.to("cpu", silence_dtype_warnings=True)
216
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
217
+
218
+ model_sequence = (
219
+ [self.text_encoder]
220
+ )
221
+ model_sequence.extend([self.unet, self.vae])
222
+
223
+ hook = None
224
+ for cpu_offloaded_model in model_sequence:
225
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
226
+
227
+ # We'll offload the last model manually.
228
+ self.final_offload_hook = hook
229
+
230
+ @property
231
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
232
+ def _execution_device(self):
233
+ r"""
234
+ Returns the device on which the pipeline's models will be executed. After calling
235
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
236
+ hooks.
237
+ """
238
+ if not hasattr(self.unet, "_hf_hook"):
239
+ return self.device
240
+ for module in self.unet.modules():
241
+ if (
242
+ hasattr(module, "_hf_hook")
243
+ and hasattr(module._hf_hook, "execution_device")
244
+ and module._hf_hook.execution_device is not None
245
+ ):
246
+ return torch.device(module._hf_hook.execution_device)
247
+ return self.device
248
+
249
+ def encode_prompt(
250
+ self,
251
+ prompt,
252
+ device: Optional[torch.device] = None,
253
+ num_images_per_prompt: int = 1,
254
+ do_classifier_free_guidance: bool = True,
255
+ negative_prompt=None,
256
+ prompt_embeds: Optional[torch.FloatTensor] = None,
257
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
258
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
259
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
260
+ lora_scale: Optional[float] = None,
261
+ ):
262
+ r"""
263
+ Encodes the prompt into text encoder hidden states.
264
+
265
+ Args:
266
+ prompt (`str` or `List[str]`, *optional*):
267
+ prompt to be encoded
268
+ device: (`torch.device`):
269
+ torch device
270
+ num_images_per_prompt (`int`):
271
+ number of images that should be generated per prompt
272
+ do_classifier_free_guidance (`bool`):
273
+ whether to use classifier free guidance or not
274
+ negative_prompt (`str` or `List[str]`, *optional*):
275
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
276
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
277
+ less than `1`).
278
+ prompt_embeds (`torch.FloatTensor`, *optional*):
279
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
280
+ provided, text embeddings will be generated from `prompt` input argument.
281
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
282
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
283
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
284
+ argument.
285
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
286
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
287
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
288
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
289
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
290
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
291
+ input argument.
292
+ lora_scale (`float`, *optional*):
293
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
294
+ """
295
+ # from IPython import embed; embed(); exit()
296
+ device = device or self._execution_device
297
+
298
+ # set lora scale so that monkey patched LoRA
299
+ # function of text encoder can correctly access it
300
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
301
+ self._lora_scale = lora_scale
302
+
303
+ if prompt is not None and isinstance(prompt, str):
304
+ batch_size = 1
305
+ elif prompt is not None and isinstance(prompt, list):
306
+ batch_size = len(prompt)
307
+ else:
308
+ batch_size = prompt_embeds.shape[0]
309
+
310
+ # Define tokenizers and text encoders
311
+ tokenizers = [self.tokenizer]
312
+ text_encoders = [self.text_encoder]
313
+
314
+ if prompt_embeds is None:
315
+ # textual inversion: procecss multi-vector tokens if necessary
316
+ prompt_embeds_list = []
317
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
318
+ if isinstance(self, TextualInversionLoaderMixin):
319
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
320
+
321
+ text_inputs = tokenizer(
322
+ prompt,
323
+ padding="max_length",
324
+ max_length=256,
325
+ truncation=True,
326
+ return_tensors="pt",
327
+ ).to('cuda')
328
+ output = text_encoder(
329
+ input_ids=text_inputs['input_ids'] ,
330
+ attention_mask=text_inputs['attention_mask'],
331
+ position_ids=text_inputs['position_ids'],
332
+ output_hidden_states=True)
333
+ prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() # [batch_size, 77, 4096]
334
+ text_proj = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
335
+ bs_embed, seq_len, _ = prompt_embeds.shape
336
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
337
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
338
+
339
+ prompt_embeds_list.append(prompt_embeds)
340
+
341
+ # prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
342
+ prompt_embeds = prompt_embeds_list[0]
343
+
344
+ # get unconditional embeddings for classifier free guidance
345
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
346
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
347
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
348
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
349
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
350
+ # negative_prompt = negative_prompt or ""
351
+ uncond_tokens: List[str]
352
+ if negative_prompt is None:
353
+ uncond_tokens = [""] * batch_size
354
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
355
+ raise TypeError(
356
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
357
+ f" {type(prompt)}."
358
+ )
359
+ elif isinstance(negative_prompt, str):
360
+ uncond_tokens = [negative_prompt]
361
+ elif batch_size != len(negative_prompt):
362
+ raise ValueError(
363
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
364
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
365
+ " the batch size of `prompt`."
366
+ )
367
+ else:
368
+ uncond_tokens = negative_prompt
369
+
370
+ negative_prompt_embeds_list = []
371
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
372
+ # textual inversion: procecss multi-vector tokens if necessary
373
+ if isinstance(self, TextualInversionLoaderMixin):
374
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
375
+
376
+ max_length = prompt_embeds.shape[1]
377
+ uncond_input = tokenizer(
378
+ uncond_tokens,
379
+ padding="max_length",
380
+ max_length=max_length,
381
+ truncation=True,
382
+ return_tensors="pt",
383
+ ).to('cuda')
384
+ output = text_encoder(
385
+ input_ids=uncond_input['input_ids'] ,
386
+ attention_mask=uncond_input['attention_mask'],
387
+ position_ids=uncond_input['position_ids'],
388
+ output_hidden_states=True)
389
+ negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() # [batch_size, 77, 4096]
390
+ negative_text_proj = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
391
+
392
+ if do_classifier_free_guidance:
393
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
394
+ seq_len = negative_prompt_embeds.shape[1]
395
+
396
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
397
+
398
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
399
+ negative_prompt_embeds = negative_prompt_embeds.view(
400
+ batch_size * num_images_per_prompt, seq_len, -1
401
+ )
402
+
403
+ # For classifier free guidance, we need to do two forward passes.
404
+ # Here we concatenate the unconditional and text embeddings into a single batch
405
+ # to avoid doing two forward passes
406
+
407
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
408
+
409
+ # negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
410
+ negative_prompt_embeds = negative_prompt_embeds_list[0]
411
+
412
+ bs_embed = text_proj.shape[0]
413
+ text_proj = text_proj.repeat(1, num_images_per_prompt).view(
414
+ bs_embed * num_images_per_prompt, -1
415
+ )
416
+ negative_text_proj = negative_text_proj.repeat(1, num_images_per_prompt).view(
417
+ bs_embed * num_images_per_prompt, -1
418
+ )
419
+
420
+ return prompt_embeds, negative_prompt_embeds, text_proj, negative_text_proj
421
+
422
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
423
+ def prepare_extra_step_kwargs(self, generator, eta):
424
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
425
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
426
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
427
+ # and should be between [0, 1]
428
+
429
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
430
+ extra_step_kwargs = {}
431
+ if accepts_eta:
432
+ extra_step_kwargs["eta"] = eta
433
+
434
+ # check if the scheduler accepts generator
435
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
436
+ if accepts_generator:
437
+ extra_step_kwargs["generator"] = generator
438
+ return extra_step_kwargs
439
+
440
+ def check_inputs(
441
+ self,
442
+ prompt,
443
+ height,
444
+ width,
445
+ callback_steps,
446
+ negative_prompt=None,
447
+ prompt_embeds=None,
448
+ negative_prompt_embeds=None,
449
+ pooled_prompt_embeds=None,
450
+ negative_pooled_prompt_embeds=None,
451
+ ):
452
+ if height % 8 != 0 or width % 8 != 0:
453
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
454
+
455
+ if (callback_steps is None) or (
456
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
457
+ ):
458
+ raise ValueError(
459
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
460
+ f" {type(callback_steps)}."
461
+ )
462
+
463
+ if prompt is not None and prompt_embeds is not None:
464
+ raise ValueError(
465
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
466
+ " only forward one of the two."
467
+ )
468
+ elif prompt is None and prompt_embeds is None:
469
+ raise ValueError(
470
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
471
+ )
472
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
473
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
474
+
475
+ if negative_prompt is not None and negative_prompt_embeds is not None:
476
+ raise ValueError(
477
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
478
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
479
+ )
480
+
481
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
482
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
483
+ raise ValueError(
484
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
485
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
486
+ f" {negative_prompt_embeds.shape}."
487
+ )
488
+
489
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
490
+ raise ValueError(
491
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
492
+ )
493
+
494
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
495
+ raise ValueError(
496
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
497
+ )
498
+
499
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
500
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
501
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
502
+ if isinstance(generator, list) and len(generator) != batch_size:
503
+ raise ValueError(
504
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
505
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
506
+ )
507
+
508
+ if latents is None:
509
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
510
+ else:
511
+ latents = latents.to(device)
512
+
513
+ # scale the initial noise by the standard deviation required by the scheduler
514
+ latents = latents * self.scheduler.init_noise_sigma
515
+ return latents
516
+
517
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
518
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
519
+
520
+ passed_add_embed_dim = (
521
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + 4096
522
+ )
523
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
524
+
525
+ if expected_add_embed_dim != passed_add_embed_dim:
526
+ raise ValueError(
527
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
528
+ )
529
+
530
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
531
+ return add_time_ids
532
+
533
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
534
+ def upcast_vae(self):
535
+ dtype = self.vae.dtype
536
+ self.vae.to(dtype=torch.float32)
537
+ use_torch_2_0_or_xformers = isinstance(
538
+ self.vae.decoder.mid_block.attentions[0].processor,
539
+ (
540
+ AttnProcessor2_0,
541
+ XFormersAttnProcessor,
542
+ LoRAXFormersAttnProcessor,
543
+ LoRAAttnProcessor2_0,
544
+ ),
545
+ )
546
+ # if xformers or torch_2_0 is used attention block does not need
547
+ # to be in float32 which can save lots of memory
548
+ if use_torch_2_0_or_xformers:
549
+ self.vae.post_quant_conv.to(dtype)
550
+ self.vae.decoder.conv_in.to(dtype)
551
+ self.vae.decoder.mid_block.to(dtype)
552
+
553
+ @torch.no_grad()
554
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
555
+ def __call__(
556
+ self,
557
+ prompt: Union[str, List[str]] = None,
558
+ height: Optional[int] = None,
559
+ width: Optional[int] = None,
560
+ num_inference_steps: int = 50,
561
+ denoising_end: Optional[float] = None,
562
+ guidance_scale: float = 5.0,
563
+ negative_prompt: Optional[Union[str, List[str]]] = None,
564
+ num_images_per_prompt: Optional[int] = 1,
565
+ eta: float = 0.0,
566
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
567
+ latents: Optional[torch.FloatTensor] = None,
568
+ prompt_embeds: Optional[torch.FloatTensor] = None,
569
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
570
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
571
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
572
+ output_type: Optional[str] = "pil",
573
+ return_dict: bool = True,
574
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
575
+ callback_steps: int = 1,
576
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
577
+ guidance_rescale: float = 0.0,
578
+ original_size: Optional[Tuple[int, int]] = None,
579
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
580
+ target_size: Optional[Tuple[int, int]] = None,
581
+ use_dynamic_threshold: Optional[bool] = False,
582
+ ):
583
+ r"""
584
+ Function invoked when calling the pipeline for generation.
585
+
586
+ Args:
587
+ prompt (`str` or `List[str]`, *optional*):
588
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
589
+ instead.
590
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
591
+ The height in pixels of the generated image.
592
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
593
+ The width in pixels of the generated image.
594
+ num_inference_steps (`int`, *optional*, defaults to 50):
595
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
596
+ expense of slower inference.
597
+ denoising_end (`float`, *optional*):
598
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
599
+ completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to
600
+ 0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50)
601
+ Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
602
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
603
+ guidance_scale (`float`, *optional*, defaults to 7.5):
604
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
605
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
606
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
607
+ negative_prompt (`str` or `List[str]`, *optional*):
608
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
609
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
610
+ less than `1`).
611
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
612
+ The number of images to generate per prompt.
613
+ eta (`float`, *optional*, defaults to 0.0):
614
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
615
+ [`schedulers.DDIMScheduler`], will be ignored for others.
616
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
617
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
618
+ to make generation deterministic.
619
+ latents (`torch.FloatTensor`, *optional*):
620
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
621
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
622
+ tensor will ge generated by sampling using the supplied random `generator`.
623
+ prompt_embeds (`torch.FloatTensor`, *optional*):
624
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
625
+ provided, text embeddings will be generated from `prompt` input argument.
626
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
627
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
628
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
629
+ argument.
630
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
631
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
632
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
633
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
634
+ output_type (`str`, *optional*, defaults to `"pil"`):
635
+ The output format of the generate image. Choose between
636
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
637
+ return_dict (`bool`, *optional*, defaults to `True`):
638
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
639
+ callback (`Callable`, *optional*):
640
+ A function that will be called every `callback_steps` steps during inference. The function will be
641
+ callback_steps (`int`, *optional*, defaults to 1):
642
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
643
+ called at every step.
644
+ cross_attention_kwargs (`dict`, *optional*):
645
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
646
+ `self.processor` in
647
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
648
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
649
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
650
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
651
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
652
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
653
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
654
+ TODO
655
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
656
+ TODO
657
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
658
+ TODO
659
+
660
+ Examples:
661
+
662
+ Returns:
663
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
664
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
665
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
666
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
667
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
668
+ """
669
+ # 0. Default height and width to unet
670
+ height = height or self.default_sample_size * self.vae_scale_factor
671
+ width = width or self.default_sample_size * self.vae_scale_factor
672
+
673
+ original_size = original_size or (height, width)
674
+ target_size = target_size or (height, width)
675
+
676
+ # 1. Check inputs. Raise error if not correct
677
+ self.check_inputs(
678
+ prompt,
679
+ height,
680
+ width,
681
+ callback_steps,
682
+ negative_prompt,
683
+ prompt_embeds,
684
+ negative_prompt_embeds,
685
+ pooled_prompt_embeds,
686
+ negative_pooled_prompt_embeds,
687
+ )
688
+
689
+ # 2. Define call parameters
690
+ if prompt is not None and isinstance(prompt, str):
691
+ batch_size = 1
692
+ elif prompt is not None and isinstance(prompt, list):
693
+ batch_size = len(prompt)
694
+ else:
695
+ batch_size = prompt_embeds.shape[0]
696
+
697
+ device = self._execution_device
698
+
699
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
700
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
701
+ # corresponds to doing no classifier free guidance.
702
+ do_classifier_free_guidance = guidance_scale > 1.0
703
+
704
+ # 3. Encode input prompt
705
+ text_encoder_lora_scale = (
706
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
707
+ )
708
+ (
709
+ prompt_embeds,
710
+ negative_prompt_embeds,
711
+ pooled_prompt_embeds,
712
+ negative_pooled_prompt_embeds,
713
+ ) = self.encode_prompt(
714
+ prompt,
715
+ device,
716
+ num_images_per_prompt,
717
+ do_classifier_free_guidance,
718
+ negative_prompt,
719
+ prompt_embeds=prompt_embeds,
720
+ negative_prompt_embeds=negative_prompt_embeds,
721
+ pooled_prompt_embeds=pooled_prompt_embeds,
722
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
723
+ lora_scale=text_encoder_lora_scale,
724
+ )
725
+
726
+ # 4. Prepare timesteps
727
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
728
+
729
+ timesteps = self.scheduler.timesteps
730
+
731
+ # 5. Prepare latent variables
732
+ num_channels_latents = self.unet.config.in_channels
733
+ latents = self.prepare_latents(
734
+ batch_size * num_images_per_prompt,
735
+ num_channels_latents,
736
+ height,
737
+ width,
738
+ prompt_embeds.dtype,
739
+ device,
740
+ generator,
741
+ latents,
742
+ )
743
+
744
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
745
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
746
+
747
+ # 7. Prepare added time ids & embeddings
748
+ add_text_embeds = pooled_prompt_embeds
749
+ add_time_ids = self._get_add_time_ids(
750
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
751
+ )
752
+
753
+ if do_classifier_free_guidance:
754
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
755
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
756
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
757
+
758
+ prompt_embeds = prompt_embeds.to(device)
759
+ add_text_embeds = add_text_embeds.to(device)
760
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
761
+
762
+ # 8. Denoising loop
763
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
764
+
765
+ # 7.1 Apply denoising_end
766
+ if denoising_end is not None:
767
+ num_inference_steps = int(round(denoising_end * num_inference_steps))
768
+ timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps]
769
+
770
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
771
+ for i, t in enumerate(timesteps):
772
+ # expand the latents if we are doing classifier free guidance
773
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
774
+
775
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
776
+
777
+ # predict the noise residual
778
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
779
+ noise_pred = self.unet(
780
+ latent_model_input,
781
+ t,
782
+ encoder_hidden_states=prompt_embeds,
783
+ cross_attention_kwargs=cross_attention_kwargs,
784
+ added_cond_kwargs=added_cond_kwargs,
785
+ return_dict=False,
786
+ )[0]
787
+
788
+ # perform guidance
789
+ if do_classifier_free_guidance:
790
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
791
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
792
+ if use_dynamic_threshold:
793
+ DynamicThresh = DynThresh(maxSteps=num_inference_steps, experiment_mode=0)
794
+ noise_pred = DynamicThresh.dynthresh(noise_pred_text,
795
+ noise_pred_uncond,
796
+ guidance_scale,
797
+ None)
798
+
799
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
800
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
801
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
802
+
803
+ # compute the previous noisy sample x_t -> x_t-1
804
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
805
+
806
+ # call the callback, if provided
807
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
808
+ progress_bar.update()
809
+ if callback is not None and i % callback_steps == 0:
810
+ callback(i, t, latents)
811
+
812
+ # make sureo the VAE is in float32 mode, as it overflows in float16
813
+ # torch.cuda.empty_cache()
814
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
815
+ self.upcast_vae()
816
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
817
+
818
+
819
+ if not output_type == "latent":
820
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
821
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
822
+ else:
823
+ image = latents
824
+ return StableDiffusionXLPipelineOutput(images=image)
825
+
826
+ # image = self.watermark.apply_watermark(image)
827
+ image = self.image_processor.postprocess(image, output_type=output_type)
828
+
829
+ # Offload last model to CPU
830
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
831
+ self.final_offload_hook.offload()
832
+
833
+ if not return_dict:
834
+ return (image,)
835
+
836
+ return StableDiffusionXLPipelineOutput(images=image)
837
+
838
+
839
+ if __name__ == "__main__":
840
+ pass