multimodalart HF staff commited on
Commit
8a09a62
·
verified ·
1 Parent(s): 9de62ed

Upload 57 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. README.md +1 -1
  3. app/hydit_app.py +170 -0
  4. app/lang/en.csv +22 -0
  5. app/lang/zh.csv +22 -0
  6. asset/Hunyuan_DiT_Tech_Report_05140553.pdf +3 -0
  7. asset/chinese elements understanding.png +3 -0
  8. asset/cover.png +0 -0
  9. asset/framework.png +0 -0
  10. asset/logo.png +0 -0
  11. asset/long text understanding.png +3 -0
  12. asset/mllm.png +0 -0
  13. asset/radar.png +0 -0
  14. dialoggen/dialoggen_demo.py +172 -0
  15. dialoggen/images/demo1.jpeg +0 -0
  16. dialoggen/images/demo2.jpeg +0 -0
  17. dialoggen/llava/__init__.py +1 -0
  18. dialoggen/llava/constants.py +13 -0
  19. dialoggen/llava/conversation.py +396 -0
  20. dialoggen/llava/mm_utils.py +247 -0
  21. dialoggen/llava/model/__init__.py +6 -0
  22. dialoggen/llava/model/apply_delta.py +48 -0
  23. dialoggen/llava/model/builder.py +167 -0
  24. dialoggen/llava/model/consolidate.py +29 -0
  25. dialoggen/llava/model/language_model/llava_llama.py +158 -0
  26. dialoggen/llava/model/language_model/llava_mistral.py +158 -0
  27. dialoggen/llava/model/language_model/llava_mpt.py +97 -0
  28. dialoggen/llava/model/llava_arch.py +368 -0
  29. dialoggen/llava/model/make_delta.py +52 -0
  30. dialoggen/llava/model/multimodal_encoder/builder.py +11 -0
  31. dialoggen/llava/model/multimodal_encoder/clip_encoder.py +88 -0
  32. dialoggen/llava/model/multimodal_projector/builder.py +51 -0
  33. dialoggen/llava/model/utils.py +20 -0
  34. dialoggen/llava/utils.py +126 -0
  35. en.csv +22 -0
  36. environment.yml +8 -0
  37. example_prompts.txt +28 -0
  38. hydit/__init__.py +0 -0
  39. hydit/config.py +67 -0
  40. hydit/constants.py +62 -0
  41. hydit/diffusion/__init__.py +0 -0
  42. hydit/diffusion/pipeline.py +830 -0
  43. hydit/inference.py +389 -0
  44. hydit/modules/__init__.py +0 -0
  45. hydit/modules/attn_layers.py +377 -0
  46. hydit/modules/embedders.py +111 -0
  47. hydit/modules/models.py +409 -0
  48. hydit/modules/norm_layers.py +68 -0
  49. hydit/modules/poolers.py +39 -0
  50. hydit/modules/posemb_layers.py +225 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ asset/chinese[[:space:]]elements[[:space:]]understanding.png filter=lfs diff=lfs merge=lfs -text
37
+ asset/Hunyuan_DiT_Tech_Report_05140553.pdf filter=lfs diff=lfs merge=lfs -text
38
+ asset/long[[:space:]]text[[:space:]]understanding.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.31.1
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
 
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.31.1
8
+ app_file: app/hydit_app.py
9
  pinned: false
10
  ---
11
 
app/hydit_app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from pathlib import Path
4
+ from PIL import Image
5
+ import sys
6
+ sys.path.insert(0, str(Path(__file__).parent.parent))
7
+
8
+ from hydit.constants import SAMPLER_FACTORY
9
+ from sample_t2i import inferencer
10
+
11
+ ROOT = Path(__file__).parent.parent
12
+ SAMPLERS = list(SAMPLER_FACTORY.keys())
13
+ SIZES = {
14
+ "square": (1024, 1024),
15
+ "landscape": (768, 1280),
16
+ "portrait": (1280, 768),
17
+ }
18
+
19
+ def get_strings(lang):
20
+ lang_file = Path(f"app/lang/{lang}.csv")
21
+ strings = pd.read_csv(lang_file, header=0)
22
+ strings = strings.set_index("key")['value'].to_dict()
23
+ return strings
24
+
25
+
26
+ args, gen, enhancer = inferencer()
27
+ strings = get_strings("en")
28
+
29
+
30
+ def infer(
31
+ prompt,
32
+ negative_prompt,
33
+ seed,
34
+ cfg_scale,
35
+ infer_steps,
36
+ oriW, oriH,
37
+ sampler,
38
+ size,
39
+ enhance
40
+ ):
41
+ if enhance and enhancer is not None:
42
+ success, enhanced_prompt = enhancer(prompt)
43
+ if not success:
44
+ fail_image = Image.open(ROOT / 'app/fail.png')
45
+ return fail_image
46
+ else:
47
+ enhanced_prompt = None
48
+
49
+ height, width = SIZES[size]
50
+ results = gen.predict(prompt,
51
+ height=height,
52
+ width=width,
53
+ seed=seed,
54
+ enhanced_prompt=enhanced_prompt,
55
+ negative_prompt=negative_prompt,
56
+ infer_steps=infer_steps,
57
+ guidance_scale=cfg_scale,
58
+ batch_size=1,
59
+ src_size_cond=(oriW, oriH),
60
+ sampler=sampler,
61
+ )
62
+ image = results['images'][0]
63
+ return image
64
+
65
+
66
+ def ui():
67
+ block = gr.Blocks()
68
+
69
+ description = f"""
70
+ # {strings['title']}
71
+
72
+ ## {strings['desc']}
73
+
74
+ """
75
+
76
+ with block:
77
+ with gr.Row():
78
+ gr.Markdown(description)
79
+ with gr.Row():
80
+ with gr.Column():
81
+ with gr.Row():
82
+ size = gr.Radio(
83
+ label=strings['size'], choices=[
84
+ (strings['square'], 'square'),
85
+ (strings['landscape'], 'landscape'),
86
+ (strings['portrait'], 'portrait'),
87
+ ],
88
+ value="square"
89
+ )
90
+ prompt = gr.Textbox(label=strings['prompt'], value=strings['default prompt'], lines=3)
91
+ with gr.Row():
92
+ infer_steps = gr.Slider(
93
+ label=strings['infer steps'], minimum=1, maximum=200, value=100, step=1,
94
+ )
95
+ seed = gr.Number(
96
+ label=strings['seed'], minimum=-1, maximum=1_000_000_000, value=1, step=1, precision=0,
97
+ )
98
+ enhance = gr.Checkbox(
99
+ label=strings['enhance'], value=enhancer is not None, interactive=True,
100
+ )
101
+
102
+ with gr.Accordion(
103
+ strings['accordion'], open=False
104
+ ):
105
+ with gr.Row():
106
+ negative_prompt = gr.Textbox(label=strings['negative_prompt'],
107
+ value=gen.default_negative_prompt,
108
+ lines=2,
109
+ )
110
+ with gr.Row():
111
+ sampler = gr.Dropdown(SAMPLERS, label=strings['sampler'], value="ddpm")
112
+ cfg_scale = gr.Slider(
113
+ label=strings['cfg'], minimum=1.0, maximum=16.0, value=6.0, step=1
114
+ )
115
+ oriW = gr.Number(
116
+ label=strings['width cond'], minimum=1024, maximum=4096, value=1024, step=64, precision=0,
117
+ min_width=80,
118
+ )
119
+ oriH = gr.Number(
120
+ label=strings['height cond'], minimum=1024, maximum=4096, value=1024, step=64, precision=0,
121
+ min_width=80,
122
+ )
123
+ with gr.Row():
124
+ advanced_button = gr.Button(strings['run'])
125
+ with gr.Column():
126
+ default_img = Image.open(ROOT / 'app/default.png')
127
+ output_img = gr.Image(
128
+ label=strings['generated image'],
129
+ interactive=False,
130
+ format='png',
131
+ value=default_img,
132
+ )
133
+ advanced_button.click(
134
+ fn=infer,
135
+ inputs=[
136
+ prompt, negative_prompt, seed, cfg_scale, infer_steps,
137
+ oriW, oriH, sampler, size, enhance,
138
+ ],
139
+ outputs=output_img,
140
+ )
141
+
142
+ with gr.Row():
143
+ gr.Examples([
144
+ ['一只小猫'],
145
+ ['现实主义风格,画面主要描述一个巴洛克风格的花瓶,带有金色的装饰边框,花瓶上盛开着各种色彩鲜艳的花,白色背景'],
146
+ ['一只聪明的狐狸走在阔叶树林里, 旁边是一条小溪, 细节真实, 摄影'],
147
+ ['飞流直下三千尺,疑是银河落九天'],
148
+ ['一只长靴猫手持亮银色的宝剑,身着铠甲,眼神坚毅,站在一堆金币上,背景是暗色调的洞穴,图像上有金币的光影点缀。'],
149
+ ['麻婆豆腐'],
150
+ ['苏州园林'],
151
+ ['一颗新鲜的草莓特写,红色的外表,表面布满许多种子,背景是淡绿色的叶子'],
152
+ ['请画出“忽如一夜春风来 千树万树梨花开”'],
153
+ ['请将“杞人忧天”的样子画出来'],
154
+ ['枯藤老树昏鸦,小桥流水人家'],
155
+ ['湖水清澈,天空湛蓝,阳光灿烂。一只优雅的白天鹅在湖边游泳。它周围有几只小鸭子,看起来非常可爱,整个画面给人一种宁静祥和的感觉。'],
156
+ ['一朵鲜艳的红色玫瑰花,花瓣撒有一些水珠,晶莹剔透,特写镜头'],
157
+ ['臭豆腐'],
158
+ ['九寨沟'],
159
+ ['俗语“鲤鱼跃龙门”'],
160
+ ['风格是写实,画面主要描述一个亚洲戏曲艺术家正在表演,她穿着华丽的戏服,脸上戴着精致的面具,身姿优雅,背景是古色古香的舞台,镜头是近景'],
161
+ ],
162
+ [prompt],
163
+ label=strings['examples']
164
+ )
165
+ return block
166
+
167
+
168
+ if __name__ == "__main__":
169
+ interface = ui()
170
+ interface.launch()
app/lang/en.csv ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ key,value
2
+ size,Size
3
+ sampler,Sampler
4
+ prompt,Prompt
5
+ default prompt,"A cute cat"
6
+ negative_prompt,Negative Prompt
7
+ seed,Seed
8
+ cfg,CFG Scale
9
+ infer steps,Sampling Steps
10
+ batch size,Batch Size
11
+ width cond,Width Cond
12
+ height cond,Height Cond
13
+ enhance,Prompt Enhancement
14
+ run,Submit
15
+ square,Square(1024x1024)
16
+ landscape,Landscape(1280x768)
17
+ portrait,Portrait(768x1280)
18
+ accordion,Advanced Options
19
+ generated image,HunYuanDiT Generated Image
20
+ examples,More Examples
21
+ title,Hunyuan-DiT
22
+ desc,A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding
app/lang/zh.csv ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ key,value
2
+ size,尺寸
3
+ sampler,采样器
4
+ prompt,文本描述
5
+ default prompt,"一只可爱的猫"
6
+ negative_prompt,负向词
7
+ seed,种子
8
+ cfg,CFG系数
9
+ infer steps,采样步数
10
+ batch size,批大小
11
+ width cond,宽度条件
12
+ height cond,高度条件
13
+ enhance,文本增强
14
+ run,提交生成
15
+ square,方形(1024x1024)
16
+ portrait,竖屏(1280x768)
17
+ landscape,横屏(768x1280)
18
+ accordion,高级设置
19
+ generated image,HunYuanDiT 生成
20
+ examples,更多示例
21
+ title,混元-DiT
22
+ desc,具有细粒度中文理解的高性能多分辨率 Diffusion Transformer 模型
asset/Hunyuan_DiT_Tech_Report_05140553.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f8514b002ba3bb4704575096683f65e09df06693a54bf3004f0b351138ab1e5
3
+ size 42132252
asset/chinese elements understanding.png ADDED

Git LFS Details

  • SHA256: c5761413a7c2b15adb83dcad04c3b56c6358debd3a354dfd559919b611c9fb52
  • Pointer size: 132 Bytes
  • Size of remote file: 6.06 MB
asset/cover.png ADDED
asset/framework.png ADDED
asset/logo.png ADDED
asset/long text understanding.png ADDED

Git LFS Details

  • SHA256: 8060c105db0cc40a83a89443096c8b95b2838da57fd04d4ddf828328dce8811e
  • Pointer size: 132 Bytes
  • Size of remote file: 5.15 MB
asset/mllm.png ADDED
asset/radar.png ADDED
dialoggen/dialoggen_demo.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import sys
4
+ import os
5
+ # 添加当前命令行运行的目录到 sys.path
6
+ sys.path.append(os.getcwd()+"/dialoggen")
7
+
8
+
9
+ from llava.constants import (
10
+ IMAGE_TOKEN_INDEX,
11
+ DEFAULT_IMAGE_TOKEN,
12
+ DEFAULT_IM_START_TOKEN,
13
+ DEFAULT_IM_END_TOKEN,
14
+ IMAGE_PLACEHOLDER,
15
+ )
16
+ from llava.conversation import conv_templates, SeparatorStyle
17
+ from llava.model.builder import load_pretrained_model
18
+ from llava.utils import disable_torch_init
19
+ from llava.mm_utils import (
20
+ process_images,
21
+ tokenizer_image_token,
22
+ get_model_name_from_path,
23
+ )
24
+
25
+ import requests
26
+ from PIL import Image
27
+ from io import BytesIO
28
+ import re
29
+
30
+
31
+ def image_parser(image_file, sep=','):
32
+ out = image_file.split(sep)
33
+ return out
34
+
35
+
36
+ def load_image(image_file):
37
+ if image_file.startswith("http") or image_file.startswith("https"):
38
+ response = requests.get(image_file)
39
+ image = Image.open(BytesIO(response.content)).convert("RGB")
40
+ else:
41
+ image = Image.open(image_file).convert("RGB")
42
+ return image
43
+
44
+
45
+ def load_images(image_files):
46
+ out = []
47
+ for image_file in image_files:
48
+ image = load_image(image_file)
49
+ out.append(image)
50
+ return out
51
+
52
+
53
+ def init_dialoggen_model(model_path, model_base=None):
54
+ model_name = get_model_name_from_path(model_path)
55
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
56
+ model_path, model_base, model_name, llava_type_model=True)
57
+ return {"tokenizer": tokenizer,
58
+ "model": model,
59
+ "image_processor": image_processor}
60
+
61
+
62
+ def eval_model(models,
63
+ query='详细描述一下这张图片',
64
+ image_file=None,
65
+ sep=',',
66
+ temperature=0.2,
67
+ top_p=None,
68
+ num_beams=1,
69
+ max_new_tokens=512,
70
+ ):
71
+ # Model
72
+ disable_torch_init()
73
+
74
+ qs = query
75
+ image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
76
+ if IMAGE_PLACEHOLDER in qs:
77
+ if models["model"].config.mm_use_im_start_end:
78
+ qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
79
+ else:
80
+ qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
81
+ else:
82
+ if models["model"].config.mm_use_im_start_end:
83
+ qs = image_token_se + "\n" + qs
84
+ else:
85
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
86
+
87
+ conv = conv_templates['llava_v1'].copy()
88
+ conv.append_message(conv.roles[0], qs)
89
+ conv.append_message(conv.roles[1], None)
90
+ prompt = conv.get_prompt()
91
+
92
+ if image_file is not None:
93
+ image_files = image_parser(image_file, sep=sep)
94
+ images = load_images(image_files)
95
+ image_sizes = [x.size for x in images]
96
+ images_tensor = process_images(
97
+ images,
98
+ models["image_processor"],
99
+ models["model"].config
100
+ ).to(models["model"].device, dtype=torch.float16)
101
+ else:
102
+ # fomatted input as training data
103
+ image_sizes = [(1024, 1024)]
104
+ images_tensor = torch.zeros(1, 5, 3, models["image_processor"].crop_size["height"], models["image_processor"].crop_size["width"])
105
+ images_tensor = images_tensor.to(models["model"].device, dtype=torch.float16)
106
+
107
+ input_ids = (
108
+ tokenizer_image_token(prompt, models["tokenizer"], IMAGE_TOKEN_INDEX, return_tensors="pt")
109
+ .unsqueeze(0)
110
+ .cuda()
111
+ )
112
+ with torch.inference_mode():
113
+ output_ids = models["model"].generate(
114
+ input_ids,
115
+ images=images_tensor,
116
+ image_sizes=image_sizes,
117
+ do_sample=True if temperature > 0 else False,
118
+ temperature=temperature,
119
+ top_p=top_p,
120
+ num_beams=num_beams,
121
+ max_new_tokens=max_new_tokens,
122
+ use_cache=True,
123
+ )
124
+
125
+ outputs = models["tokenizer"].batch_decode(output_ids, skip_special_tokens=True)[0].strip()
126
+ return outputs
127
+
128
+
129
+ def remove_prefix(text):
130
+ if text.startswith("<画图>"):
131
+ return text[len("<画图>"):], True
132
+ elif text.startswith("对不起"):
133
+ # 拒绝画图
134
+ return "", False
135
+ else:
136
+ return text, True
137
+
138
+
139
+ class DialogGen(object):
140
+ def __init__(self, model_path):
141
+ self.models = init_dialoggen_model(model_path)
142
+ self.query_template = "请先判断用户的意图,若为画图则在输出前加入<画图>:{}"
143
+
144
+ def __call__(self, prompt):
145
+ enhanced_prompt = eval_model(
146
+ models=self.models,
147
+ query=self.query_template.format(prompt),
148
+ image_file=None,
149
+ )
150
+
151
+ enhanced_prompt, compliance = remove_prefix(enhanced_prompt)
152
+ if not compliance:
153
+ return False, ""
154
+ return True, enhanced_prompt
155
+
156
+
157
+ if __name__ == "__main__":
158
+ parser = argparse.ArgumentParser()
159
+ parser.add_argument('--model_path', type=str, default='./ckpts/dialoggen')
160
+ parser.add_argument('--prompt', type=str, default='画一只小猫')
161
+ parser.add_argument('--image_file', type=str, default=None) # 'images/demo1.jpeg'
162
+ args = parser.parse_args()
163
+
164
+ query = f"请先判断用户的意图,若为画图则在输出前加入<画图>:{args.prompt}"
165
+
166
+ models = init_dialoggen_model(args.model_path)
167
+
168
+ res = eval_model(models,
169
+ query=query,
170
+ image_file=args.image_file,
171
+ )
172
+ print(res)
dialoggen/images/demo1.jpeg ADDED
dialoggen/images/demo2.jpeg ADDED
dialoggen/llava/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import LlavaLlamaForCausalLM
dialoggen/llava/constants.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
13
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
dialoggen/llava/conversation.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+
9
+ class SeparatorStyle(Enum):
10
+ """Different separator style."""
11
+ SINGLE = auto()
12
+ TWO = auto()
13
+ MPT = auto()
14
+ PLAIN = auto()
15
+ LLAMA_2 = auto()
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class Conversation:
20
+ """A class that keeps all conversation history."""
21
+ system: str
22
+ roles: List[str]
23
+ messages: List[List[str]]
24
+ offset: int
25
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
26
+ sep: str = "###"
27
+ sep2: str = None
28
+ version: str = "Unknown"
29
+
30
+ skip_next: bool = False
31
+
32
+ def get_prompt(self):
33
+ messages = self.messages
34
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
35
+ messages = self.messages.copy()
36
+ init_role, init_msg = messages[0].copy()
37
+ init_msg = init_msg[0].replace("<image>", "").strip()
38
+ if 'mmtag' in self.version:
39
+ messages[0] = (init_role, init_msg)
40
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
41
+ messages.insert(1, (self.roles[1], "Received."))
42
+ else:
43
+ messages[0] = (init_role, "<image>\n" + init_msg)
44
+
45
+ if self.sep_style == SeparatorStyle.SINGLE:
46
+ ret = self.system + self.sep
47
+ for role, message in messages:
48
+ if message:
49
+ if type(message) is tuple:
50
+ message, _, _ = message
51
+ ret += role + ": " + message + self.sep
52
+ else:
53
+ ret += role + ":"
54
+ elif self.sep_style == SeparatorStyle.TWO:
55
+ seps = [self.sep, self.sep2]
56
+ ret = self.system + seps[0]
57
+ for i, (role, message) in enumerate(messages):
58
+ if message:
59
+ if type(message) is tuple:
60
+ message, _, _ = message
61
+ ret += role + ": " + message + seps[i % 2]
62
+ else:
63
+ ret += role + ":"
64
+ elif self.sep_style == SeparatorStyle.MPT:
65
+ ret = self.system + self.sep
66
+ for role, message in messages:
67
+ if message:
68
+ if type(message) is tuple:
69
+ message, _, _ = message
70
+ ret += role + message + self.sep
71
+ else:
72
+ ret += role
73
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
74
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
75
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
76
+ ret = ""
77
+
78
+ for i, (role, message) in enumerate(messages):
79
+ if i == 0:
80
+ assert message, "first message should not be none"
81
+ assert role == self.roles[0], "first message should come from user"
82
+ if message:
83
+ if type(message) is tuple:
84
+ message, _, _ = message
85
+ if i == 0: message = wrap_sys(self.system) + message
86
+ if i % 2 == 0:
87
+ message = wrap_inst(message)
88
+ ret += self.sep + message
89
+ else:
90
+ ret += " " + message + " " + self.sep2
91
+ else:
92
+ ret += ""
93
+ ret = ret.lstrip(self.sep)
94
+ elif self.sep_style == SeparatorStyle.PLAIN:
95
+ seps = [self.sep, self.sep2]
96
+ ret = self.system
97
+ for i, (role, message) in enumerate(messages):
98
+ if message:
99
+ if type(message) is tuple:
100
+ message, _, _ = message
101
+ ret += message + seps[i % 2]
102
+ else:
103
+ ret += ""
104
+ else:
105
+ raise ValueError(f"Invalid style: {self.sep_style}")
106
+
107
+ return ret
108
+
109
+ def append_message(self, role, message):
110
+ self.messages.append([role, message])
111
+
112
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
113
+ if image_process_mode == "Pad":
114
+ def expand2square(pil_img, background_color=(122, 116, 104)):
115
+ width, height = pil_img.size
116
+ if width == height:
117
+ return pil_img
118
+ elif width > height:
119
+ result = Image.new(pil_img.mode, (width, width), background_color)
120
+ result.paste(pil_img, (0, (width - height) // 2))
121
+ return result
122
+ else:
123
+ result = Image.new(pil_img.mode, (height, height), background_color)
124
+ result.paste(pil_img, ((height - width) // 2, 0))
125
+ return result
126
+ image = expand2square(image)
127
+ elif image_process_mode in ["Default", "Crop"]:
128
+ pass
129
+ elif image_process_mode == "Resize":
130
+ image = image.resize((336, 336))
131
+ else:
132
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
133
+ if max(image.size) > max_len:
134
+ max_hw, min_hw = max(image.size), min(image.size)
135
+ aspect_ratio = max_hw / min_hw
136
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
137
+ longest_edge = int(shortest_edge * aspect_ratio)
138
+ W, H = image.size
139
+ if H > W:
140
+ H, W = longest_edge, shortest_edge
141
+ else:
142
+ H, W = shortest_edge, longest_edge
143
+ image = image.resize((W, H))
144
+ if return_pil:
145
+ return image
146
+ else:
147
+ buffered = BytesIO()
148
+ image.save(buffered, format=image_format)
149
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
150
+ return img_b64_str
151
+
152
+ def get_images(self, return_pil=False):
153
+ images = []
154
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
155
+ if i % 2 == 0:
156
+ if type(msg) is tuple:
157
+ msg, image, image_process_mode = msg
158
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
159
+ images.append(image)
160
+ return images
161
+
162
+ def to_gradio_chatbot(self):
163
+ ret = []
164
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
165
+ if i % 2 == 0:
166
+ if type(msg) is tuple:
167
+ msg, image, image_process_mode = msg
168
+ img_b64_str = self.process_image(
169
+ image, "Default", return_pil=False,
170
+ image_format='JPEG')
171
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
172
+ msg = img_str + msg.replace('<image>', '').strip()
173
+ ret.append([msg, None])
174
+ else:
175
+ ret.append([msg, None])
176
+ else:
177
+ ret[-1][-1] = msg
178
+ return ret
179
+
180
+ def copy(self):
181
+ return Conversation(
182
+ system=self.system,
183
+ roles=self.roles,
184
+ messages=[[x, y] for x, y in self.messages],
185
+ offset=self.offset,
186
+ sep_style=self.sep_style,
187
+ sep=self.sep,
188
+ sep2=self.sep2,
189
+ version=self.version)
190
+
191
+ def dict(self):
192
+ if len(self.get_images()) > 0:
193
+ return {
194
+ "system": self.system,
195
+ "roles": self.roles,
196
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
197
+ "offset": self.offset,
198
+ "sep": self.sep,
199
+ "sep2": self.sep2,
200
+ }
201
+ return {
202
+ "system": self.system,
203
+ "roles": self.roles,
204
+ "messages": self.messages,
205
+ "offset": self.offset,
206
+ "sep": self.sep,
207
+ "sep2": self.sep2,
208
+ }
209
+
210
+
211
+ conv_vicuna_v0 = Conversation(
212
+ system="A chat between a curious human and an artificial intelligence assistant. "
213
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
214
+ roles=("Human", "Assistant"),
215
+ messages=(
216
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
217
+ ("Assistant",
218
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
219
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
220
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
221
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
222
+ "renewable and non-renewable energy sources:\n"
223
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
224
+ "energy sources are finite and will eventually run out.\n"
225
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
226
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
227
+ "and other negative effects.\n"
228
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
229
+ "have lower operational costs than non-renewable sources.\n"
230
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
231
+ "locations than non-renewable sources.\n"
232
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
233
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
234
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
235
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
236
+ ),
237
+ offset=2,
238
+ sep_style=SeparatorStyle.SINGLE,
239
+ sep="###",
240
+ )
241
+
242
+ conv_vicuna_v1 = Conversation(
243
+ system="A chat between a curious user and an artificial intelligence assistant. "
244
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
245
+ roles=("USER", "ASSISTANT"),
246
+ version="v1",
247
+ messages=(),
248
+ offset=0,
249
+ sep_style=SeparatorStyle.TWO,
250
+ sep=" ",
251
+ sep2="</s>",
252
+ )
253
+
254
+ conv_llama_2 = Conversation(
255
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
256
+
257
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
258
+ roles=("USER", "ASSISTANT"),
259
+ version="llama_v2",
260
+ messages=(),
261
+ offset=0,
262
+ sep_style=SeparatorStyle.LLAMA_2,
263
+ sep="<s>",
264
+ sep2="</s>",
265
+ )
266
+
267
+ conv_llava_llama_2 = Conversation(
268
+ system="You are a helpful language and vision assistant. "
269
+ "You are able to understand the visual content that the user provides, "
270
+ "and assist the user with a variety of tasks using natural language.",
271
+ roles=("USER", "ASSISTANT"),
272
+ version="llama_v2",
273
+ messages=(),
274
+ offset=0,
275
+ sep_style=SeparatorStyle.LLAMA_2,
276
+ sep="<s>",
277
+ sep2="</s>",
278
+ )
279
+
280
+ conv_mpt = Conversation(
281
+ system="""<|im_start|>system
282
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
283
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
284
+ version="mpt",
285
+ messages=(),
286
+ offset=0,
287
+ sep_style=SeparatorStyle.MPT,
288
+ sep="<|im_end|>",
289
+ )
290
+
291
+ conv_llava_plain = Conversation(
292
+ system="",
293
+ roles=("", ""),
294
+ messages=(
295
+ ),
296
+ offset=0,
297
+ sep_style=SeparatorStyle.PLAIN,
298
+ sep="\n",
299
+ )
300
+
301
+ conv_llava_v0 = Conversation(
302
+ system="A chat between a curious human and an artificial intelligence assistant. "
303
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
304
+ roles=("Human", "Assistant"),
305
+ messages=(
306
+ ),
307
+ offset=0,
308
+ sep_style=SeparatorStyle.SINGLE,
309
+ sep="###",
310
+ )
311
+
312
+ conv_llava_v0_mmtag = Conversation(
313
+ system="A chat between a curious user and an artificial intelligence assistant. "
314
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
315
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
316
+ roles=("Human", "Assistant"),
317
+ messages=(
318
+ ),
319
+ offset=0,
320
+ sep_style=SeparatorStyle.SINGLE,
321
+ sep="###",
322
+ version="v0_mmtag",
323
+ )
324
+
325
+ conv_llava_v1 = Conversation(
326
+ system="A chat between a curious human and an artificial intelligence assistant. "
327
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
328
+ roles=("USER", "ASSISTANT"),
329
+ version="v1",
330
+ messages=(),
331
+ offset=0,
332
+ sep_style=SeparatorStyle.TWO,
333
+ sep=" ",
334
+ sep2="</s>",
335
+ )
336
+
337
+ conv_llava_v1_mmtag = Conversation(
338
+ system="A chat between a curious user and an artificial intelligence assistant. "
339
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
340
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
341
+ roles=("USER", "ASSISTANT"),
342
+ messages=(),
343
+ offset=0,
344
+ sep_style=SeparatorStyle.TWO,
345
+ sep=" ",
346
+ sep2="</s>",
347
+ version="v1_mmtag",
348
+ )
349
+
350
+ conv_mistral_instruct = Conversation(
351
+ system="",
352
+ roles=("USER", "ASSISTANT"),
353
+ version="llama_v2",
354
+ messages=(),
355
+ offset=0,
356
+ sep_style=SeparatorStyle.LLAMA_2,
357
+ sep="",
358
+ sep2="</s>",
359
+ )
360
+
361
+ conv_chatml_direct = Conversation(
362
+ system="""<|im_start|>system
363
+ Answer the questions.""",
364
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
365
+ version="mpt",
366
+ messages=(),
367
+ offset=0,
368
+ sep_style=SeparatorStyle.MPT,
369
+ sep="<|im_end|>",
370
+ )
371
+
372
+ default_conversation = conv_vicuna_v1
373
+ conv_templates = {
374
+ "default": conv_vicuna_v0,
375
+ "v0": conv_vicuna_v0,
376
+ "v1": conv_vicuna_v1,
377
+ "vicuna_v1": conv_vicuna_v1,
378
+ "llama_2": conv_llama_2,
379
+ "mistral_instruct": conv_mistral_instruct,
380
+ "chatml_direct": conv_chatml_direct,
381
+ "mistral_direct": conv_chatml_direct,
382
+
383
+ "plain": conv_llava_plain,
384
+ "v0_plain": conv_llava_plain,
385
+ "llava_v0": conv_llava_v0,
386
+ "v0_mmtag": conv_llava_v0_mmtag,
387
+ "llava_v1": conv_llava_v1,
388
+ "v1_mmtag": conv_llava_v1_mmtag,
389
+ "llava_llama_2": conv_llava_llama_2,
390
+
391
+ "mpt": conv_mpt,
392
+ }
393
+
394
+
395
+ if __name__ == "__main__":
396
+ print(default_conversation.get_prompt())
dialoggen/llava/mm_utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+ import torch
5
+ import math
6
+ import ast
7
+
8
+ from transformers import StoppingCriteria
9
+ from llava.constants import IMAGE_TOKEN_INDEX
10
+
11
+
12
+ def select_best_resolution(original_size, possible_resolutions):
13
+ """
14
+ Selects the best resolution from a list of possible resolutions based on the original size.
15
+
16
+ Args:
17
+ original_size (tuple): The original size of the image in the format (width, height).
18
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
19
+
20
+ Returns:
21
+ tuple: The best fit resolution in the format (width, height).
22
+ """
23
+ original_width, original_height = original_size
24
+ best_fit = None
25
+ max_effective_resolution = 0
26
+ min_wasted_resolution = float('inf')
27
+
28
+ for width, height in possible_resolutions:
29
+ scale = min(width / original_width, height / original_height)
30
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
31
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
32
+ wasted_resolution = (width * height) - effective_resolution
33
+
34
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
35
+ max_effective_resolution = effective_resolution
36
+ min_wasted_resolution = wasted_resolution
37
+ best_fit = (width, height)
38
+
39
+ return best_fit
40
+
41
+
42
+ def resize_and_pad_image(image, target_resolution):
43
+ """
44
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
45
+
46
+ Args:
47
+ image (PIL.Image.Image): The input image.
48
+ target_resolution (tuple): The target resolution (width, height) of the image.
49
+
50
+ Returns:
51
+ PIL.Image.Image: The resized and padded image.
52
+ """
53
+ original_width, original_height = image.size
54
+ target_width, target_height = target_resolution
55
+
56
+ scale_w = target_width / original_width
57
+ scale_h = target_height / original_height
58
+
59
+ if scale_w < scale_h:
60
+ new_width = target_width
61
+ new_height = min(math.ceil(original_height * scale_w), target_height)
62
+ else:
63
+ new_height = target_height
64
+ new_width = min(math.ceil(original_width * scale_h), target_width)
65
+
66
+ # Resize the image
67
+ resized_image = image.resize((new_width, new_height))
68
+
69
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
70
+ paste_x = (target_width - new_width) // 2
71
+ paste_y = (target_height - new_height) // 2
72
+ new_image.paste(resized_image, (paste_x, paste_y))
73
+
74
+ return new_image
75
+
76
+
77
+ def divide_to_patches(image, patch_size):
78
+ """
79
+ Divides an image into patches of a specified size.
80
+
81
+ Args:
82
+ image (PIL.Image.Image): The input image.
83
+ patch_size (int): The size of each patch.
84
+
85
+ Returns:
86
+ list: A list of PIL.Image.Image objects representing the patches.
87
+ """
88
+ patches = []
89
+ width, height = image.size
90
+ for i in range(0, height, patch_size):
91
+ for j in range(0, width, patch_size):
92
+ box = (j, i, j + patch_size, i + patch_size)
93
+ patch = image.crop(box)
94
+ patches.append(patch)
95
+
96
+ return patches
97
+
98
+
99
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
100
+ """
101
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
102
+
103
+ Args:
104
+ image_size (tuple): The size of the input image in the format (width, height).
105
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
106
+ patch_size (int): The size of each image patch.
107
+
108
+ Returns:
109
+ tuple: The shape of the image patch grid in the format (width, height).
110
+ """
111
+ if type(grid_pinpoints) is list:
112
+ possible_resolutions = grid_pinpoints
113
+ else:
114
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
115
+ width, height = select_best_resolution(image_size, possible_resolutions)
116
+ return width // patch_size, height // patch_size
117
+
118
+
119
+ def process_anyres_image(image, processor, grid_pinpoints):
120
+ """
121
+ Process an image with variable resolutions.
122
+
123
+ Args:
124
+ image (PIL.Image.Image): The input image to be processed.
125
+ processor: The image processor object.
126
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
127
+
128
+ Returns:
129
+ torch.Tensor: A tensor containing the processed image patches.
130
+ """
131
+ if type(grid_pinpoints) is list:
132
+ possible_resolutions = grid_pinpoints
133
+ else:
134
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
135
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
136
+ image_padded = resize_and_pad_image(image, best_resolution)
137
+
138
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
139
+
140
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
141
+
142
+ image_patches = [image_original_resize] + patches
143
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
144
+ for image_patch in image_patches]
145
+ return torch.stack(image_patches, dim=0)
146
+
147
+
148
+ def load_image_from_base64(image):
149
+ return Image.open(BytesIO(base64.b64decode(image)))
150
+
151
+
152
+ def expand2square(pil_img, background_color):
153
+ width, height = pil_img.size
154
+ if width == height:
155
+ return pil_img
156
+ elif width > height:
157
+ result = Image.new(pil_img.mode, (width, width), background_color)
158
+ result.paste(pil_img, (0, (width - height) // 2))
159
+ return result
160
+ else:
161
+ result = Image.new(pil_img.mode, (height, height), background_color)
162
+ result.paste(pil_img, ((height - width) // 2, 0))
163
+ return result
164
+
165
+
166
+ def process_images(images, image_processor, model_cfg):
167
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
168
+ new_images = []
169
+ if image_aspect_ratio == 'pad':
170
+ for image in images:
171
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
172
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
173
+ new_images.append(image)
174
+ elif image_aspect_ratio == "anyres":
175
+ for image in images:
176
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
177
+ new_images.append(image)
178
+ else:
179
+ return image_processor(images, return_tensors='pt')['pixel_values']
180
+ if all(x.shape == new_images[0].shape for x in new_images):
181
+ new_images = torch.stack(new_images, dim=0)
182
+ return new_images
183
+
184
+
185
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
186
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
187
+
188
+ def insert_separator(X, sep):
189
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
190
+
191
+ input_ids = []
192
+ offset = 0
193
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
194
+ offset = 1
195
+ input_ids.append(prompt_chunks[0][0])
196
+
197
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
198
+ input_ids.extend(x[offset:])
199
+
200
+ if return_tensors is not None:
201
+ if return_tensors == 'pt':
202
+ return torch.tensor(input_ids, dtype=torch.long)
203
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
204
+ return input_ids
205
+
206
+
207
+ def get_model_name_from_path(model_path):
208
+ model_path = model_path.strip("/")
209
+ model_paths = model_path.split("/")
210
+ if model_paths[-1].startswith('checkpoint-'):
211
+ return model_paths[-2] + "_" + model_paths[-1]
212
+ else:
213
+ return model_paths[-1]
214
+
215
+ class KeywordsStoppingCriteria(StoppingCriteria):
216
+ def __init__(self, keywords, tokenizer, input_ids):
217
+ self.keywords = keywords
218
+ self.keyword_ids = []
219
+ self.max_keyword_len = 0
220
+ for keyword in keywords:
221
+ cur_keyword_ids = tokenizer(keyword).input_ids
222
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
223
+ cur_keyword_ids = cur_keyword_ids[1:]
224
+ if len(cur_keyword_ids) > self.max_keyword_len:
225
+ self.max_keyword_len = len(cur_keyword_ids)
226
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
227
+ self.tokenizer = tokenizer
228
+ self.start_len = input_ids.shape[1]
229
+
230
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
231
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
232
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
233
+ for keyword_id in self.keyword_ids:
234
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
235
+ if torch.equal(truncated_output_ids, keyword_id):
236
+ return True
237
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
238
+ for keyword in self.keywords:
239
+ if keyword in outputs:
240
+ return True
241
+ return False
242
+
243
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
244
+ outputs = []
245
+ for i in range(output_ids.shape[0]):
246
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
247
+ return all(outputs)
dialoggen/llava/model/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ try:
2
+ from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
3
+ from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
4
+ from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
5
+ except:
6
+ pass
dialoggen/llava/model/apply_delta.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from llava import LlavaLlamaForCausalLM
11
+
12
+
13
+ def apply_delta(base_model_path, target_model_path, delta_path):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading delta")
19
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
+
22
+ print("Applying delta")
23
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data += base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31
+ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32
+ bparam = base.state_dict()[name]
33
+ param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34
+
35
+ print("Saving target model")
36
+ delta.save_pretrained(target_model_path)
37
+ delta_tokenizer.save_pretrained(target_model_path)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--base-model-path", type=str, required=True)
43
+ parser.add_argument("--target-model-path", type=str, required=True)
44
+ parser.add_argument("--delta-path", type=str, required=True)
45
+
46
+ args = parser.parse_args()
47
+
48
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
dialoggen/llava/model/builder.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from llava.model import *
23
+ from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+
25
+
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, llava_type_model=True, **kwargs):
27
+ kwargs = {"device_map": device_map, **kwargs}
28
+
29
+ if device != "cuda":
30
+ kwargs['device_map'] = {"": device}
31
+
32
+ if load_8bit:
33
+ kwargs['load_in_8bit'] = True
34
+ elif load_4bit:
35
+ kwargs['load_in_4bit'] = True
36
+ kwargs['quantization_config'] = BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_compute_dtype=torch.float16,
39
+ bnb_4bit_use_double_quant=True,
40
+ bnb_4bit_quant_type='nf4'
41
+ )
42
+ else:
43
+ kwargs['torch_dtype'] = torch.float16
44
+
45
+ if use_flash_attn:
46
+ kwargs['attn_implementation'] = 'flash_attention_2'
47
+
48
+ if 'llava' in model_name.lower():
49
+ # Load LLaVA model
50
+ if 'lora' in model_name.lower() and model_base is None:
51
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
52
+ if 'lora' in model_name.lower() and model_base is not None:
53
+ from llava.model.language_model.llava_llama import LlavaConfig
54
+ lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
55
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
56
+ print('Loading LLaVA from base model...')
57
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
58
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
59
+ if model.lm_head.weight.shape[0] != token_num:
60
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
61
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
62
+
63
+ print('Loading additional LLaVA weights...')
64
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
65
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
66
+ else:
67
+ # this is probably from HF Hub
68
+ from huggingface_hub import hf_hub_download
69
+ def load_from_hf(repo_id, filename, subfolder=None):
70
+ cache_file = hf_hub_download(
71
+ repo_id=repo_id,
72
+ filename=filename,
73
+ subfolder=subfolder)
74
+ return torch.load(cache_file, map_location='cpu')
75
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
76
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
77
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
78
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
79
+ model.load_state_dict(non_lora_trainables, strict=False)
80
+
81
+ from peft import PeftModel
82
+ print('Loading LoRA weights...')
83
+ model = PeftModel.from_pretrained(model, model_path)
84
+ print('Merging LoRA weights...')
85
+ model = model.merge_and_unload()
86
+ print('Model is loaded...')
87
+ elif model_base is not None:
88
+ # this may be mm projector only
89
+ print('Loading LLaVA from base model...')
90
+ if 'mpt' in model_name.lower():
91
+ if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
92
+ shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
93
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
94
+ cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
95
+ model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
96
+ else:
97
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
98
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
99
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
100
+
101
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
102
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
103
+ model.load_state_dict(mm_projector_weights, strict=False)
104
+ else:
105
+ if 'mpt' in model_name.lower():
106
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
107
+ model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
108
+ elif 'mistral' in model_name.lower():
109
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
110
+ model = LlavaMistralForCausalLM.from_pretrained(
111
+ model_path,
112
+ low_cpu_mem_usage=True,
113
+ **kwargs
114
+ )
115
+ else:
116
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
117
+ model = LlavaLlamaForCausalLM.from_pretrained(
118
+ model_path,
119
+ low_cpu_mem_usage=True,
120
+ **kwargs
121
+ )
122
+ else:
123
+ # Load language model
124
+ if model_base is not None:
125
+ # PEFT model
126
+ from peft import PeftModel
127
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
128
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
129
+ print(f"Loading LoRA weights from {model_path}")
130
+ model = PeftModel.from_pretrained(model, model_path)
131
+ print(f"Merging weights")
132
+ model = model.merge_and_unload()
133
+ print('Convert to FP16...')
134
+ model.to(torch.float16)
135
+ else:
136
+ use_fast = False
137
+ if 'mpt' in model_name.lower():
138
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
139
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
140
+ else:
141
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
142
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
143
+
144
+ image_processor = None
145
+
146
+ if llava_type_model:
147
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
148
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
149
+ if mm_use_im_patch_token:
150
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
151
+ if mm_use_im_start_end:
152
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
153
+ model.resize_token_embeddings(len(tokenizer))
154
+
155
+ vision_tower = model.get_vision_tower()
156
+ if not vision_tower.is_loaded:
157
+ vision_tower.load_model(device_map=device_map)
158
+ if device_map != 'auto':
159
+ vision_tower.to(device=device_map, dtype=torch.float16)
160
+ image_processor = vision_tower.image_processor
161
+
162
+ if hasattr(model.config, "max_sequence_length"):
163
+ context_len = model.config.max_sequence_length
164
+ else:
165
+ context_len = 2048
166
+
167
+ return tokenizer, model, image_processor, context_len
dialoggen/llava/model/consolidate.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from llava.model import *
10
+ from llava.model.utils import auto_upgrade
11
+
12
+
13
+ def consolidate_ckpt(src_path, dst_path):
14
+ print("Loading model")
15
+ auto_upgrade(src_path)
16
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18
+ src_model.save_pretrained(dst_path)
19
+ src_tokenizer.save_pretrained(dst_path)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--src", type=str, required=True)
25
+ parser.add_argument("--dst", type=str, required=True)
26
+
27
+ args = parser.parse_args()
28
+
29
+ consolidate_ckpt(args.src, args.dst)
dialoggen/llava/model/language_model/llava_llama.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from transformers import AutoConfig, AutoModelForCausalLM, \
22
+ LlamaConfig, LlamaModel, LlamaForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaConfig(LlamaConfig):
31
+ model_type = "llava_llama"
32
+
33
+
34
+ class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
35
+ config_class = LlavaConfig
36
+
37
+ def __init__(self, config: LlamaConfig):
38
+ super(LlavaLlamaModel, self).__init__(config)
39
+
40
+
41
+ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaConfig
43
+
44
+ def __init__(self, config):
45
+ super(LlamaForCausalLM, self).__init__(config)
46
+ self.model = LlavaLlamaModel(config)
47
+ self.pretraining_tp = config.pretraining_tp
48
+ self.vocab_size = config.vocab_size
49
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50
+
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ labels: Optional[torch.LongTensor] = None,
65
+ use_cache: Optional[bool] = None,
66
+ output_attentions: Optional[bool] = None,
67
+ output_hidden_states: Optional[bool] = None,
68
+ images: Optional[torch.FloatTensor] = None,
69
+ image_sizes: Optional[List[List[int]]] = None,
70
+ return_dict: Optional[bool] = None,
71
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
72
+
73
+ if inputs_embeds is None:
74
+ (
75
+ input_ids,
76
+ position_ids,
77
+ attention_mask,
78
+ past_key_values,
79
+ inputs_embeds,
80
+ labels
81
+ ) = self.prepare_inputs_labels_for_multimodal(
82
+ input_ids,
83
+ position_ids,
84
+ attention_mask,
85
+ past_key_values,
86
+ labels,
87
+ images,
88
+ image_sizes
89
+ )
90
+
91
+ return super().forward(
92
+ input_ids=input_ids,
93
+ attention_mask=attention_mask,
94
+ position_ids=position_ids,
95
+ past_key_values=past_key_values,
96
+ inputs_embeds=inputs_embeds,
97
+ labels=labels,
98
+ use_cache=use_cache,
99
+ output_attentions=output_attentions,
100
+ output_hidden_states=output_hidden_states,
101
+ return_dict=return_dict
102
+ )
103
+
104
+ @torch.no_grad()
105
+ def generate(
106
+ self,
107
+ inputs: Optional[torch.Tensor] = None,
108
+ images: Optional[torch.Tensor] = None,
109
+ image_sizes: Optional[torch.Tensor] = None,
110
+ **kwargs,
111
+ ) -> Union[GenerateOutput, torch.LongTensor]:
112
+ position_ids = kwargs.pop("position_ids", None)
113
+ attention_mask = kwargs.pop("attention_mask", None)
114
+ if "inputs_embeds" in kwargs:
115
+ raise NotImplementedError("`inputs_embeds` is not supported")
116
+
117
+ if images is not None:
118
+ (
119
+ inputs,
120
+ position_ids,
121
+ attention_mask,
122
+ _,
123
+ inputs_embeds,
124
+ _
125
+ ) = self.prepare_inputs_labels_for_multimodal(
126
+ inputs,
127
+ position_ids,
128
+ attention_mask,
129
+ None,
130
+ None,
131
+ images,
132
+ image_sizes=image_sizes
133
+ )
134
+ else:
135
+ inputs_embeds = self.get_model().embed_tokens(inputs)
136
+
137
+ return super().generate(
138
+ position_ids=position_ids,
139
+ attention_mask=attention_mask,
140
+ inputs_embeds=inputs_embeds,
141
+ **kwargs
142
+ )
143
+
144
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145
+ inputs_embeds=None, **kwargs):
146
+ images = kwargs.pop("images", None)
147
+ image_sizes = kwargs.pop("image_sizes", None)
148
+ inputs = super().prepare_inputs_for_generation(
149
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150
+ )
151
+ if images is not None:
152
+ inputs['images'] = images
153
+ if image_sizes is not None:
154
+ inputs['image_sizes'] = image_sizes
155
+ return inputs
156
+
157
+ AutoConfig.register("llava_llama", LlavaConfig)
158
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
dialoggen/llava/model/language_model/llava_mistral.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ MistralConfig, MistralModel, MistralForCausalLM
24
+
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+ from transformers.generation.utils import GenerateOutput
27
+
28
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
+
30
+
31
+ class LlavaMistralConfig(MistralConfig):
32
+ model_type = "llava_mistral"
33
+
34
+
35
+ class LlavaMistralModel(LlavaMetaModel, MistralModel):
36
+ config_class = LlavaMistralConfig
37
+
38
+ def __init__(self, config: MistralConfig):
39
+ super(LlavaMistralModel, self).__init__(config)
40
+
41
+
42
+ class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
43
+ config_class = LlavaMistralConfig
44
+
45
+ def __init__(self, config):
46
+ super(MistralForCausalLM, self).__init__(config)
47
+ self.model = LlavaMistralModel(config)
48
+
49
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50
+
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ labels: Optional[torch.LongTensor] = None,
65
+ use_cache: Optional[bool] = None,
66
+ output_attentions: Optional[bool] = None,
67
+ output_hidden_states: Optional[bool] = None,
68
+ images: Optional[torch.FloatTensor] = None,
69
+ image_sizes: Optional[List[List[int]]] = None,
70
+ return_dict: Optional[bool] = None,
71
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
72
+
73
+ if inputs_embeds is None:
74
+ (
75
+ input_ids,
76
+ position_ids,
77
+ attention_mask,
78
+ past_key_values,
79
+ inputs_embeds,
80
+ labels
81
+ ) = self.prepare_inputs_labels_for_multimodal(
82
+ input_ids,
83
+ position_ids,
84
+ attention_mask,
85
+ past_key_values,
86
+ labels,
87
+ images,
88
+ image_sizes
89
+ )
90
+
91
+ return super().forward(
92
+ input_ids=input_ids,
93
+ attention_mask=attention_mask,
94
+ position_ids=position_ids,
95
+ past_key_values=past_key_values,
96
+ inputs_embeds=inputs_embeds,
97
+ labels=labels,
98
+ use_cache=use_cache,
99
+ output_attentions=output_attentions,
100
+ output_hidden_states=output_hidden_states,
101
+ return_dict=return_dict
102
+ )
103
+
104
+ @torch.no_grad()
105
+ def generate(
106
+ self,
107
+ inputs: Optional[torch.Tensor] = None,
108
+ images: Optional[torch.Tensor] = None,
109
+ image_sizes: Optional[torch.Tensor] = None,
110
+ **kwargs,
111
+ ) -> Union[GenerateOutput, torch.LongTensor]:
112
+ position_ids = kwargs.pop("position_ids", None)
113
+ attention_mask = kwargs.pop("attention_mask", None)
114
+ if "inputs_embeds" in kwargs:
115
+ raise NotImplementedError("`inputs_embeds` is not supported")
116
+
117
+ if images is not None:
118
+ (
119
+ inputs,
120
+ position_ids,
121
+ attention_mask,
122
+ _,
123
+ inputs_embeds,
124
+ _
125
+ ) = self.prepare_inputs_labels_for_multimodal(
126
+ inputs,
127
+ position_ids,
128
+ attention_mask,
129
+ None,
130
+ None,
131
+ images,
132
+ image_sizes=image_sizes
133
+ )
134
+ else:
135
+ inputs_embeds = self.get_model().embed_tokens(inputs)
136
+
137
+ return super().generate(
138
+ position_ids=position_ids,
139
+ attention_mask=attention_mask,
140
+ inputs_embeds=inputs_embeds,
141
+ **kwargs
142
+ )
143
+
144
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145
+ inputs_embeds=None, **kwargs):
146
+ images = kwargs.pop("images", None)
147
+ image_sizes = kwargs.pop("image_sizes", None)
148
+ inputs = super().prepare_inputs_for_generation(
149
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150
+ )
151
+ if images is not None:
152
+ inputs['images'] = images
153
+ if image_sizes is not None:
154
+ inputs['image_sizes'] = image_sizes
155
+ return inputs
156
+
157
+ AutoConfig.register("llava_mistral", LlavaMistralConfig)
158
+ AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
dialoggen/llava/model/language_model/llava_mpt.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+
20
+ from transformers import AutoConfig, AutoModelForCausalLM, \
21
+ MptConfig, MptForCausalLM, MptModel
22
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
23
+
24
+
25
+ class LlavaMptConfig(MptConfig):
26
+ model_type = "llava_mpt"
27
+
28
+
29
+ class LlavaMptModel(LlavaMetaModel, MptModel):
30
+ config_class = LlavaMptConfig
31
+
32
+ def __init__(self, config: MptConfig):
33
+ config.hidden_size = config.d_model
34
+ super(LlavaMptModel, self).__init__(config)
35
+
36
+ def embed_tokens(self, x):
37
+ return self.wte(x)
38
+
39
+
40
+ class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
41
+ config_class = LlavaMptConfig
42
+ supports_gradient_checkpointing = True
43
+
44
+ def __init__(self, config):
45
+ super(MptForCausalLM, self).__init__(config)
46
+
47
+ self.transformer = LlavaMptModel(config)
48
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.transformer
55
+
56
+ def _set_gradient_checkpointing(self, module, value=False):
57
+ if isinstance(module, LlavaMptModel):
58
+ module.gradient_checkpointing = value
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: Optional[torch.LongTensor] = None,
63
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
64
+ attention_mask: Optional[torch.Tensor] = None,
65
+ inputs_embeds: Optional[torch.Tensor] = None,
66
+ labels: Optional[torch.Tensor] = None,
67
+ use_cache: Optional[bool] = None,
68
+ output_attentions: Optional[bool] = None,
69
+ output_hidden_states: Optional[bool] = None,
70
+ return_dict: Optional[bool] = None,
71
+ images=None):
72
+
73
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
74
+
75
+ return super().forward(
76
+ input_ids,
77
+ past_key_values=past_key_values,
78
+ attention_mask=attention_mask,
79
+ inputs_embeds=inputs_embeds,
80
+ labels=labels,
81
+ use_cache=use_cache,
82
+ output_attentions=output_attentions,
83
+ output_hidden_states=output_hidden_states,
84
+ return_dict=return_dict,
85
+ )
86
+
87
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
88
+ images = kwargs.pop("images", None)
89
+ _inputs = super().prepare_inputs_for_generation(
90
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
91
+ )
92
+ _inputs['images'] = images
93
+ return _inputs
94
+
95
+
96
+ AutoConfig.register("llava_mpt", LlavaMptConfig)
97
+ AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
dialoggen/llava/model/llava_arch.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .multimodal_encoder.builder import build_vision_tower
22
+ from .multimodal_projector.builder import build_vision_projector
23
+
24
+ from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
+
26
+ from llava.mm_utils import get_anyres_image_grid_shape
27
+
28
+
29
+ class LlavaMetaModel:
30
+
31
+ def __init__(self, config):
32
+ super(LlavaMetaModel, self).__init__(config)
33
+
34
+ if hasattr(config, "mm_vision_tower"):
35
+ self.vision_tower = build_vision_tower(config, delay_load=True)
36
+ self.mm_projector = build_vision_projector(config)
37
+
38
+ if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
39
+ self.image_newline = nn.Parameter(
40
+ torch.empty(config.hidden_size, dtype=self.dtype)
41
+ )
42
+
43
+ def get_vision_tower(self):
44
+ vision_tower = getattr(self, 'vision_tower', None)
45
+ if type(vision_tower) is list:
46
+ vision_tower = vision_tower[0]
47
+ return vision_tower
48
+
49
+ def initialize_vision_modules(self, model_args, fsdp=None):
50
+ vision_tower = model_args.vision_tower
51
+ mm_vision_select_layer = model_args.mm_vision_select_layer
52
+ mm_vision_select_feature = model_args.mm_vision_select_feature
53
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
54
+ mm_patch_merge_type = model_args.mm_patch_merge_type
55
+
56
+ self.config.mm_vision_tower = vision_tower
57
+
58
+ if self.get_vision_tower() is None:
59
+ vision_tower = build_vision_tower(model_args)
60
+
61
+ if fsdp is not None and len(fsdp) > 0:
62
+ self.vision_tower = [vision_tower]
63
+ else:
64
+ self.vision_tower = vision_tower
65
+ else:
66
+ if fsdp is not None and len(fsdp) > 0:
67
+ vision_tower = self.vision_tower[0]
68
+ else:
69
+ vision_tower = self.vision_tower
70
+ vision_tower.load_model()
71
+
72
+ self.config.use_mm_proj = True
73
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
74
+ self.config.mm_hidden_size = vision_tower.hidden_size
75
+ self.config.mm_vision_select_layer = mm_vision_select_layer
76
+ self.config.mm_vision_select_feature = mm_vision_select_feature
77
+ self.config.mm_patch_merge_type = mm_patch_merge_type
78
+
79
+ if getattr(self, 'mm_projector', None) is None:
80
+ self.mm_projector = build_vision_projector(self.config)
81
+
82
+ if 'unpad' in mm_patch_merge_type:
83
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
84
+ self.image_newline = nn.Parameter(
85
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
86
+ )
87
+ else:
88
+ # In case it is frozen by LoRA
89
+ for p in self.mm_projector.parameters():
90
+ p.requires_grad = True
91
+
92
+ if pretrain_mm_mlp_adapter is not None:
93
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
94
+ def get_w(weights, keyword):
95
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
96
+
97
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
98
+
99
+
100
+ def unpad_image(tensor, original_size):
101
+ """
102
+ Unpads a PyTorch tensor of a padded and resized image.
103
+
104
+ Args:
105
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
106
+ original_size (tuple): The original size of the image (height, width).
107
+
108
+ Returns:
109
+ torch.Tensor: The unpadded image tensor.
110
+ """
111
+ original_width, original_height = original_size
112
+ current_height, current_width = tensor.shape[1:]
113
+
114
+ original_aspect_ratio = original_width / original_height
115
+ current_aspect_ratio = current_width / current_height
116
+
117
+ if original_aspect_ratio > current_aspect_ratio:
118
+ scale_factor = current_width / original_width
119
+ new_height = int(original_height * scale_factor)
120
+ padding = (current_height - new_height) // 2
121
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
122
+ else:
123
+ scale_factor = current_height / original_height
124
+ new_width = int(original_width * scale_factor)
125
+ padding = (current_width - new_width) // 2
126
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
127
+
128
+ return unpadded_tensor
129
+
130
+
131
+ class LlavaMetaForCausalLM(ABC):
132
+
133
+ @abstractmethod
134
+ def get_model(self):
135
+ pass
136
+
137
+ def get_vision_tower(self):
138
+ return self.get_model().get_vision_tower()
139
+
140
+ def encode_images(self, images):
141
+ image_features = self.get_model().get_vision_tower()(images)
142
+ image_features = self.get_model().mm_projector(image_features)
143
+ return image_features
144
+
145
+ def prepare_inputs_labels_for_multimodal(
146
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
147
+ images, image_sizes=None
148
+ ):
149
+ vision_tower = self.get_vision_tower()
150
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
151
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
152
+
153
+ if type(images) is list or images.ndim == 5:
154
+ if type(images) is list:
155
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
156
+ concat_images = torch.cat([image for image in images], dim=0)
157
+ image_features = self.encode_images(concat_images)
158
+ split_sizes = [image.shape[0] for image in images]
159
+ image_features = torch.split(image_features, split_sizes, dim=0)
160
+ mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
161
+ image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
162
+ if mm_patch_merge_type == 'flat':
163
+ image_features = [x.flatten(0, 1) for x in image_features]
164
+ elif mm_patch_merge_type.startswith('spatial'):
165
+ new_image_features = []
166
+ for image_idx, image_feature in enumerate(image_features):
167
+ if image_feature.shape[0] > 1:
168
+ base_image_feature = image_feature[0]
169
+ image_feature = image_feature[1:]
170
+ height = width = self.get_vision_tower().num_patches_per_side
171
+ assert height * width == base_image_feature.shape[0]
172
+ if image_aspect_ratio == 'anyres':
173
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
174
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
175
+ else:
176
+ raise NotImplementedError
177
+ if 'unpad' in mm_patch_merge_type:
178
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
179
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
180
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
181
+ image_feature = torch.cat((
182
+ image_feature,
183
+ self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
184
+ ), dim=-1)
185
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
186
+ else:
187
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
188
+ image_feature = image_feature.flatten(0, 3)
189
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
190
+ else:
191
+ image_feature = image_feature[0]
192
+ if 'unpad' in mm_patch_merge_type:
193
+ image_feature = torch.cat((
194
+ image_feature,
195
+ self.model.image_newline[None].to(image_feature.device)
196
+ ), dim=0)
197
+ new_image_features.append(image_feature)
198
+ image_features = new_image_features
199
+ else:
200
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
201
+ else:
202
+ image_features = self.encode_images(images)
203
+
204
+ # TODO: image start / end is not implemented here to support pretraining.
205
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
206
+ raise NotImplementedError
207
+
208
+ # Let's just add dummy tensors if they do not exist,
209
+ # it is a headache to deal with None all the time.
210
+ # But it is not ideal, and if you have a better idea,
211
+ # please open an issue / submit a PR, thanks.
212
+ _labels = labels
213
+ _position_ids = position_ids
214
+ _attention_mask = attention_mask
215
+ if attention_mask is None:
216
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
217
+ else:
218
+ attention_mask = attention_mask.bool()
219
+ if position_ids is None:
220
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
221
+ if labels is None:
222
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
223
+
224
+ # remove the padding using attention_mask -- FIXME
225
+ _input_ids = input_ids
226
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
227
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
228
+
229
+ new_input_embeds = []
230
+ new_labels = []
231
+ cur_image_idx = 0
232
+ for batch_idx, cur_input_ids in enumerate(input_ids):
233
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
234
+ if num_images == 0:
235
+ cur_image_features = image_features[cur_image_idx]
236
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
237
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
238
+ new_input_embeds.append(cur_input_embeds)
239
+ new_labels.append(labels[batch_idx])
240
+ cur_image_idx += 1
241
+ continue
242
+
243
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
244
+ cur_input_ids_noim = []
245
+ cur_labels = labels[batch_idx]
246
+ cur_labels_noim = []
247
+ for i in range(len(image_token_indices) - 1):
248
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
249
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
250
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
251
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
252
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
253
+ cur_new_input_embeds = []
254
+ cur_new_labels = []
255
+
256
+ for i in range(num_images + 1):
257
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
258
+ cur_new_labels.append(cur_labels_noim[i])
259
+ if i < num_images:
260
+ cur_image_features = image_features[cur_image_idx]
261
+ cur_image_idx += 1
262
+ cur_new_input_embeds.append(cur_image_features)
263
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
264
+
265
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
266
+
267
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
268
+ cur_new_labels = torch.cat(cur_new_labels)
269
+
270
+ new_input_embeds.append(cur_new_input_embeds)
271
+ new_labels.append(cur_new_labels)
272
+
273
+ # Truncate sequences to max length as image embeddings can make the sequence longer
274
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
275
+ if tokenizer_model_max_length is not None:
276
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
277
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
278
+
279
+ # Combine them
280
+ max_len = max(x.shape[0] for x in new_input_embeds)
281
+ batch_size = len(new_input_embeds)
282
+
283
+ new_input_embeds_padded = []
284
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
285
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
286
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
287
+
288
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
289
+ cur_len = cur_new_embed.shape[0]
290
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
291
+ new_input_embeds_padded.append(torch.cat((
292
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
293
+ cur_new_embed
294
+ ), dim=0))
295
+ if cur_len > 0:
296
+ new_labels_padded[i, -cur_len:] = cur_new_labels
297
+ attention_mask[i, -cur_len:] = True
298
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
299
+ else:
300
+ new_input_embeds_padded.append(torch.cat((
301
+ cur_new_embed,
302
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
303
+ ), dim=0))
304
+ if cur_len > 0:
305
+ new_labels_padded[i, :cur_len] = cur_new_labels
306
+ attention_mask[i, :cur_len] = True
307
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
308
+
309
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
310
+
311
+ if _labels is None:
312
+ new_labels = None
313
+ else:
314
+ new_labels = new_labels_padded
315
+
316
+ if _attention_mask is None:
317
+ attention_mask = None
318
+ else:
319
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
320
+
321
+ if _position_ids is None:
322
+ position_ids = None
323
+
324
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
325
+
326
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
327
+ if model_args.mm_use_im_patch_token:
328
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
329
+ self.resize_token_embeddings(len(tokenizer))
330
+
331
+ if model_args.mm_use_im_start_end:
332
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
333
+ self.resize_token_embeddings(len(tokenizer))
334
+
335
+ if num_new_tokens > 0:
336
+ input_embeddings = self.get_input_embeddings().weight.data
337
+ output_embeddings = self.get_output_embeddings().weight.data
338
+
339
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
340
+ dim=0, keepdim=True)
341
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
342
+ dim=0, keepdim=True)
343
+
344
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
345
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
346
+
347
+ if model_args.tune_mm_mlp_adapter:
348
+ for p in self.get_input_embeddings().parameters():
349
+ p.requires_grad = True
350
+ for p in self.get_output_embeddings().parameters():
351
+ p.requires_grad = False
352
+
353
+ if model_args.pretrain_mm_mlp_adapter:
354
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
355
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
356
+ assert num_new_tokens == 2
357
+ if input_embeddings.shape == embed_tokens_weight.shape:
358
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
359
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
360
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
361
+ else:
362
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
363
+ elif model_args.mm_use_im_patch_token:
364
+ if model_args.tune_mm_mlp_adapter:
365
+ for p in self.get_input_embeddings().parameters():
366
+ p.requires_grad = False
367
+ for p in self.get_output_embeddings().parameters():
368
+ p.requires_grad = False
dialoggen/llava/model/make_delta.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from llava.model.utils import auto_upgrade
11
+
12
+
13
+ def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading target model")
19
+ auto_upgrade(target_model_path)
20
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21
+
22
+ print("Calculating delta")
23
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data -= base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31
+ bparam = base.state_dict()[name]
32
+ param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33
+
34
+ print("Saving delta")
35
+ if hub_repo_id:
36
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37
+ else:
38
+ kwargs = {}
39
+ target.save_pretrained(delta_path, **kwargs)
40
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("--base-model-path", type=str, required=True)
47
+ parser.add_argument("--target-model-path", type=str, required=True)
48
+ parser.add_argument("--delta-path", type=str, required=True)
49
+ parser.add_argument("--hub-repo-id", type=str, default=None)
50
+ args = parser.parse_args()
51
+
52
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
dialoggen/llava/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .clip_encoder import CLIPVisionTower
3
+
4
+
5
+ def build_vision_tower(vision_tower_cfg, **kwargs):
6
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7
+ is_absolute_path_exists = os.path.exists(vision_tower)
8
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
9
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
10
+
11
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
dialoggen/llava/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ self.select_layer = args.mm_vision_select_layer
15
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16
+
17
+ if not delay_load:
18
+ self.load_model()
19
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
20
+ self.load_model()
21
+ else:
22
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
23
+
24
+ def load_model(self, device_map=None):
25
+ if self.is_loaded:
26
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
27
+ return
28
+
29
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
30
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
31
+ self.vision_tower.requires_grad_(False)
32
+
33
+ self.is_loaded = True
34
+
35
+ def feature_select(self, image_forward_outs):
36
+ image_features = image_forward_outs.hidden_states[self.select_layer]
37
+ if self.select_feature == 'patch':
38
+ image_features = image_features[:, 1:]
39
+ elif self.select_feature == 'cls_patch':
40
+ image_features = image_features
41
+ else:
42
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
43
+ return image_features
44
+
45
+ @torch.no_grad()
46
+ def forward(self, images):
47
+ if type(images) is list:
48
+ image_features = []
49
+ for image in images:
50
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
51
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
52
+ image_features.append(image_feature)
53
+ else:
54
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
55
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
56
+
57
+ return image_features
58
+
59
+ @property
60
+ def dummy_feature(self):
61
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
62
+
63
+ @property
64
+ def dtype(self):
65
+ return self.vision_tower.dtype
66
+
67
+ @property
68
+ def device(self):
69
+ return self.vision_tower.device
70
+
71
+ @property
72
+ def config(self):
73
+ if self.is_loaded:
74
+ return self.vision_tower.config
75
+ else:
76
+ return self.cfg_only
77
+
78
+ @property
79
+ def hidden_size(self):
80
+ return self.config.hidden_size
81
+
82
+ @property
83
+ def num_patches_per_side(self):
84
+ return self.config.image_size // self.config.patch_size
85
+
86
+ @property
87
+ def num_patches(self):
88
+ return (self.config.image_size // self.config.patch_size) ** 2
dialoggen/llava/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+
5
+
6
+ class IdentityMap(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def forward(self, x, *args, **kwargs):
11
+ return x
12
+
13
+ @property
14
+ def config(self):
15
+ return {"mm_projector_type": 'identity'}
16
+
17
+
18
+ class SimpleResBlock(nn.Module):
19
+ def __init__(self, channels):
20
+ super().__init__()
21
+ self.pre_norm = nn.LayerNorm(channels)
22
+
23
+ self.proj = nn.Sequential(
24
+ nn.Linear(channels, channels),
25
+ nn.GELU(),
26
+ nn.Linear(channels, channels)
27
+ )
28
+ def forward(self, x):
29
+ x = self.pre_norm(x)
30
+ return x + self.proj(x)
31
+
32
+
33
+ def build_vision_projector(config, delay_load=False, **kwargs):
34
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
35
+
36
+ if projector_type == 'linear':
37
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
38
+
39
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
40
+ if mlp_gelu_match:
41
+ mlp_depth = int(mlp_gelu_match.group(1))
42
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43
+ for _ in range(1, mlp_depth):
44
+ modules.append(nn.GELU())
45
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46
+ return nn.Sequential(*modules)
47
+
48
+ if projector_type == 'identity':
49
+ return IdentityMap()
50
+
51
+ raise ValueError(f'Unknown projector type: {projector_type}')
dialoggen/llava/model/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig
2
+
3
+
4
+ def auto_upgrade(config):
5
+ cfg = AutoConfig.from_pretrained(config)
6
+ if 'llava' in config and 'llava' not in cfg.model_type:
7
+ assert cfg.model_type == 'llama'
8
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11
+ if confirm.lower() in ["y", "yes"]:
12
+ print("Upgrading checkpoint...")
13
+ assert len(cfg.architectures) == 1
14
+ setattr(cfg.__class__, "model_type", "llava")
15
+ cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16
+ cfg.save_pretrained(config)
17
+ print("Checkpoint upgraded.")
18
+ else:
19
+ print("Checkpoint upgrade aborted.")
20
+ exit(1)
dialoggen/llava/utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ import requests
8
+
9
+ from llava.constants import LOGDIR
10
+
11
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13
+
14
+ handler = None
15
+
16
+
17
+ def build_logger(logger_name, logger_filename):
18
+ global handler
19
+
20
+ formatter = logging.Formatter(
21
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22
+ datefmt="%Y-%m-%d %H:%M:%S",
23
+ )
24
+
25
+ # Set the format of root handlers
26
+ if not logging.getLogger().handlers:
27
+ logging.basicConfig(level=logging.INFO)
28
+ logging.getLogger().handlers[0].setFormatter(formatter)
29
+
30
+ # Redirect stdout and stderr to loggers
31
+ stdout_logger = logging.getLogger("stdout")
32
+ stdout_logger.setLevel(logging.INFO)
33
+ sl = StreamToLogger(stdout_logger, logging.INFO)
34
+ sys.stdout = sl
35
+
36
+ stderr_logger = logging.getLogger("stderr")
37
+ stderr_logger.setLevel(logging.ERROR)
38
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
39
+ sys.stderr = sl
40
+
41
+ # Get logger
42
+ logger = logging.getLogger(logger_name)
43
+ logger.setLevel(logging.INFO)
44
+
45
+ # Add a file handler for all loggers
46
+ if handler is None:
47
+ os.makedirs(LOGDIR, exist_ok=True)
48
+ filename = os.path.join(LOGDIR, logger_filename)
49
+ handler = logging.handlers.TimedRotatingFileHandler(
50
+ filename, when='D', utc=True, encoding='UTF-8')
51
+ handler.setFormatter(formatter)
52
+
53
+ for name, item in logging.root.manager.loggerDict.items():
54
+ if isinstance(item, logging.Logger):
55
+ item.addHandler(handler)
56
+
57
+ return logger
58
+
59
+
60
+ class StreamToLogger(object):
61
+ """
62
+ Fake file-like stream object that redirects writes to a logger instance.
63
+ """
64
+ def __init__(self, logger, log_level=logging.INFO):
65
+ self.terminal = sys.stdout
66
+ self.logger = logger
67
+ self.log_level = log_level
68
+ self.linebuf = ''
69
+
70
+ def __getattr__(self, attr):
71
+ return getattr(self.terminal, attr)
72
+
73
+ def write(self, buf):
74
+ temp_linebuf = self.linebuf + buf
75
+ self.linebuf = ''
76
+ for line in temp_linebuf.splitlines(True):
77
+ # From the io.TextIOWrapper docs:
78
+ # On output, if newline is None, any '\n' characters written
79
+ # are translated to the system default line separator.
80
+ # By default sys.stdout.write() expects '\n' newlines and then
81
+ # translates them so this is still cross platform.
82
+ if line[-1] == '\n':
83
+ self.logger.log(self.log_level, line.rstrip())
84
+ else:
85
+ self.linebuf += line
86
+
87
+ def flush(self):
88
+ if self.linebuf != '':
89
+ self.logger.log(self.log_level, self.linebuf.rstrip())
90
+ self.linebuf = ''
91
+
92
+
93
+ def disable_torch_init():
94
+ """
95
+ Disable the redundant torch default initialization to accelerate model creation.
96
+ """
97
+ import torch
98
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100
+
101
+
102
+ def violates_moderation(text):
103
+ """
104
+ Check whether the text violates OpenAI moderation API.
105
+ """
106
+ url = "https://api.openai.com/v1/moderations"
107
+ headers = {"Content-Type": "application/json",
108
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109
+ text = text.replace("\n", "")
110
+ data = "{" + '"input": ' + f'"{text}"' + "}"
111
+ data = data.encode("utf-8")
112
+ try:
113
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
114
+ flagged = ret.json()["results"][0]["flagged"]
115
+ except requests.exceptions.RequestException as e:
116
+ flagged = False
117
+ except KeyError as e:
118
+ flagged = False
119
+
120
+ return flagged
121
+
122
+
123
+ def pretty_print_semaphore(semaphore):
124
+ if semaphore is None:
125
+ return "None"
126
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
en.csv ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ key,value
2
+ size,Size
3
+ sampler,Sampler
4
+ prompt,Prompt
5
+ default prompt,"A cute cat"
6
+ negative_prompt,Negative Prompt
7
+ seed,Seed
8
+ cfg,CFG Scale
9
+ infer steps,Sampling Steps
10
+ batch size,Batch Size
11
+ width cond,Width Cond
12
+ height cond,Height Cond
13
+ enhance,Prompt Enhancement
14
+ run,Submit
15
+ square,Square(1024x1024)
16
+ landscape,Landscape(1280x768)
17
+ portrait,Portrait(768x1280)
18
+ accordion,Advanced Options
19
+ generated image,HunYuanDiT Generated Image
20
+ examples,More Examples
21
+ title,Hunyuan-DiT
22
+ desc,A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding
environment.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ name: HunyuanDiT
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python=3.8.12
7
+ - pytorch=1.13.1
8
+ - pip
example_prompts.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 一只聪明的狐狸走在阔叶树林里, 旁边是一条小溪, 细节真实, 摄影
2
+ 湖水清澈,天空湛蓝,阳光灿烂。一只优雅的白天鹅在湖边游泳。它周围有几只小鸭子,看起来非常可爱,整个画面给人一种宁静祥和的感觉。
3
+ 太阳微微升起,花园里的玫瑰花瓣上露珠晶莹剔透,一只瓢虫正在爬向露珠,背景是清晨的花园,微距镜头
4
+ 一位女明星,中国人,头发是黑色,衣服是纯白色短袖,人物风格清新,城市背景
5
+ 后印象主义风格,一条古老的石板路上面散落着金黄色的树叶。路旁的风车在静谧地转动,后面竖着两个风车。背景是一片向日葵田,蓝天上飘着几朵白云
6
+ 一幅细致的油画描绘了一只年轻獾轻轻嗅着一朵明亮的黄色玫瑰时错综复杂的皮毛。背景是一棵大树干的粗糙纹理,獾的爪子轻轻地挖进树皮。在柔和的背景中,一个宁静的瀑布倾泻而下,它的水在绿色植物中闪烁着蓝色。
7
+ 渔舟唱晚
8
+ 请将杞人忧天的样子画出来
9
+ 一只长靴猫手持亮银色的宝剑,身着铠甲,眼神坚毅,站在一堆金币上,背景是暗色调的洞穴,图像上有金币的光影点缀。
10
+ 插画风格,一只狐狸和一只刺猬坐在水边的石头上,刺猬手里拿着一杯茶,狐狸旁边放着一个玻璃杯。周围是茂密的绿色植物和树木,阳光透过树叶洒在水面上,画面宁静温馨。
11
+ 泥塑风格,一座五彩斑斓的花园在画面中展现,各种各样的花朵,绿色的叶子和一只正在嬉戏的小猫形成了一幅生动的图像,背景是蓝天和白云
12
+ 枯藤老树昏鸦,小桥流水人家
13
+ 一张细致的照片捕捉到了一尊雕像的形象,这尊雕像酷似一位古代法老,头上出人意料地戴着一副青铜蒸汽朋克护目镜。这座雕像穿着复古时髦,一件清爽的白色T恤和一件合身的黑色皮夹克,与传统的头饰形成鲜明对比。背景是简单的纯色,突出了雕像的非传统服装和蒸汽朋克眼镜的复杂细节。
14
+ 一朵鲜艳的红色玫瑰花,花瓣撒有一些水珠,晶莹剔透,特写镜头,
15
+ 一只可爱的猫, 细节真实, 摄影
16
+ 飞流直下三千尺,疑是银河落九天
17
+ 成语“鲤鱼跃龙门”
18
+ 一颗新鲜的草莓特写,红色的外表,表面布满许多种子,背景是淡绿色的叶子
19
+ 九寨沟
20
+ 摄影风格,在画面中心是一盘热气腾腾的麻婆豆腐,豆腐呈白色,上面撒着一层红色的辣酱,有些许绿色的葱花点缀,背景是深色木质餐桌,桌子上放有辣椒和葱花作为点缀。
21
+ 一位年轻女子站在春季的火车站月台上。她身着蓝灰色长风衣,白色衬衫。她的深棕色头发扎成低马尾,几缕碎发随风飘扬。她的眼神充满期待,阳光洒在她温暖的脸庞上。
22
+ 一只优雅的白鹤在湖边静静地站立,它的身体纯白色,翅膀轻轻展开,背景是湖面和远处的山脉
23
+ 国画风格,苏州园林中的小桥流水,周围是郁郁葱葱的树,池塘里有几朵绽放的荷花,背景是宁静的江南水乡
24
+ 现实主义风格,画面主要描述一个巴洛克风格的花瓶,带有金色的装饰边框,花瓶上盛开着各种色彩鲜艳的花,白色背景
25
+ 醉后不知天在水,满船清梦压星河
26
+ 长城
27
+ 一个亚洲中年男士在夕阳下的公园长椅上静坐。他穿着一件深蓝色的针织毛衣和灰色裤子。他的头发略显花白,手中拿着一本敞开的书。面带微笑,眼神温和,周围是落日余晖和四周的绿树。
28
+ 风格是写实,画面主要描述一个亚洲戏曲艺术家正在表演,她穿着华丽的戏服,脸上戴着精致的面具,身姿优雅,背景是古色古香的舞台,镜头是近景
hydit/__init__.py ADDED
File without changes
hydit/config.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from .constants import *
4
+ from .modules.models import HUNYUAN_DIT_CONFIG
5
+
6
+
7
+ def get_args(default_args=None):
8
+ parser = argparse.ArgumentParser()
9
+
10
+ # Basic
11
+ parser.add_argument("--prompt", type=str, default="一只小猫", help="The prompt for generating images.")
12
+ parser.add_argument("--model-root", type=str, default="ckpts", help="Model root path.")
13
+ parser.add_argument("--image-size", type=int, nargs='+', default=[1024, 1024],
14
+ help='Image size (h, w). If a single value is provided, the image will be treated to '
15
+ '(value, value).')
16
+ parser.add_argument("--infer-mode", type=str, choices=["fa", "torch", "trt"], default="torch",
17
+ help="Inference mode")
18
+
19
+ # HunYuan-DiT
20
+ parser.add_argument("--model", type=str, choices=list(HUNYUAN_DIT_CONFIG.keys()), default='DiT-g/2')
21
+ parser.add_argument("--norm", type=str, default="layer", help="Normalization layer type")
22
+ parser.add_argument("--load-key", type=str, choices=["ema", "module"], default="ema", help="Load model key for HunYuanDiT checkpoint.")
23
+ parser.add_argument('--size-cond', type=int, nargs='+', default=[1024, 1024],
24
+ help="Size condition used in sampling. 2 values are required for height and width. "
25
+ "If a single value is provided, the image will be treated to (value, value).")
26
+ parser.add_argument("--cfg-scale", type=float, default=6.0, help="Guidance scale for classifier-free.")
27
+
28
+ # Prompt enhancement
29
+ parser.add_argument("--enhance", action="store_true", help="Enhance prompt with dialoggen.")
30
+ parser.add_argument("--no-enhance", dest="enhance", action="store_false")
31
+ parser.set_defaults(enhance=True)
32
+
33
+ # Diffusion
34
+ parser.add_argument("--learn-sigma", action="store_true", help="Learn extra channels for sigma.")
35
+ parser.add_argument("--no-learn-sigma", dest="learn_sigma", action="store_false")
36
+ parser.set_defaults(learn_sigma=True)
37
+ parser.add_argument("--predict-type", type=str, choices=list(PREDICT_TYPE), default="v_prediction",
38
+ help="Diffusion predict type")
39
+ parser.add_argument("--noise-schedule", type=str, choices=list(NOISE_SCHEDULES), default="scaled_linear",
40
+ help="Noise schedule")
41
+ parser.add_argument("--beta-start", type=float, default=0.00085, help="Beta start value")
42
+ parser.add_argument("--beta-end", type=float, default=0.03, help="Beta end value")
43
+
44
+ # Text condition
45
+ parser.add_argument("--text-states-dim", type=int, default=1024, help="Hidden size of CLIP text encoder.")
46
+ parser.add_argument("--text-len", type=int, default=77, help="Token length of CLIP text encoder output.")
47
+ parser.add_argument("--text-states-dim-t5", type=int, default=2048, help="Hidden size of CLIP text encoder.")
48
+ parser.add_argument("--text-len-t5", type=int, default=256, help="Token length of T5 text encoder output.")
49
+ parser.add_argument("--negative", type=str, default=None, help="Negative prompt.")
50
+
51
+ # Acceleration
52
+ parser.add_argument("--use_fp16", action="store_true", help="Use FP16 precision.")
53
+ parser.add_argument("--no-fp16", dest="use_fp16", action="store_false")
54
+ parser.set_defaults(use_fp16=True)
55
+
56
+ # Sampling
57
+ parser.add_argument("--batch-size", type=int, default=1, help="Per-GPU batch size")
58
+ parser.add_argument("--sampler", type=str, choices=SAMPLER_FACTORY, default="ddpm", help="Diffusion sampler")
59
+ parser.add_argument("--infer-steps", type=int, default=100, help="Inference steps")
60
+ parser.add_argument('--seed', type=int, default=42, help="A seed for all the prompts.")
61
+
62
+ # App
63
+ parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="Language")
64
+
65
+ args = parser.parse_args(default_args)
66
+
67
+ return args
hydit/constants.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =======================================================
2
+ NOISE_SCHEDULES = {
3
+ "linear",
4
+ "scaled_linear",
5
+ "squaredcos_cap_v2",
6
+ }
7
+
8
+ PREDICT_TYPE = {
9
+ "epsilon",
10
+ "sample",
11
+ "v_prediction",
12
+ }
13
+
14
+ # =======================================================
15
+ NEGATIVE_PROMPT = '错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,'
16
+
17
+
18
+ # =======================================================
19
+ # Constants about models
20
+ # =======================================================
21
+
22
+ SAMPLER_FACTORY = {
23
+ 'ddpm': {
24
+ 'scheduler': 'DDPMScheduler',
25
+ 'name': 'DDPM',
26
+ 'kwargs': {
27
+ 'steps_offset': 1,
28
+ 'clip_sample': False,
29
+ 'clip_sample_range': 1.0,
30
+ 'beta_schedule': 'scaled_linear',
31
+ 'beta_start': 0.00085,
32
+ 'beta_end': 0.03,
33
+ 'prediction_type': 'v_prediction',
34
+ }
35
+ },
36
+ 'ddim': {
37
+ 'scheduler': 'DDIMScheduler',
38
+ 'name': 'DDIM',
39
+ 'kwargs': {
40
+ 'steps_offset': 1,
41
+ 'clip_sample': False,
42
+ 'clip_sample_range': 1.0,
43
+ 'beta_schedule': 'scaled_linear',
44
+ 'beta_start': 0.00085,
45
+ 'beta_end': 0.03,
46
+ 'prediction_type': 'v_prediction',
47
+ }
48
+ },
49
+ 'dpmms': {
50
+ 'scheduler': 'DPMSolverMultistepScheduler',
51
+ 'name': 'DPMMS',
52
+ 'kwargs': {
53
+ 'beta_schedule': 'scaled_linear',
54
+ 'beta_start': 0.00085,
55
+ 'beta_end': 0.03,
56
+ 'prediction_type': 'v_prediction',
57
+ 'trained_betas': None,
58
+ 'solver_order': 2,
59
+ 'algorithm_type': 'dpmsolver++',
60
+ }
61
+ },
62
+ }
hydit/diffusion/__init__.py ADDED
File without changes
hydit/diffusion/pipeline.py ADDED
@@ -0,0 +1,830 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ import inspect
14
+ from typing import Any, Callable, Dict, List, Optional, Union
15
+
16
+ import PIL
17
+ import numpy as np
18
+ import torch
19
+ import torchvision.transforms as T
20
+ from diffusers.configuration_utils import FrozenDict
21
+ from diffusers.image_processor import VaeImageProcessor
22
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
23
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
24
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
27
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
28
+ from diffusers.schedulers import KarrasDiffusionSchedulers
29
+ from diffusers.utils import (
30
+ PIL_INTERPOLATION,
31
+ deprecate,
32
+ logging,
33
+ replace_example_docstring,
34
+ )
35
+ from diffusers.utils.torch_utils import randn_tensor
36
+ from transformers import BertModel, BertTokenizer
37
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
38
+
39
+ from ..modules.models import HunYuanDiT
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+ EXAMPLE_DOC_STRING = """
44
+ Examples:
45
+ ```py
46
+ >>> import requests
47
+ >>> import torch
48
+ >>> from PIL import Image
49
+ >>> from io import BytesIO
50
+
51
+ >>> from diffusers import StableDiffusionImg2ImgPipeline
52
+
53
+ >>> device = "cuda"
54
+ >>> model_id_or_path = "runwayml/stable-diffusion-v1-5"
55
+ >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
56
+ >>> pipe = pipe.to(device)
57
+
58
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
59
+
60
+ >>> response = requests.get(url)
61
+ >>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
62
+ >>> init_image = init_image.resize((768, 512))
63
+
64
+ >>> prompt = "A fantasy landscape, trending on artstation"
65
+
66
+ >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
67
+ >>> images[0].save("fantasy_landscape.png")
68
+ ```
69
+ """
70
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
71
+ """
72
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
73
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
74
+ """
75
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
76
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
77
+ # rescale the results from guidance (fixes overexposure)
78
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
79
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
80
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
81
+ return noise_cfg
82
+
83
+ def preprocess(image):
84
+ deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
85
+ deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
86
+ if isinstance(image, torch.Tensor):
87
+ return image
88
+ elif isinstance(image, PIL.Image.Image):
89
+ image = [image]
90
+
91
+ if isinstance(image[0], PIL.Image.Image):
92
+ w, h = image[0].size
93
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
94
+
95
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
96
+ image = np.concatenate(image, axis=0)
97
+ image = np.array(image).astype(np.float32) / 255.0
98
+ image = image.transpose(0, 3, 1, 2)
99
+ image = 2.0 * image - 1.0
100
+ image = torch.from_numpy(image)
101
+ elif isinstance(image[0], torch.Tensor):
102
+ image = torch.cat(image, dim=0)
103
+ return image
104
+
105
+
106
+ class StableDiffusionPipeline(
107
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
108
+ ):
109
+ r"""
110
+ Pipeline for text-guided image-to-image generation using Stable Diffusion.
111
+
112
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
113
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
114
+
115
+ The pipeline also inherits the following loading methods:
116
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
117
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
118
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
119
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
120
+
121
+ Args:
122
+ vae ([`AutoencoderKL`]):
123
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
124
+ text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
125
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
126
+ tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
127
+ A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
128
+ unet (Optional[`HunYuanDiT`, `UNet2DConditionModel`]):
129
+ A `UNet2DConditionModel` to denoise the encoded image latents.
130
+ scheduler ([`SchedulerMixin`]):
131
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
132
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
133
+ safety_checker ([`StableDiffusionSafetyChecker`]):
134
+ Classification module that estimates whether generated images could be considered offensive or harmful.
135
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
136
+ about a model's potential harms.
137
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
138
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
139
+ """
140
+ model_cpu_offload_seq = "text_encoder->unet->vae"
141
+ _optional_components = ["safety_checker", "feature_extractor"]
142
+ _exclude_from_cpu_offload = ["safety_checker"]
143
+
144
+ def __init__(
145
+ self,
146
+ vae: AutoencoderKL,
147
+ text_encoder: Union[BertModel, CLIPTextModel],
148
+ tokenizer: Union[BertTokenizer, CLIPTokenizer],
149
+ unet: Union[HunYuanDiT, UNet2DConditionModel],
150
+ scheduler: KarrasDiffusionSchedulers,
151
+ safety_checker: StableDiffusionSafetyChecker,
152
+ feature_extractor: CLIPImageProcessor,
153
+ requires_safety_checker: bool = True,
154
+ progress_bar_config: Dict[str, Any] = None,
155
+ embedder_t5=None,
156
+ infer_mode='torch',
157
+ ):
158
+ super().__init__()
159
+
160
+ # ========================================================
161
+ self.embedder_t5 = embedder_t5
162
+ self.infer_mode = infer_mode
163
+
164
+ # ========================================================
165
+ if progress_bar_config is None:
166
+ progress_bar_config = {}
167
+ if not hasattr(self, '_progress_bar_config'):
168
+ self._progress_bar_config = {}
169
+ self._progress_bar_config.update(progress_bar_config)
170
+ # ========================================================
171
+
172
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
173
+ deprecation_message = (
174
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
175
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
176
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
177
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
178
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
179
+ " file"
180
+ )
181
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
182
+ new_config = dict(scheduler.config)
183
+ new_config["steps_offset"] = 1
184
+ scheduler._internal_dict = FrozenDict(new_config)
185
+
186
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
187
+ deprecation_message = (
188
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
189
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
190
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
191
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
192
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
193
+ )
194
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
195
+ new_config = dict(scheduler.config)
196
+ new_config["clip_sample"] = False
197
+ scheduler._internal_dict = FrozenDict(new_config)
198
+
199
+ if safety_checker is None and requires_safety_checker:
200
+ logger.warning(
201
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
202
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
203
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
204
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
205
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
206
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
207
+ )
208
+
209
+ if safety_checker is not None and feature_extractor is None:
210
+ raise ValueError(
211
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
212
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
213
+ )
214
+
215
+ self.register_modules(
216
+ vae=vae,
217
+ text_encoder=text_encoder,
218
+ tokenizer=tokenizer,
219
+ unet=unet,
220
+ scheduler=scheduler,
221
+ safety_checker=safety_checker,
222
+ feature_extractor=feature_extractor,
223
+ )
224
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
225
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
226
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
227
+
228
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
229
+ def _encode_prompt(
230
+ self,
231
+ prompt,
232
+ device,
233
+ num_images_per_prompt,
234
+ do_classifier_free_guidance,
235
+ negative_prompt=None,
236
+ prompt_embeds: Optional[torch.FloatTensor] = None,
237
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
238
+ lora_scale: Optional[float] = None,
239
+ ):
240
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
241
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
242
+
243
+ prompt_embeds_tuple = self.encode_prompt(
244
+ prompt=prompt,
245
+ device=device,
246
+ num_images_per_prompt=num_images_per_prompt,
247
+ do_classifier_free_guidance=do_classifier_free_guidance,
248
+ negative_prompt=negative_prompt,
249
+ prompt_embeds=prompt_embeds,
250
+ negative_prompt_embeds=negative_prompt_embeds,
251
+ lora_scale=lora_scale,
252
+ )
253
+
254
+ # concatenate for backwards comp
255
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
256
+
257
+ return prompt_embeds
258
+
259
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
260
+ def encode_prompt(
261
+ self,
262
+ prompt,
263
+ device,
264
+ num_images_per_prompt,
265
+ do_classifier_free_guidance,
266
+ negative_prompt=None,
267
+ prompt_embeds: Optional[torch.FloatTensor] = None,
268
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
269
+ lora_scale: Optional[float] = None,
270
+ embedder=None,
271
+ ):
272
+ r"""
273
+ Encodes the prompt into text encoder hidden states.
274
+
275
+ Args:
276
+ prompt (`str` or `List[str]`, *optional*):
277
+ prompt to be encoded
278
+ device: (`torch.device`):
279
+ torch device
280
+ num_images_per_prompt (`int`):
281
+ number of images that should be generated per prompt
282
+ do_classifier_free_guidance (`bool`):
283
+ whether to use classifier free guidance or not
284
+ negative_prompt (`str` or `List[str]`, *optional*):
285
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
286
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
287
+ less than `1`).
288
+ prompt_embeds (`torch.FloatTensor`, *optional*):
289
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
290
+ provided, text embeddings will be generated from `prompt` input argument.
291
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
292
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
293
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
294
+ argument.
295
+ lora_scale (`float`, *optional*):
296
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
297
+ embedder:
298
+ T5 embedder (including text encoder and tokenizer)
299
+ """
300
+ if embedder is None:
301
+ text_encoder = self.text_encoder
302
+ tokenizer = self.tokenizer
303
+ max_length = self.tokenizer.model_max_length
304
+ else:
305
+ text_encoder = embedder.model
306
+ tokenizer = embedder.tokenizer
307
+ max_length = embedder.max_length
308
+
309
+ # set lora scale so that monkey patched LoRA
310
+ # function of text encoder can correctly access it
311
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
312
+ self._lora_scale = lora_scale
313
+
314
+ # dynamically adjust the LoRA scale
315
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
316
+
317
+ if prompt is not None and isinstance(prompt, str):
318
+ batch_size = 1
319
+ elif prompt is not None and isinstance(prompt, list):
320
+ batch_size = len(prompt)
321
+ else:
322
+ batch_size = prompt_embeds.shape[0]
323
+
324
+ if prompt_embeds is None:
325
+ # textual inversion: procecss multi-vector tokens if necessary
326
+ if isinstance(self, TextualInversionLoaderMixin):
327
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
328
+
329
+ text_inputs = tokenizer(
330
+ prompt,
331
+ padding="max_length",
332
+ max_length=max_length,
333
+ truncation=True,
334
+ return_attention_mask=True,
335
+ return_tensors="pt",
336
+ )
337
+ text_input_ids = text_inputs.input_ids
338
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
339
+
340
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
341
+ text_input_ids, untruncated_ids
342
+ ):
343
+ removed_text = tokenizer.batch_decode(
344
+ untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
345
+ )
346
+ logger.warning(
347
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
348
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
349
+ )
350
+
351
+ attention_mask = text_inputs.attention_mask.to(device)
352
+ prompt_embeds = text_encoder(
353
+ text_input_ids.to(device),
354
+ attention_mask=attention_mask,
355
+ )
356
+ prompt_embeds = prompt_embeds[0]
357
+ attention_mask = attention_mask.repeat(num_images_per_prompt, 1)
358
+ else:
359
+ attention_mask = None
360
+
361
+ if text_encoder is not None:
362
+ prompt_embeds_dtype = text_encoder.dtype
363
+ elif self.unet is not None:
364
+ prompt_embeds_dtype = self.unet.dtype
365
+ else:
366
+ prompt_embeds_dtype = prompt_embeds.dtype
367
+
368
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
369
+
370
+ bs_embed, seq_len, _ = prompt_embeds.shape
371
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
372
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
373
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
374
+
375
+ # get unconditional embeddings for classifier free guidance
376
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
377
+ uncond_tokens: List[str]
378
+ if negative_prompt is None:
379
+ uncond_tokens = [""] * batch_size
380
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
381
+ raise TypeError(
382
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
383
+ f" {type(prompt)}."
384
+ )
385
+ elif isinstance(negative_prompt, str):
386
+ uncond_tokens = [negative_prompt]
387
+ elif batch_size != len(negative_prompt):
388
+ raise ValueError(
389
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
390
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
391
+ " the batch size of `prompt`."
392
+ )
393
+ else:
394
+ uncond_tokens = negative_prompt
395
+
396
+ # textual inversion: procecss multi-vector tokens if necessary
397
+ if isinstance(self, TextualInversionLoaderMixin):
398
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
399
+
400
+ max_length = prompt_embeds.shape[1]
401
+ uncond_input = tokenizer(
402
+ uncond_tokens,
403
+ padding="max_length",
404
+ max_length=max_length,
405
+ truncation=True,
406
+ return_tensors="pt",
407
+ )
408
+
409
+ uncond_attention_mask = uncond_input.attention_mask.to(device)
410
+ negative_prompt_embeds = text_encoder(
411
+ uncond_input.input_ids.to(device),
412
+ attention_mask=uncond_attention_mask,
413
+ )
414
+ negative_prompt_embeds = negative_prompt_embeds[0]
415
+ uncond_attention_mask = uncond_attention_mask.repeat(num_images_per_prompt, 1)
416
+ else:
417
+ uncond_attention_mask = None
418
+
419
+ if do_classifier_free_guidance:
420
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
421
+ seq_len = negative_prompt_embeds.shape[1]
422
+
423
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
424
+
425
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
426
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
427
+
428
+ return prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask
429
+
430
+ def _convert_to_rgb(self, image):
431
+ return image.convert('RGB')
432
+
433
+ def image_transform(self, image_size=224):
434
+ transform = T.Compose([
435
+ T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC),
436
+ self._convert_to_rgb,
437
+ T.ToTensor(),
438
+ T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
439
+ ])
440
+ return transform
441
+
442
+ def encode_img(self, img, device, do_classifier_free_guidance):
443
+ # print('len', len(img))
444
+ # print('img', img.size)
445
+ img = img[0] # TODO: support batch processing
446
+ image_preprocess = self.image_transform(224)
447
+ img_for_clip = image_preprocess(img)
448
+ # print('img_for_clip', img_for_clip.shape)
449
+ img_for_clip = img_for_clip.unsqueeze(0)
450
+ img_clip_embedding = self.img_encoder(img_for_clip.to(device)).to(dtype=torch.float16)
451
+ # print('img_clip_embedding_1_type', img_clip_embedding.dtype)
452
+ if do_classifier_free_guidance:
453
+ negative_img_clip_embedding = torch.zeros_like(img_clip_embedding)
454
+ return img_clip_embedding, negative_img_clip_embedding
455
+
456
+
457
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
458
+ def run_safety_checker(self, image, device, dtype):
459
+ if self.safety_checker is None:
460
+ has_nsfw_concept = None
461
+ else:
462
+ if torch.is_tensor(image):
463
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
464
+ else:
465
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
466
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
467
+ image, has_nsfw_concept = self.safety_checker(
468
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
469
+ )
470
+ return image, has_nsfw_concept
471
+
472
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
473
+ def decode_latents(self, latents):
474
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
475
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
476
+
477
+ latents = 1 / self.vae.config.scaling_factor * latents
478
+ image = self.vae.decode(latents, return_dict=False)[0]
479
+ image = (image / 2 + 0.5).clamp(0, 1)
480
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
481
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
482
+ return image
483
+
484
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
485
+ def prepare_extra_step_kwargs(self, generator, eta):
486
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
487
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
488
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
489
+ # and should be between [0, 1]
490
+
491
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
492
+ extra_step_kwargs = {}
493
+ if accepts_eta:
494
+ extra_step_kwargs["eta"] = eta
495
+
496
+ # check if the scheduler accepts generator
497
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
498
+ if accepts_generator:
499
+ extra_step_kwargs["generator"] = generator
500
+ return extra_step_kwargs
501
+
502
+ def check_inputs(
503
+ self,
504
+ prompt,
505
+ height,
506
+ width,
507
+ callback_steps,
508
+ negative_prompt=None,
509
+ prompt_embeds=None,
510
+ negative_prompt_embeds=None,
511
+ ):
512
+ if height % 8 != 0 or width % 8 != 0:
513
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
514
+
515
+ if (callback_steps is None) or (
516
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
517
+ ):
518
+ raise ValueError(
519
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
520
+ f" {type(callback_steps)}."
521
+ )
522
+
523
+ if prompt is not None and prompt_embeds is not None:
524
+ raise ValueError(
525
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
526
+ " only forward one of the two."
527
+ )
528
+ elif prompt is None and prompt_embeds is None:
529
+ raise ValueError(
530
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
531
+ )
532
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
533
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
534
+
535
+ if negative_prompt is not None and negative_prompt_embeds is not None:
536
+ raise ValueError(
537
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
538
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
539
+ )
540
+
541
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
542
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
543
+ raise ValueError(
544
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
545
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
546
+ f" {negative_prompt_embeds.shape}."
547
+ )
548
+
549
+ def get_timesteps(self, num_inference_steps, strength, device):
550
+ # get the original timestep using init_timestep
551
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
552
+
553
+ t_start = max(num_inference_steps - init_timestep, 0)
554
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
555
+
556
+ return timesteps, num_inference_steps - t_start
557
+
558
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
559
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
560
+ if isinstance(generator, list) and len(generator) != batch_size:
561
+ raise ValueError(
562
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
563
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
564
+ )
565
+
566
+ if latents is None:
567
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
568
+ else:
569
+ latents = latents.to(device)
570
+
571
+ # scale the initial noise by the standard deviation required by the scheduler
572
+ latents = latents * self.scheduler.init_noise_sigma
573
+ return latents
574
+
575
+ @torch.no_grad()
576
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
577
+ def __call__(
578
+ self,
579
+ height: int,
580
+ width: int,
581
+ prompt: Union[str, List[str]] = None,
582
+ num_inference_steps: Optional[int] = 50,
583
+ guidance_scale: Optional[float] = 7.5,
584
+ negative_prompt: Optional[Union[str, List[str]]] = None,
585
+ num_images_per_prompt: Optional[int] = 1,
586
+ eta: Optional[float] = 0.0,
587
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
588
+ latents: Optional[torch.FloatTensor] = None,
589
+ prompt_embeds: Optional[torch.FloatTensor] = None,
590
+ prompt_embeds_t5: Optional[torch.FloatTensor] = None,
591
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
592
+ negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None,
593
+ output_type: Optional[str] = "pil",
594
+ return_dict: bool = True,
595
+ callback: Optional[Callable[[int, int, torch.FloatTensor, torch.FloatTensor], None]] = None,
596
+ callback_steps: int = 1,
597
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
598
+ guidance_rescale: float = 0.0,
599
+ image_meta_size: Optional[torch.LongTensor] = None,
600
+ style: Optional[torch.LongTensor] = None,
601
+ progress: bool = True,
602
+ use_fp16: bool = False,
603
+ freqs_cis_img: Optional[tuple] = None,
604
+ learn_sigma: bool = True,
605
+ ):
606
+ r"""
607
+ The call function to the pipeline for generation.
608
+
609
+ Args:
610
+ height (`int`):
611
+ The height in pixels of the generated image.
612
+ width (`int`):
613
+ The width in pixels of the generated image.
614
+ prompt (`str` or `List[str]`, *optional*):
615
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
616
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
617
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
618
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
619
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
620
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
621
+ latents as `image`, but if passing latents directly it is not encoded again.
622
+ strength (`float`, *optional*, defaults to 1.0):
623
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
624
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
625
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
626
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
627
+ essentially ignores `image`.
628
+ num_inference_steps (`int`, *optional*, defaults to 50):
629
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
630
+ expense of slower inference. This parameter is modulated by `strength`.
631
+ guidance_scale (`float`, *optional*, defaults to 7.5):
632
+ A higher guidance scale value encourages the model to generate images closely linked to the text
633
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
634
+ negative_prompt (`str` or `List[str]`, *optional*):
635
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
636
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
637
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
638
+ The number of images to generate per prompt.
639
+ eta (`float`, *optional*, defaults to 0.0):
640
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
641
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
642
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
643
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
644
+ generation deterministic.
645
+ prompt_embeds (`torch.FloatTensor`, *optional*):
646
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
647
+ provided, text embeddings are generated from the `prompt` input argument.
648
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
649
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
650
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
651
+ output_type (`str`, *optional*, defaults to `"pil"`):
652
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
653
+ return_dict (`bool`, *optional*, defaults to `True`):
654
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
655
+ plain tuple.
656
+ callback (`Callable`, *optional*):
657
+ A function that calls every `callback_steps` steps during inference. The function is called with the
658
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor,
659
+ pred_x0: torch.FloatTensor)`.
660
+ callback_steps (`int`, *optional*, defaults to 1):
661
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
662
+ every step.
663
+ cross_attention_kwargs (`dict`, *optional*):
664
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
665
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
666
+
667
+ Examples:
668
+
669
+ Returns:
670
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
671
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
672
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
673
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
674
+ "not-safe-for-work" (nsfw) content.
675
+ """
676
+ # 1. Check inputs. Raise error if not correct
677
+ self.check_inputs(
678
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
679
+ )
680
+
681
+ # 2. Define call parameters
682
+ if prompt is not None and isinstance(prompt, str):
683
+ batch_size = 1
684
+ elif prompt is not None and isinstance(prompt, list):
685
+ batch_size = len(prompt)
686
+ else:
687
+ batch_size = prompt_embeds.shape[0]
688
+
689
+ device = self._execution_device
690
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
691
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
692
+ # corresponds to doing no classifier free guidance.
693
+ do_classifier_free_guidance = guidance_scale > 1.0
694
+
695
+ # 3. Encode input prompt
696
+ text_encoder_lora_scale = (
697
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
698
+ )
699
+
700
+ prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask = \
701
+ self.encode_prompt(prompt,
702
+ device,
703
+ num_images_per_prompt,
704
+ do_classifier_free_guidance,
705
+ negative_prompt,
706
+ prompt_embeds=prompt_embeds,
707
+ negative_prompt_embeds=negative_prompt_embeds,
708
+ lora_scale=text_encoder_lora_scale,
709
+ )
710
+ prompt_embeds_t5, negative_prompt_embeds_t5, attention_mask_t5, uncond_attention_mask_t5 = \
711
+ self.encode_prompt(prompt,
712
+ device,
713
+ num_images_per_prompt,
714
+ do_classifier_free_guidance,
715
+ negative_prompt,
716
+ prompt_embeds=prompt_embeds_t5,
717
+ negative_prompt_embeds=negative_prompt_embeds_t5,
718
+ lora_scale=text_encoder_lora_scale,
719
+ embedder=self.embedder_t5,
720
+ )
721
+
722
+ # For classifier free guidance, we need to do two forward passes.
723
+ # Here we concatenate the unconditional and text embeddings into a single batch
724
+ # to avoid doing two forward passes
725
+ if do_classifier_free_guidance:
726
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
727
+ attention_mask = torch.cat([uncond_attention_mask, attention_mask])
728
+ prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5])
729
+ attention_mask_t5 = torch.cat([uncond_attention_mask_t5, attention_mask_t5])
730
+
731
+ # 4. Prepare timesteps
732
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
733
+ timesteps = self.scheduler.timesteps
734
+
735
+ # 6. Prepare latent variables
736
+ num_channels_latents = self.unet.config.in_channels
737
+ latents = self.prepare_latents(batch_size * num_images_per_prompt,
738
+ num_channels_latents,
739
+ height,
740
+ width,
741
+ prompt_embeds.dtype,
742
+ device,
743
+ generator,
744
+ latents,
745
+ )
746
+
747
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
748
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
749
+
750
+ # 8. Denoising loop
751
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
752
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
753
+ for i, t in enumerate(timesteps):
754
+ # expand the latents if we are doing classifier free guidance
755
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
756
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
757
+ # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
758
+ t_expand = torch.tensor([t] * latent_model_input.shape[0], device=latent_model_input.device)
759
+
760
+ if use_fp16:
761
+ latent_model_input = latent_model_input.half()
762
+ t_expand = t_expand.half()
763
+ prompt_embeds = prompt_embeds.half()
764
+ ims = image_meta_size.half() if image_meta_size is not None else None
765
+ else:
766
+ ims = image_meta_size if image_meta_size is not None else None
767
+
768
+ # predict the noise residual
769
+ if self.infer_mode in ["fa", "torch"]:
770
+ noise_pred = self.unet(
771
+ latent_model_input,
772
+ t_expand,
773
+ encoder_hidden_states=prompt_embeds,
774
+ text_embedding_mask=attention_mask,
775
+ encoder_hidden_states_t5=prompt_embeds_t5,
776
+ text_embedding_mask_t5=attention_mask_t5,
777
+ image_meta_size=ims,
778
+ style=style,
779
+ cos_cis_img=freqs_cis_img[0],
780
+ sin_cis_img=freqs_cis_img[1],
781
+ return_dict=False,
782
+ )
783
+ elif self.infer_mode == "trt":
784
+ raise NotImplementedError("TensorRT model is not supported yet.")
785
+ else:
786
+ raise ValueError("[ERROR] invalid inference mode! please check your config file")
787
+ if learn_sigma:
788
+ noise_pred, _ = noise_pred.chunk(2, dim=1)
789
+
790
+ # perform guidance
791
+ if do_classifier_free_guidance:
792
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
793
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
794
+
795
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
796
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
797
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
798
+
799
+ # compute the previous noisy sample x_t -> x_t-1
800
+ results = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
801
+ latents = results.prev_sample
802
+ pred_x0 = results.pred_original_sample if hasattr(results, 'pred_original_sample') else None
803
+
804
+ # call the callback, if provided
805
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
806
+ progress_bar.update()
807
+ if callback is not None and i % callback_steps == 0:
808
+ callback(i, t, latents, pred_x0)
809
+
810
+ if not output_type == "latent":
811
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
812
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
813
+ else:
814
+ image = latents
815
+ has_nsfw_concept = None
816
+
817
+ if has_nsfw_concept is None:
818
+ do_denormalize = [True] * image.shape[0]
819
+ else:
820
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
821
+
822
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
823
+
824
+ # Offload all models
825
+ self.maybe_free_model_hooks()
826
+
827
+ if not return_dict:
828
+ return (image, has_nsfw_concept)
829
+
830
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
hydit/inference.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import time
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ # For reproducibility
9
+ # torch.backends.cudnn.benchmark = False
10
+ # torch.backends.cudnn.deterministic = True
11
+
12
+ from diffusers import schedulers
13
+ from diffusers.models import AutoencoderKL
14
+ from loguru import logger
15
+ from transformers import BertModel, BertTokenizer
16
+ from transformers.modeling_utils import logger as tf_logger
17
+
18
+ from .constants import SAMPLER_FACTORY, NEGATIVE_PROMPT
19
+ from .diffusion.pipeline import StableDiffusionPipeline
20
+ from .modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG
21
+ from .modules.posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
22
+ from .modules.text_encoder import MT5Embedder
23
+ from .utils.tools import set_seeds
24
+
25
+
26
+ class Resolution:
27
+ def __init__(self, width, height):
28
+ self.width = width
29
+ self.height = height
30
+
31
+ def __str__(self):
32
+ return f'{self.height}x{self.width}'
33
+
34
+
35
+ class ResolutionGroup:
36
+ def __init__(self):
37
+ self.data = [
38
+ Resolution(768, 768), # 1:1
39
+ Resolution(1024, 1024), # 1:1
40
+ Resolution(1280, 1280), # 1:1
41
+ Resolution(1024, 768), # 4:3
42
+ Resolution(1152, 864), # 4:3
43
+ Resolution(1280, 960), # 4:3
44
+ Resolution(768, 1024), # 3:4
45
+ Resolution(864, 1152), # 3:4
46
+ Resolution(960, 1280), # 3:4
47
+ Resolution(1280, 768), # 16:9
48
+ Resolution(768, 1280), # 9:16
49
+ ]
50
+ self.supported_sizes = set([(r.width, r.height) for r in self.data])
51
+
52
+ def is_valid(self, width, height):
53
+ return (width, height) in self.supported_sizes
54
+
55
+
56
+ STANDARD_RATIO = np.array([
57
+ 1.0, # 1:1
58
+ 4.0 / 3.0, # 4:3
59
+ 3.0 / 4.0, # 3:4
60
+ 16.0 / 9.0, # 16:9
61
+ 9.0 / 16.0, # 9:16
62
+ ])
63
+ STANDARD_SHAPE = [
64
+ [(768, 768), (1024, 1024), (1280, 1280)], # 1:1
65
+ [(1024, 768), (1152, 864), (1280, 960)], # 4:3
66
+ [(768, 1024), (864, 1152), (960, 1280)], # 3:4
67
+ [(1280, 768)], # 16:9
68
+ [(768, 1280)], # 9:16
69
+ ]
70
+ STANDARD_AREA = [
71
+ np.array([w * h for w, h in shapes])
72
+ for shapes in STANDARD_SHAPE
73
+ ]
74
+
75
+
76
+ def get_standard_shape(target_width, target_height):
77
+ """
78
+ Map image size to standard size.
79
+ """
80
+ target_ratio = target_width / target_height
81
+ closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
82
+ closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
83
+ width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
84
+ return width, height
85
+
86
+
87
+ def _to_tuple(val):
88
+ if isinstance(val, (list, tuple)):
89
+ if len(val) == 1:
90
+ val = [val[0], val[0]]
91
+ elif len(val) == 2:
92
+ val = tuple(val)
93
+ else:
94
+ raise ValueError(f"Invalid value: {val}")
95
+ elif isinstance(val, (int, float)):
96
+ val = (val, val)
97
+ else:
98
+ raise ValueError(f"Invalid value: {val}")
99
+ return val
100
+
101
+
102
+ def get_pipeline(args, vae, text_encoder, tokenizer, model, device, rank,
103
+ embedder_t5, infer_mode, sampler=None):
104
+ """
105
+ Get scheduler and pipeline for sampling. The sampler and pipeline are both
106
+ based on diffusers and make some modifications.
107
+
108
+ Returns
109
+ -------
110
+ pipeline: StableDiffusionPipeline
111
+ sampler_name: str
112
+ """
113
+ sampler = sampler or args.sampler
114
+
115
+ # Load sampler from factory
116
+ kwargs = SAMPLER_FACTORY[sampler]['kwargs']
117
+ scheduler = SAMPLER_FACTORY[sampler]['scheduler']
118
+
119
+ # Update sampler according to the arguments
120
+ kwargs['beta_schedule'] = args.noise_schedule
121
+ kwargs['beta_start'] = args.beta_start
122
+ kwargs['beta_end'] = args.beta_end
123
+ kwargs['prediction_type'] = args.predict_type
124
+
125
+ # Build scheduler according to the sampler.
126
+ scheduler_class = getattr(schedulers, scheduler)
127
+ scheduler = scheduler_class(**kwargs)
128
+
129
+ # Set timesteps for inference steps.
130
+ scheduler.set_timesteps(args.infer_steps, device)
131
+
132
+ # Only enable progress bar for rank 0
133
+ progress_bar_config = {} if rank == 0 else {'disable': True}
134
+
135
+ pipeline = StableDiffusionPipeline(vae=vae,
136
+ text_encoder=text_encoder,
137
+ tokenizer=tokenizer,
138
+ unet=model,
139
+ scheduler=scheduler,
140
+ feature_extractor=None,
141
+ safety_checker=None,
142
+ requires_safety_checker=False,
143
+ progress_bar_config=progress_bar_config,
144
+ embedder_t5=embedder_t5,
145
+ infer_mode=infer_mode,
146
+ )
147
+
148
+ pipeline = pipeline.to(device)
149
+
150
+ return pipeline, sampler
151
+
152
+
153
+ class End2End(object):
154
+ def __init__(self, args, models_root_path):
155
+ self.args = args
156
+
157
+ # Check arguments
158
+ t2i_root_path = Path(models_root_path) / "t2i"
159
+ self.root = t2i_root_path
160
+ logger.info(f"Got text-to-image model root path: {t2i_root_path}")
161
+
162
+ # Set device and disable gradient
163
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
164
+ torch.set_grad_enabled(False)
165
+ # Disable BertModel logging checkpoint info
166
+ tf_logger.setLevel('ERROR')
167
+
168
+ # ========================================================================
169
+ model_dir = self.root / "model"
170
+
171
+ # ========================================================================
172
+ logger.info(f"Loading CLIP Text Encoder...")
173
+ text_encoder_path = self.root / "clip_text_encoder"
174
+ self.clip_text_encoder = BertModel.from_pretrained(str(text_encoder_path), False, revision=None).to(self.device)
175
+ logger.info(f"Loading CLIP Text Encoder finished")
176
+
177
+ # ========================================================================
178
+ logger.info(f"Loading CLIP Tokenizer...")
179
+ tokenizer_path = self.root / "tokenizer"
180
+ self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path))
181
+ logger.info(f"Loading CLIP Tokenizer finished")
182
+
183
+ # ========================================================================
184
+ logger.info(f"Loading T5 Text Encoder and T5 Tokenizer...")
185
+ t5_text_encoder_path = self.root / 'mt5'
186
+ embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256)
187
+ self.embedder_t5 = embedder_t5
188
+ logger.info(f"Loading t5_text_encoder and t5_tokenizer finished")
189
+
190
+ # ========================================================================
191
+ logger.info(f"Loading VAE...")
192
+ vae_path = self.root / "sdxl-vae-fp16-fix"
193
+ self.vae = AutoencoderKL.from_pretrained(str(vae_path)).to(self.device)
194
+ logger.info(f"Loading VAE finished")
195
+
196
+ # ========================================================================
197
+ # Create model structure and load the checkpoint
198
+ logger.info(f"Building HunYuan-DiT model...")
199
+ model_config = HUNYUAN_DIT_CONFIG[self.args.model]
200
+ self.patch_size = model_config['patch_size']
201
+ self.head_size = model_config['hidden_size'] // model_config['num_heads']
202
+ self.resolutions, self.freqs_cis_img = self.standard_shapes() # Used for TensorRT models
203
+ self.image_size = _to_tuple(self.args.image_size)
204
+ latent_size = (self.image_size[0] // 8, self.image_size[1] // 8)
205
+
206
+ self.infer_mode = self.args.infer_mode
207
+ if self.infer_mode in ['fa', 'torch']:
208
+ model_path = model_dir / f"pytorch_model_{self.args.load_key}.pt"
209
+ if not model_path.exists():
210
+ raise ValueError(f"model_path not exists: {model_path}")
211
+ # Build model structure
212
+ self.model = HunYuanDiT(self.args,
213
+ input_size=latent_size,
214
+ **model_config,
215
+ log_fn=logger.info,
216
+ ).half().to(self.device) # Force to use fp16
217
+ # Load model checkpoint
218
+ logger.info(f"Loading model checkpoint {model_path}...")
219
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
220
+ self.model.load_state_dict(state_dict)
221
+ self.model.eval()
222
+ elif self.infer_mode == 'trt':
223
+ raise NotImplementedError("TensorRT model is not supported yet.")
224
+ else:
225
+ raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
226
+
227
+ # ========================================================================
228
+ # Build inference pipeline. We use a customized StableDiffusionPipeline.
229
+ logger.info(f"Loading inference pipeline...")
230
+ self.pipeline, self.sampler = self.load_sampler()
231
+ logger.info(f'Loading pipeline finished')
232
+
233
+ # ========================================================================
234
+ self.default_negative_prompt = NEGATIVE_PROMPT
235
+ logger.info("==================================================")
236
+ logger.info(f" Model is ready. ")
237
+ logger.info("==================================================")
238
+
239
+ def load_sampler(self, sampler=None):
240
+ pipeline, sampler = get_pipeline(self.args,
241
+ self.vae,
242
+ self.clip_text_encoder,
243
+ self.tokenizer,
244
+ self.model,
245
+ device=self.device,
246
+ rank=0,
247
+ embedder_t5=self.embedder_t5,
248
+ infer_mode=self.infer_mode,
249
+ sampler=sampler,
250
+ )
251
+ return pipeline, sampler
252
+
253
+ def calc_rope(self, height, width):
254
+ th = height // 8 // self.patch_size
255
+ tw = width // 8 // self.patch_size
256
+ base_size = 512 // 8 // self.patch_size
257
+ start, stop = get_fill_resize_and_crop((th, tw), base_size)
258
+ sub_args = [start, stop, (th, tw)]
259
+ rope = get_2d_rotary_pos_embed(self.head_size, *sub_args)
260
+ return rope
261
+
262
+ def standard_shapes(self):
263
+ resolutions = ResolutionGroup()
264
+ freqs_cis_img = {}
265
+ for reso in resolutions.data:
266
+ freqs_cis_img[str(reso)] = self.calc_rope(reso.height, reso.width)
267
+ return resolutions, freqs_cis_img
268
+
269
+ def predict(self,
270
+ user_prompt,
271
+ height=1024,
272
+ width=1024,
273
+ seed=None,
274
+ enhanced_prompt=None,
275
+ negative_prompt=None,
276
+ infer_steps=100,
277
+ guidance_scale=6,
278
+ batch_size=1,
279
+ src_size_cond=(1024, 1024),
280
+ sampler=None,
281
+ ):
282
+ # ========================================================================
283
+ # Arguments: seed
284
+ # ========================================================================
285
+ if seed is None:
286
+ seed = random.randint(0, 1_000_000)
287
+ if not isinstance(seed, int):
288
+ raise TypeError(f"`seed` must be an integer, but got {type(seed)}")
289
+ generator = set_seeds(seed)
290
+
291
+ # ========================================================================
292
+ # Arguments: target_width, target_height
293
+ # ========================================================================
294
+ if width <= 0 or height <= 0:
295
+ raise ValueError(f"`height` and `width` must be positive integers, got height={height}, width={width}")
296
+ logger.info(f"Input (height, width) = ({height}, {width})")
297
+ if self.infer_mode in ['fa', 'torch']:
298
+ # We must force height and width to align to 16 and to be an integer.
299
+ target_height = int((height // 16) * 16)
300
+ target_width = int((width // 16) * 16)
301
+ logger.info(f"Align to 16: (height, width) = ({target_height}, {target_width})")
302
+ elif self.infer_mode == 'trt':
303
+ target_width, target_height = get_standard_shape(width, height)
304
+ logger.info(f"Align to standard shape: (height, width) = ({target_height}, {target_width})")
305
+ else:
306
+ raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
307
+
308
+ # ========================================================================
309
+ # Arguments: prompt, new_prompt, negative_prompt
310
+ # ========================================================================
311
+ if not isinstance(user_prompt, str):
312
+ raise TypeError(f"`user_prompt` must be a string, but got {type(user_prompt)}")
313
+ user_prompt = user_prompt.strip()
314
+ prompt = user_prompt
315
+
316
+ if enhanced_prompt is not None:
317
+ if not isinstance(enhanced_prompt, str):
318
+ raise TypeError(f"`enhanced_prompt` must be a string, but got {type(enhanced_prompt)}")
319
+ enhanced_prompt = enhanced_prompt.strip()
320
+ prompt = enhanced_prompt
321
+
322
+ # negative prompt
323
+ if negative_prompt is None or negative_prompt == '':
324
+ negative_prompt = self.default_negative_prompt
325
+ if not isinstance(negative_prompt, str):
326
+ raise TypeError(f"`negative_prompt` must be a string, but got {type(negative_prompt)}")
327
+
328
+ # ========================================================================
329
+ # Arguments: style. (A fixed argument. Don't Change it.)
330
+ # ========================================================================
331
+ style = torch.as_tensor([0, 0] * batch_size, device=self.device)
332
+
333
+ # ========================================================================
334
+ # Inner arguments: image_meta_size (Please refer to SDXL.)
335
+ # ========================================================================
336
+ if isinstance(src_size_cond, int):
337
+ src_size_cond = [src_size_cond, src_size_cond]
338
+ if not isinstance(src_size_cond, (list, tuple)):
339
+ raise TypeError(f"`src_size_cond` must be a list or tuple, but got {type(src_size_cond)}")
340
+ if len(src_size_cond) != 2:
341
+ raise ValueError(f"`src_size_cond` must be a tuple of 2 integers, but got {len(src_size_cond)}")
342
+ size_cond = list(src_size_cond) + [target_width, target_height, 0, 0]
343
+ image_meta_size = torch.as_tensor([size_cond] * 2 * batch_size, device=self.device)
344
+
345
+ # ========================================================================
346
+ start_time = time.time()
347
+ logger.debug(f"""
348
+ prompt: {user_prompt}
349
+ enhanced prompt: {enhanced_prompt}
350
+ seed: {seed}
351
+ (height, width): {(target_height, target_width)}
352
+ negative_prompt: {negative_prompt}
353
+ batch_size: {batch_size}
354
+ guidance_scale: {guidance_scale}
355
+ infer_steps: {infer_steps}
356
+ image_meta_size: {size_cond}
357
+ """)
358
+ reso = f'{target_height}x{target_width}'
359
+ if reso in self.freqs_cis_img:
360
+ freqs_cis_img = self.freqs_cis_img[reso]
361
+ else:
362
+ freqs_cis_img = self.calc_rope(target_height, target_width)
363
+
364
+ if sampler is not None and sampler != self.sampler:
365
+ self.pipeline, self.sampler = self.load_sampler(sampler)
366
+
367
+ samples = self.pipeline(
368
+ height=target_height,
369
+ width=target_width,
370
+ prompt=prompt,
371
+ negative_prompt=negative_prompt,
372
+ num_images_per_prompt=batch_size,
373
+ guidance_scale=guidance_scale,
374
+ num_inference_steps=infer_steps,
375
+ image_meta_size=image_meta_size,
376
+ style=style,
377
+ return_dict=False,
378
+ generator=generator,
379
+ freqs_cis_img=freqs_cis_img,
380
+ use_fp16=self.args.use_fp16,
381
+ learn_sigma=self.args.learn_sigma,
382
+ )[0]
383
+ gen_time = time.time() - start_time
384
+ logger.debug(f"Success, time: {gen_time}")
385
+
386
+ return {
387
+ 'images': samples,
388
+ 'seed': seed,
389
+ }
hydit/modules/__init__.py ADDED
File without changes
hydit/modules/attn_layers.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Tuple, Union, Optional
4
+
5
+ try:
6
+ import flash_attn
7
+ if hasattr(flash_attn, '__version__') and int(flash_attn.__version__[0]) == 2:
8
+ from flash_attn.flash_attn_interface import flash_attn_kvpacked_func
9
+ from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention
10
+ else:
11
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
12
+ from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention
13
+ except Exception as e:
14
+ print(f'flash_attn import failed: {e}')
15
+
16
+
17
+ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
18
+ """
19
+ Reshape frequency tensor for broadcasting it with another tensor.
20
+
21
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
22
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
23
+
24
+ Args:
25
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
26
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
27
+ head_first (bool): head dimension first (except batch dim) or not.
28
+
29
+ Returns:
30
+ torch.Tensor: Reshaped frequency tensor.
31
+
32
+ Raises:
33
+ AssertionError: If the frequency tensor doesn't match the expected shape.
34
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
35
+ """
36
+ ndim = x.ndim
37
+ assert 0 <= 1 < ndim
38
+
39
+ if isinstance(freqs_cis, tuple):
40
+ # freqs_cis: (cos, sin) in real space
41
+ if head_first:
42
+ assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
43
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
44
+ else:
45
+ assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
46
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
47
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
48
+ else:
49
+ # freqs_cis: values in complex space
50
+ if head_first:
51
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
52
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
53
+ else:
54
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
55
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
56
+ return freqs_cis.view(*shape)
57
+
58
+
59
+ def rotate_half(x):
60
+ x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
61
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
62
+
63
+
64
+ def apply_rotary_emb(
65
+ xq: torch.Tensor,
66
+ xk: Optional[torch.Tensor],
67
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
68
+ head_first: bool = False,
69
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
70
+ """
71
+ Apply rotary embeddings to input tensors using the given frequency tensor.
72
+
73
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
74
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
75
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
76
+ returned as real tensors.
77
+
78
+ Args:
79
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
80
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
81
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
82
+ head_first (bool): head dimension first (except batch dim) or not.
83
+
84
+ Returns:
85
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
86
+
87
+ """
88
+ xk_out = None
89
+ if isinstance(freqs_cis, tuple):
90
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
91
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
92
+ xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
93
+ if xk is not None:
94
+ xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
95
+ else:
96
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
97
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
98
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
99
+ if xk is not None:
100
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
101
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
102
+
103
+ return xq_out, xk_out
104
+
105
+
106
+ class FlashSelfMHAModified(nn.Module):
107
+ """
108
+ Use QK Normalization.
109
+ """
110
+ def __init__(self,
111
+ dim,
112
+ num_heads,
113
+ qkv_bias=True,
114
+ qk_norm=False,
115
+ attn_drop=0.0,
116
+ proj_drop=0.0,
117
+ device=None,
118
+ dtype=None,
119
+ norm_layer=nn.LayerNorm,
120
+ ):
121
+ factory_kwargs = {'device': device, 'dtype': dtype}
122
+ super().__init__()
123
+ self.dim = dim
124
+ self.num_heads = num_heads
125
+ assert self.dim % num_heads == 0, "self.kdim must be divisible by num_heads"
126
+ self.head_dim = self.dim // num_heads
127
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
128
+
129
+ self.Wqkv = nn.Linear(dim, 3 * dim, bias=qkv_bias, **factory_kwargs)
130
+ # TODO: eps should be 1 / 65530 if using fp16
131
+ self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
132
+ self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
133
+ self.inner_attn = FlashSelfAttention(attention_dropout=attn_drop)
134
+ self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs)
135
+ self.proj_drop = nn.Dropout(proj_drop)
136
+
137
+ def forward(self, x, freqs_cis_img=None):
138
+ """
139
+ Parameters
140
+ ----------
141
+ x: torch.Tensor
142
+ (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
143
+ freqs_cis_img: torch.Tensor
144
+ (batch, hidden_dim // 2), RoPE for image
145
+ """
146
+ b, s, d = x.shape
147
+
148
+ qkv = self.Wqkv(x)
149
+ qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, h, d]
150
+ q, k, v = qkv.unbind(dim=2) # [b, s, h, d]
151
+ q = self.q_norm(q).half() # [b, s, h, d]
152
+ k = self.k_norm(k).half()
153
+
154
+ # Apply RoPE if needed
155
+ if freqs_cis_img is not None:
156
+ qq, kk = apply_rotary_emb(q, k, freqs_cis_img)
157
+ assert qq.shape == q.shape and kk.shape == k.shape, f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
158
+ q, k = qq, kk
159
+
160
+ qkv = torch.stack([q, k, v], dim=2) # [b, s, 3, h, d]
161
+ context = self.inner_attn(qkv)
162
+ out = self.out_proj(context.view(b, s, d))
163
+ out = self.proj_drop(out)
164
+
165
+ out_tuple = (out,)
166
+
167
+ return out_tuple
168
+
169
+
170
+ class FlashCrossMHAModified(nn.Module):
171
+ """
172
+ Use QK Normalization.
173
+ """
174
+ def __init__(self,
175
+ qdim,
176
+ kdim,
177
+ num_heads,
178
+ qkv_bias=True,
179
+ qk_norm=False,
180
+ attn_drop=0.0,
181
+ proj_drop=0.0,
182
+ device=None,
183
+ dtype=None,
184
+ norm_layer=nn.LayerNorm,
185
+ ):
186
+ factory_kwargs = {'device': device, 'dtype': dtype}
187
+ super().__init__()
188
+ self.qdim = qdim
189
+ self.kdim = kdim
190
+ self.num_heads = num_heads
191
+ assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
192
+ self.head_dim = self.qdim // num_heads
193
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
194
+
195
+ self.scale = self.head_dim ** -0.5
196
+
197
+ self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
198
+ self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
199
+
200
+ # TODO: eps should be 1 / 65530 if using fp16
201
+ self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
202
+ self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
203
+
204
+ self.inner_attn = FlashCrossAttention(attention_dropout=attn_drop)
205
+ self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
206
+ self.proj_drop = nn.Dropout(proj_drop)
207
+
208
+ def forward(self, x, y, freqs_cis_img=None):
209
+ """
210
+ Parameters
211
+ ----------
212
+ x: torch.Tensor
213
+ (batch, seqlen1, hidden_dim) (where hidden_dim = num_heads * head_dim)
214
+ y: torch.Tensor
215
+ (batch, seqlen2, hidden_dim2)
216
+ freqs_cis_img: torch.Tensor
217
+ (batch, hidden_dim // num_heads), RoPE for image
218
+ """
219
+ b, s1, _ = x.shape # [b, s1, D]
220
+ _, s2, _ = y.shape # [b, s2, 1024]
221
+
222
+ q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
223
+ kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
224
+ k, v = kv.unbind(dim=2) # [b, s2, h, d]
225
+ q = self.q_norm(q).half() # [b, s1, h, d]
226
+ k = self.k_norm(k).half() # [b, s2, h, d]
227
+
228
+ # Apply RoPE if needed
229
+ if freqs_cis_img is not None:
230
+ qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
231
+ assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
232
+ q = qq # [b, s1, h, d]
233
+ kv = torch.stack([k, v], dim=2) # [b, s1, 2, h, d]
234
+ context = self.inner_attn(q, kv) # [b, s1, h, d]
235
+ context = context.view(b, s1, -1) # [b, s1, D]
236
+
237
+ out = self.out_proj(context)
238
+ out = self.proj_drop(out)
239
+
240
+ out_tuple = (out,)
241
+
242
+ return out_tuple
243
+
244
+
245
+ class CrossAttention(nn.Module):
246
+ """
247
+ Use QK Normalization.
248
+ """
249
+ def __init__(self,
250
+ qdim,
251
+ kdim,
252
+ num_heads,
253
+ qkv_bias=True,
254
+ qk_norm=False,
255
+ attn_drop=0.0,
256
+ proj_drop=0.0,
257
+ device=None,
258
+ dtype=None,
259
+ norm_layer=nn.LayerNorm,
260
+ ):
261
+ factory_kwargs = {'device': device, 'dtype': dtype}
262
+ super().__init__()
263
+ self.qdim = qdim
264
+ self.kdim = kdim
265
+ self.num_heads = num_heads
266
+ assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
267
+ self.head_dim = self.qdim // num_heads
268
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
269
+ self.scale = self.head_dim ** -0.5
270
+
271
+ self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
272
+ self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
273
+
274
+ # TODO: eps should be 1 / 65530 if using fp16
275
+ self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
276
+ self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
277
+ self.attn_drop = nn.Dropout(attn_drop)
278
+ self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
279
+ self.proj_drop = nn.Dropout(proj_drop)
280
+
281
+ def forward(self, x, y, freqs_cis_img=None):
282
+ """
283
+ Parameters
284
+ ----------
285
+ x: torch.Tensor
286
+ (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
287
+ y: torch.Tensor
288
+ (batch, seqlen2, hidden_dim2)
289
+ freqs_cis_img: torch.Tensor
290
+ (batch, hidden_dim // 2), RoPE for image
291
+ """
292
+ b, s1, c = x.shape # [b, s1, D]
293
+ _, s2, c = y.shape # [b, s2, 1024]
294
+
295
+ q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
296
+ kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
297
+ k, v = kv.unbind(dim=2) # [b, s, h, d]
298
+ q = self.q_norm(q)
299
+ k = self.k_norm(k)
300
+
301
+ # Apply RoPE if needed
302
+ if freqs_cis_img is not None:
303
+ qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
304
+ assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
305
+ q = qq
306
+
307
+ q = q * self.scale
308
+ q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
309
+ k = k.permute(0, 2, 3, 1).contiguous() # k -> B, L2, H, C - B, H, C, L2
310
+ attn = q @ k # attn -> B, H, L1, L2
311
+ attn = attn.softmax(dim=-1) # attn -> B, H, L1, L2
312
+ attn = self.attn_drop(attn)
313
+ x = attn @ v.transpose(-2, -3) # v -> B, L2, H, C - B, H, L2, C x-> B, H, L1, C
314
+ context = x.transpose(1, 2) # context -> B, H, L1, C - B, L1, H, C
315
+
316
+ context = context.contiguous().view(b, s1, -1)
317
+
318
+ out = self.out_proj(context) # context.reshape - B, L1, -1
319
+ out = self.proj_drop(out)
320
+
321
+ out_tuple = (out,)
322
+
323
+ return out_tuple
324
+
325
+
326
+ class Attention(nn.Module):
327
+ """
328
+ We rename some layer names to align with flash attention
329
+ """
330
+ def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0.,
331
+ norm_layer=nn.LayerNorm,
332
+ ):
333
+ super().__init__()
334
+ self.dim = dim
335
+ self.num_heads = num_heads
336
+ assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
337
+ self.head_dim = self.dim // num_heads
338
+ # This assertion is aligned with flash attention
339
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
340
+ self.scale = self.head_dim ** -0.5
341
+
342
+ # qkv --> Wqkv
343
+ self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
344
+ # TODO: eps should be 1 / 65530 if using fp16
345
+ self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
346
+ self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
347
+ self.attn_drop = nn.Dropout(attn_drop)
348
+ self.out_proj = nn.Linear(dim, dim)
349
+ self.proj_drop = nn.Dropout(proj_drop)
350
+
351
+ def forward(self, x, freqs_cis_img=None):
352
+ B, N, C = x.shape
353
+ qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d]
354
+ q, k, v = qkv.unbind(0) # [b, h, s, d]
355
+ q = self.q_norm(q) # [b, h, s, d]
356
+ k = self.k_norm(k) # [b, h, s, d]
357
+
358
+ # Apply RoPE if needed
359
+ if freqs_cis_img is not None:
360
+ qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
361
+ assert qq.shape == q.shape and kk.shape == k.shape, \
362
+ f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
363
+ q, k = qq, kk
364
+
365
+ q = q * self.scale
366
+ attn = q @ k.transpose(-2, -1) # [b, h, s, d] @ [b, h, d, s]
367
+ attn = attn.softmax(dim=-1) # [b, h, s, s]
368
+ attn = self.attn_drop(attn)
369
+ x = attn @ v # [b, h, s, d]
370
+
371
+ x = x.transpose(1, 2).reshape(B, N, C) # [b, s, h, d]
372
+ x = self.out_proj(x)
373
+ x = self.proj_drop(x)
374
+
375
+ out_tuple = (x,)
376
+
377
+ return out_tuple
hydit/modules/embedders.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import repeat
5
+
6
+ from timm.models.layers import to_2tuple
7
+
8
+
9
+ class PatchEmbed(nn.Module):
10
+ """ 2D Image to Patch Embedding
11
+
12
+ Image to Patch Embedding using Conv2d
13
+
14
+ A convolution based approach to patchifying a 2D image w/ embedding projection.
15
+
16
+ Based on the impl in https://github.com/google-research/vision_transformer
17
+
18
+ Hacked together by / Copyright 2020 Ross Wightman
19
+
20
+ Remove the _assert function in forward function to be compatible with multi-resolution images.
21
+ """
22
+ def __init__(
23
+ self,
24
+ img_size=224,
25
+ patch_size=16,
26
+ in_chans=3,
27
+ embed_dim=768,
28
+ norm_layer=None,
29
+ flatten=True,
30
+ bias=True,
31
+ ):
32
+ super().__init__()
33
+ if isinstance(img_size, int):
34
+ img_size = to_2tuple(img_size)
35
+ elif isinstance(img_size, (tuple, list)) and len(img_size) == 2:
36
+ img_size = tuple(img_size)
37
+ else:
38
+ raise ValueError(f"img_size must be int or tuple/list of length 2. Got {img_size}")
39
+ patch_size = to_2tuple(patch_size)
40
+ self.img_size = img_size
41
+ self.patch_size = patch_size
42
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
43
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
44
+ self.flatten = flatten
45
+
46
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
47
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
48
+
49
+ def update_image_size(self, img_size):
50
+ self.img_size = img_size
51
+ self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1])
52
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
53
+
54
+ def forward(self, x):
55
+ # B, C, H, W = x.shape
56
+ # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
57
+ # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
58
+ x = self.proj(x)
59
+ if self.flatten:
60
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
61
+ x = self.norm(x)
62
+ return x
63
+
64
+
65
+ def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
66
+ """
67
+ Create sinusoidal timestep embeddings.
68
+ :param t: a 1-D Tensor of N indices, one per batch element.
69
+ These may be fractional.
70
+ :param dim: the dimension of the output.
71
+ :param max_period: controls the minimum frequency of the embeddings.
72
+ :return: an (N, D) Tensor of positional embeddings.
73
+ """
74
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
75
+ if not repeat_only:
76
+ half = dim // 2
77
+ freqs = torch.exp(
78
+ -math.log(max_period)
79
+ * torch.arange(start=0, end=half, dtype=torch.float32)
80
+ / half
81
+ ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
82
+ args = t[:, None].float() * freqs[None]
83
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
84
+ if dim % 2:
85
+ embedding = torch.cat(
86
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
87
+ )
88
+ else:
89
+ embedding = repeat(t, "b -> b d", d=dim)
90
+ return embedding
91
+
92
+
93
+ class TimestepEmbedder(nn.Module):
94
+ """
95
+ Embeds scalar timesteps into vector representations.
96
+ """
97
+ def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None):
98
+ super().__init__()
99
+ if out_size is None:
100
+ out_size = hidden_size
101
+ self.mlp = nn.Sequential(
102
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
103
+ nn.SiLU(),
104
+ nn.Linear(hidden_size, out_size, bias=True),
105
+ )
106
+ self.frequency_embedding_size = frequency_embedding_size
107
+
108
+ def forward(self, t):
109
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
110
+ t_emb = self.mlp(t_freq)
111
+ return t_emb
hydit/modules/models.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
5
+ from diffusers.models import ModelMixin
6
+ from timm.models.vision_transformer import Mlp
7
+
8
+ from .attn_layers import Attention, FlashCrossMHAModified, FlashSelfMHAModified, CrossAttention
9
+ from .embedders import TimestepEmbedder, PatchEmbed, timestep_embedding
10
+ from .norm_layers import RMSNorm
11
+ from .poolers import AttentionPool
12
+
13
+
14
+ def modulate(x, shift, scale):
15
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
16
+
17
+
18
+ class FP32_Layernorm(nn.LayerNorm):
19
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
20
+ origin_dtype = inputs.dtype
21
+ return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(),
22
+ self.eps).to(origin_dtype)
23
+
24
+
25
+ class FP32_SiLU(nn.SiLU):
26
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
27
+ return torch.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype)
28
+
29
+
30
+ class HunYuanDiTBlock(nn.Module):
31
+ """
32
+ A HunYuanDiT block with `add` conditioning.
33
+ """
34
+ def __init__(self,
35
+ hidden_size,
36
+ c_emb_size,
37
+ num_heads,
38
+ mlp_ratio=4.0,
39
+ text_states_dim=1024,
40
+ use_flash_attn=False,
41
+ qk_norm=False,
42
+ norm_type="layer",
43
+ skip=False,
44
+ ):
45
+ super().__init__()
46
+ self.use_flash_attn = use_flash_attn
47
+ use_ele_affine = True
48
+
49
+ if norm_type == "layer":
50
+ norm_layer = FP32_Layernorm
51
+ elif norm_type == "rms":
52
+ norm_layer = RMSNorm
53
+ else:
54
+ raise ValueError(f"Unknown norm_type: {norm_type}")
55
+
56
+ # ========================= Self-Attention =========================
57
+ self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
58
+ if use_flash_attn:
59
+ self.attn1 = FlashSelfMHAModified(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm)
60
+ else:
61
+ self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm)
62
+
63
+ # ========================= FFN =========================
64
+ self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
65
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
66
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
67
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
68
+
69
+ # ========================= Add =========================
70
+ # Simply use add like SDXL.
71
+ self.default_modulation = nn.Sequential(
72
+ FP32_SiLU(),
73
+ nn.Linear(c_emb_size, hidden_size, bias=True)
74
+ )
75
+
76
+ # ========================= Cross-Attention =========================
77
+ if use_flash_attn:
78
+ self.attn2 = FlashCrossMHAModified(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
79
+ qk_norm=qk_norm)
80
+ else:
81
+ self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
82
+ qk_norm=qk_norm)
83
+ self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
84
+
85
+ # ========================= Skip Connection =========================
86
+ if skip:
87
+ self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6)
88
+ self.skip_linear = nn.Linear(2 * hidden_size, hidden_size)
89
+ else:
90
+ self.skip_linear = None
91
+
92
+ def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
93
+ # Long Skip Connection
94
+ if self.skip_linear is not None:
95
+ cat = torch.cat([x, skip], dim=-1)
96
+ cat = self.skip_norm(cat)
97
+ x = self.skip_linear(cat)
98
+
99
+ # Self-Attention
100
+ shift_msa = self.default_modulation(c).unsqueeze(dim=1)
101
+ attn_inputs = (
102
+ self.norm1(x) + shift_msa, freq_cis_img,
103
+ )
104
+ x = x + self.attn1(*attn_inputs)[0]
105
+
106
+ # Cross-Attention
107
+ cross_inputs = (
108
+ self.norm3(x), text_states, freq_cis_img
109
+ )
110
+ x = x + self.attn2(*cross_inputs)[0]
111
+
112
+ # FFN Layer
113
+ mlp_inputs = self.norm2(x)
114
+ x = x + self.mlp(mlp_inputs)
115
+
116
+ return x
117
+
118
+
119
+ class FinalLayer(nn.Module):
120
+ """
121
+ The final layer of HunYuanDiT.
122
+ """
123
+ def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels):
124
+ super().__init__()
125
+ self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
126
+ self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
127
+ self.adaLN_modulation = nn.Sequential(
128
+ FP32_SiLU(),
129
+ nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)
130
+ )
131
+
132
+ def forward(self, x, c):
133
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
134
+ x = modulate(self.norm_final(x), shift, scale)
135
+ x = self.linear(x)
136
+ return x
137
+
138
+
139
+ class HunYuanDiT(ModelMixin, ConfigMixin):
140
+ """
141
+ HunYuanDiT: Diffusion model with a Transformer backbone.
142
+
143
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
144
+
145
+ Parameters
146
+ ----------
147
+ args: argparse.Namespace
148
+ The arguments parsed by argparse.
149
+ input_size: tuple
150
+ The size of the input image.
151
+ patch_size: int
152
+ The size of the patch.
153
+ in_channels: int
154
+ The number of input channels.
155
+ hidden_size: int
156
+ The hidden size of the transformer backbone.
157
+ depth: int
158
+ The number of transformer blocks.
159
+ num_heads: int
160
+ The number of attention heads.
161
+ mlp_ratio: float
162
+ The ratio of the hidden size of the MLP in the transformer block.
163
+ log_fn: callable
164
+ The logging function.
165
+ """
166
+ @register_to_config
167
+ def __init__(
168
+ self, args,
169
+ input_size=(32, 32),
170
+ patch_size=2,
171
+ in_channels=4,
172
+ hidden_size=1152,
173
+ depth=28,
174
+ num_heads=16,
175
+ mlp_ratio=4.0,
176
+ log_fn=print,
177
+ ):
178
+ super().__init__()
179
+ self.args = args
180
+ self.log_fn = log_fn
181
+ self.depth = depth
182
+ self.learn_sigma = args.learn_sigma
183
+ self.in_channels = in_channels
184
+ self.out_channels = in_channels * 2 if args.learn_sigma else in_channels
185
+ self.patch_size = patch_size
186
+ self.num_heads = num_heads
187
+ self.hidden_size = hidden_size
188
+ self.text_states_dim = args.text_states_dim
189
+ self.text_states_dim_t5 = args.text_states_dim_t5
190
+ self.text_len = args.text_len
191
+ self.text_len_t5 = args.text_len_t5
192
+ self.norm = args.norm
193
+
194
+ use_flash_attn = args.infer_mode == 'fa'
195
+ if use_flash_attn:
196
+ log_fn(f" Enable Flash Attention.")
197
+ qk_norm = True # See http://arxiv.org/abs/2302.05442 for details.
198
+
199
+ self.mlp_t5 = nn.Sequential(
200
+ nn.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True),
201
+ FP32_SiLU(),
202
+ nn.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True),
203
+ )
204
+ # learnable replace
205
+ self.text_embedding_padding = nn.Parameter(
206
+ torch.randn(self.text_len + self.text_len_t5, self.text_states_dim, dtype=torch.float32))
207
+
208
+ # Attention pooling
209
+ self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=1024)
210
+
211
+ # Here we use a default learned embedder layer for future extension.
212
+ self.style_embedder = nn.Embedding(1, hidden_size)
213
+
214
+ # Image size and crop size conditions
215
+ self.extra_in_dim = 256 * 6 + hidden_size
216
+
217
+ # Text embedding for `add`
218
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size)
219
+ self.t_embedder = TimestepEmbedder(hidden_size)
220
+ self.extra_in_dim += 1024
221
+ self.extra_embedder = nn.Sequential(
222
+ nn.Linear(self.extra_in_dim, hidden_size * 4),
223
+ FP32_SiLU(),
224
+ nn.Linear(hidden_size * 4, hidden_size, bias=True),
225
+ )
226
+
227
+ # Image embedding
228
+ num_patches = self.x_embedder.num_patches
229
+ log_fn(f" Number of tokens: {num_patches}")
230
+
231
+ # HUnYuanDiT Blocks
232
+ self.blocks = nn.ModuleList([
233
+ HunYuanDiTBlock(hidden_size=hidden_size,
234
+ c_emb_size=hidden_size,
235
+ num_heads=num_heads,
236
+ mlp_ratio=mlp_ratio,
237
+ text_states_dim=self.text_states_dim,
238
+ use_flash_attn=use_flash_attn,
239
+ qk_norm=qk_norm,
240
+ norm_type=self.norm,
241
+ skip=layer > depth // 2,
242
+ )
243
+ for layer in range(depth)
244
+ ])
245
+
246
+ self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels)
247
+ self.unpatchify_channels = self.out_channels
248
+
249
+ self.initialize_weights()
250
+
251
+ def forward(self,
252
+ x,
253
+ t,
254
+ encoder_hidden_states=None,
255
+ text_embedding_mask=None,
256
+ encoder_hidden_states_t5=None,
257
+ text_embedding_mask_t5=None,
258
+ image_meta_size=None,
259
+ style=None,
260
+ cos_cis_img=None,
261
+ sin_cis_img=None,
262
+ return_dict=True,
263
+ ):
264
+ """
265
+ Forward pass of the encoder.
266
+
267
+ Parameters
268
+ ----------
269
+ x: torch.Tensor
270
+ (B, D, H, W)
271
+ t: torch.Tensor
272
+ (B)
273
+ encoder_hidden_states: torch.Tensor
274
+ CLIP text embedding, (B, L_clip, D)
275
+ text_embedding_mask: torch.Tensor
276
+ CLIP text embedding mask, (B, L_clip)
277
+ encoder_hidden_states_t5: torch.Tensor
278
+ T5 text embedding, (B, L_t5, D)
279
+ text_embedding_mask_t5: torch.Tensor
280
+ T5 text embedding mask, (B, L_t5)
281
+ image_meta_size: torch.Tensor
282
+ (B, 6)
283
+ style: torch.Tensor
284
+ (B)
285
+ cos_cis_img: torch.Tensor
286
+ sin_cis_img: torch.Tensor
287
+ return_dict: bool
288
+ Whether to return a dictionary.
289
+ """
290
+
291
+ text_states = encoder_hidden_states # 2,77,1024
292
+ text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
293
+ text_states_mask = text_embedding_mask.bool() # 2,77
294
+ text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
295
+ b_t5, l_t5, c_t5 = text_states_t5.shape
296
+ text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5))
297
+ text_states = torch.cat([text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1) # 2,205,1024
298
+ clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
299
+
300
+ clip_t5_mask = clip_t5_mask
301
+ text_states = torch.where(clip_t5_mask.unsqueeze(2), text_states, self.text_embedding_padding.to(text_states))
302
+
303
+ _, _, oh, ow = x.shape
304
+ th, tw = oh // self.patch_size, ow // self.patch_size
305
+
306
+ # ========================= Build time and image embedding =========================
307
+ t = self.t_embedder(t)
308
+ x = self.x_embedder(x)
309
+
310
+ # Get image RoPE embedding according to `reso`lution.
311
+ freqs_cis_img = (cos_cis_img, sin_cis_img)
312
+
313
+ # ========================= Concatenate all extra vectors =========================
314
+ # Build text tokens with pooling
315
+ extra_vec = self.pooler(encoder_hidden_states_t5)
316
+
317
+ # Build image meta size tokens
318
+ image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
319
+ if self.args.use_fp16:
320
+ image_meta_size = image_meta_size.half()
321
+ image_meta_size = image_meta_size.view(-1, 6 * 256)
322
+ extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
323
+
324
+ # Build style tokens
325
+ style_embedding = self.style_embedder(style)
326
+ extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
327
+
328
+ # Concatenate all extra vectors
329
+ c = t + self.extra_embedder(extra_vec) # [B, D]
330
+
331
+ # ========================= Forward pass through HunYuanDiT blocks =========================
332
+ skips = []
333
+ for layer, block in enumerate(self.blocks):
334
+ if layer > self.depth // 2:
335
+ skip = skips.pop()
336
+ x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
337
+ else:
338
+ x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
339
+
340
+ if layer < (self.depth // 2 - 1):
341
+ skips.append(x)
342
+
343
+ # ========================= Final layer =========================
344
+ x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels)
345
+ x = self.unpatchify(x, th, tw) # (N, out_channels, H, W)
346
+
347
+ if return_dict:
348
+ return {'x': x}
349
+ return x
350
+
351
+ def initialize_weights(self):
352
+ # Initialize transformer layers:
353
+ def _basic_init(module):
354
+ if isinstance(module, nn.Linear):
355
+ torch.nn.init.xavier_uniform_(module.weight)
356
+ if module.bias is not None:
357
+ nn.init.constant_(module.bias, 0)
358
+ self.apply(_basic_init)
359
+
360
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
361
+ w = self.x_embedder.proj.weight.data
362
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
363
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
364
+
365
+ # Initialize label embedding table:
366
+ nn.init.normal_(self.extra_embedder[0].weight, std=0.02)
367
+ nn.init.normal_(self.extra_embedder[2].weight, std=0.02)
368
+
369
+ # Initialize timestep embedding MLP:
370
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
371
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
372
+
373
+ # Zero-out adaLN modulation layers in HunYuanDiT blocks:
374
+ for block in self.blocks:
375
+ nn.init.constant_(block.default_modulation[-1].weight, 0)
376
+ nn.init.constant_(block.default_modulation[-1].bias, 0)
377
+
378
+ # Zero-out output layers:
379
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
380
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
381
+ nn.init.constant_(self.final_layer.linear.weight, 0)
382
+ nn.init.constant_(self.final_layer.linear.bias, 0)
383
+
384
+ def unpatchify(self, x, h, w):
385
+ """
386
+ x: (N, T, patch_size**2 * C)
387
+ imgs: (N, H, W, C)
388
+ """
389
+ c = self.unpatchify_channels
390
+ p = self.x_embedder.patch_size[0]
391
+ # h = w = int(x.shape[1] ** 0.5)
392
+ assert h * w == x.shape[1]
393
+
394
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
395
+ x = torch.einsum('nhwpqc->nchpwq', x)
396
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
397
+ return imgs
398
+
399
+
400
+ #################################################################################
401
+ # HunYuanDiT Configs #
402
+ #################################################################################
403
+
404
+ HUNYUAN_DIT_CONFIG = {
405
+ 'DiT-g/2': {'depth': 40, 'hidden_size': 1408, 'patch_size': 2, 'num_heads': 16, 'mlp_ratio': 4.3637},
406
+ 'DiT-XL/2': {'depth': 28, 'hidden_size': 1152, 'patch_size': 2, 'num_heads': 16},
407
+ 'DiT-L/2': {'depth': 24, 'hidden_size': 1024, 'patch_size': 2, 'num_heads': 16},
408
+ 'DiT-B/2': {'depth': 12, 'hidden_size': 768, 'patch_size': 2, 'num_heads': 12},
409
+ }
hydit/modules/norm_layers.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class RMSNorm(nn.Module):
6
+ def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6):
7
+ """
8
+ Initialize the RMSNorm normalization layer.
9
+
10
+ Args:
11
+ dim (int): The dimension of the input tensor.
12
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
13
+
14
+ Attributes:
15
+ eps (float): A small value added to the denominator for numerical stability.
16
+ weight (nn.Parameter): Learnable scaling parameter.
17
+
18
+ """
19
+ super().__init__()
20
+ self.eps = eps
21
+ if elementwise_affine:
22
+ self.weight = nn.Parameter(torch.ones(dim))
23
+
24
+ def _norm(self, x):
25
+ """
26
+ Apply the RMSNorm normalization to the input tensor.
27
+
28
+ Args:
29
+ x (torch.Tensor): The input tensor.
30
+
31
+ Returns:
32
+ torch.Tensor: The normalized tensor.
33
+
34
+ """
35
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
36
+
37
+ def forward(self, x):
38
+ """
39
+ Forward pass through the RMSNorm layer.
40
+
41
+ Args:
42
+ x (torch.Tensor): The input tensor.
43
+
44
+ Returns:
45
+ torch.Tensor: The output tensor after applying RMSNorm.
46
+
47
+ """
48
+ output = self._norm(x.float()).type_as(x)
49
+ if hasattr(self, "weight"):
50
+ output = output * self.weight
51
+ return output
52
+
53
+
54
+ class GroupNorm32(nn.GroupNorm):
55
+ def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None):
56
+ super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype)
57
+
58
+ def forward(self, x):
59
+ y = super().forward(x).to(x.dtype)
60
+ return y
61
+
62
+ def normalization(channels, dtype=None):
63
+ """
64
+ Make a standard normalization layer.
65
+ :param channels: number of input channels.
66
+ :return: an nn.Module for normalization.
67
+ """
68
+ return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype)
hydit/modules/poolers.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class AttentionPool(nn.Module):
7
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
8
+ super().__init__()
9
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
10
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
11
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
12
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
13
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
14
+ self.num_heads = num_heads
15
+
16
+ def forward(self, x):
17
+ x = x.permute(1, 0, 2) # NLC -> LNC
18
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
19
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
20
+ x, _ = F.multi_head_attention_forward(
21
+ query=x[:1], key=x, value=x,
22
+ embed_dim_to_check=x.shape[-1],
23
+ num_heads=self.num_heads,
24
+ q_proj_weight=self.q_proj.weight,
25
+ k_proj_weight=self.k_proj.weight,
26
+ v_proj_weight=self.v_proj.weight,
27
+ in_proj_weight=None,
28
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
29
+ bias_k=None,
30
+ bias_v=None,
31
+ add_zero_attn=False,
32
+ dropout_p=0,
33
+ out_proj_weight=self.c_proj.weight,
34
+ out_proj_bias=self.c_proj.bias,
35
+ use_separate_proj_weight=True,
36
+ training=self.training,
37
+ need_weights=False
38
+ )
39
+ return x.squeeze(0)
hydit/modules/posemb_layers.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Union
4
+
5
+
6
+ def _to_tuple(x):
7
+ if isinstance(x, int):
8
+ return x, x
9
+ else:
10
+ return x
11
+
12
+
13
+ def get_fill_resize_and_crop(src, tgt): # src 来源的分辨率 tgt base 分辨率
14
+ th, tw = _to_tuple(tgt)
15
+ h, w = _to_tuple(src)
16
+
17
+ tr = th / tw # base 分辨率
18
+ r = h / w # 目标分辨率
19
+
20
+ # resize
21
+ if r > tr:
22
+ resize_height = th
23
+ resize_width = int(round(th / h * w))
24
+ else:
25
+ resize_width = tw
26
+ resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来
27
+
28
+ crop_top = int(round((th - resize_height) / 2.0))
29
+ crop_left = int(round((tw - resize_width) / 2.0))
30
+
31
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
32
+
33
+
34
+ def get_meshgrid(start, *args):
35
+ if len(args) == 0:
36
+ # start is grid_size
37
+ num = _to_tuple(start)
38
+ start = (0, 0)
39
+ stop = num
40
+ elif len(args) == 1:
41
+ # start is start, args[0] is stop, step is 1
42
+ start = _to_tuple(start)
43
+ stop = _to_tuple(args[0])
44
+ num = (stop[0] - start[0], stop[1] - start[1])
45
+ elif len(args) == 2:
46
+ # start is start, args[0] is stop, args[1] is num
47
+ start = _to_tuple(start) # 左上角 eg: 12,0
48
+ stop = _to_tuple(args[0]) # 右下角 eg: 20,32
49
+ num = _to_tuple(args[1]) # 目标大小 eg: 32,124
50
+ else:
51
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
52
+
53
+ grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份
54
+ grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
55
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
56
+ grid = np.stack(grid, axis=0) # [2, W, H]
57
+ return grid
58
+
59
+ #################################################################################
60
+ # Sine/Cosine Positional Embedding Functions #
61
+ #################################################################################
62
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
63
+
64
+ def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
65
+ """
66
+ grid_size: int of the grid height and width
67
+ return:
68
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
69
+ """
70
+ grid = get_meshgrid(start, *args) # [2, H, w]
71
+ # grid_h = np.arange(grid_size, dtype=np.float32)
72
+ # grid_w = np.arange(grid_size, dtype=np.float32)
73
+ # grid = np.meshgrid(grid_w, grid_h) # here w goes first
74
+ # grid = np.stack(grid, axis=0) # [2, W, H]
75
+
76
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
77
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
78
+ if cls_token and extra_tokens > 0:
79
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
80
+ return pos_embed
81
+
82
+
83
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
84
+ assert embed_dim % 2 == 0
85
+
86
+ # use half of dimensions to encode grid_h
87
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
88
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
89
+
90
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
91
+ return emb
92
+
93
+
94
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
95
+ """
96
+ embed_dim: output dimension for each position
97
+ pos: a list of positions to be encoded: size (W,H)
98
+ out: (M, D)
99
+ """
100
+ assert embed_dim % 2 == 0
101
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
102
+ omega /= embed_dim / 2.
103
+ omega = 1. / 10000**omega # (D/2,)
104
+
105
+ pos = pos.reshape(-1) # (M,)
106
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
107
+
108
+ emb_sin = np.sin(out) # (M, D/2)
109
+ emb_cos = np.cos(out) # (M, D/2)
110
+
111
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
112
+ return emb
113
+
114
+
115
+ #################################################################################
116
+ # Rotary Positional Embedding Functions #
117
+ #################################################################################
118
+ # https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
119
+
120
+ def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
121
+ """
122
+ This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
123
+
124
+ Parameters
125
+ ----------
126
+ embed_dim: int
127
+ embedding dimension size
128
+ start: int or tuple of int
129
+ If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
130
+ If len(args) == 2, start is start, args[0] is stop, args[1] is num.
131
+ use_real: bool
132
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
133
+
134
+ Returns
135
+ -------
136
+ pos_embed: torch.Tensor
137
+ [HW, D/2]
138
+ """
139
+ grid = get_meshgrid(start, *args) # [2, H, w]
140
+ grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致
141
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
142
+ return pos_embed
143
+
144
+
145
+ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
146
+ assert embed_dim % 4 == 0
147
+
148
+ # use half of dimensions to encode grid_h
149
+ emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
150
+ emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
151
+
152
+ if use_real:
153
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
154
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
155
+ return cos, sin
156
+ else:
157
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
158
+ return emb
159
+
160
+
161
+ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
162
+ """
163
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
164
+
165
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
166
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
167
+ The returned tensor contains complex values in complex64 data type.
168
+
169
+ Args:
170
+ dim (int): Dimension of the frequency tensor.
171
+ pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
172
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
173
+ use_real (bool, optional): If True, return real part and imaginary part separately.
174
+ Otherwise, return complex numbers.
175
+
176
+ Returns:
177
+ torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
178
+
179
+ """
180
+ if isinstance(pos, int):
181
+ pos = np.arange(pos)
182
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
183
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
184
+ freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
185
+ if use_real:
186
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
187
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
188
+ return freqs_cos, freqs_sin
189
+ else:
190
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
191
+ return freqs_cis
192
+
193
+
194
+
195
+ def calc_sizes(rope_img, patch_size, th, tw):
196
+ """ 计算 RoPE 的尺寸. """
197
+ if rope_img == 'extend':
198
+ # 拓展模式
199
+ sub_args = [(th, tw)]
200
+ elif rope_img.startswith('base'):
201
+ # 基于一个尺寸, 其他尺寸插值获得.
202
+ base_size = int(rope_img[4:]) // 8 // patch_size # 基于512作为base,其他根据512差值得到
203
+ start, stop = get_fill_resize_and_crop((th, tw), base_size) # 需要在32x32里面 crop的左上角和右下角
204
+ sub_args = [start, stop, (th, tw)]
205
+ else:
206
+ raise ValueError(f"Unknown rope_img: {rope_img}")
207
+ return sub_args
208
+
209
+
210
+ def init_image_posemb(rope_img,
211
+ resolutions,
212
+ patch_size,
213
+ hidden_size,
214
+ num_heads,
215
+ log_fn,
216
+ rope_real=True,
217
+ ):
218
+ freqs_cis_img = {}
219
+ for reso in resolutions:
220
+ th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
221
+ sub_args = calc_sizes(rope_img, patch_size, th, tw) # [左上角, 右下角, 目标高宽] 需要在32x32里面 crop的左上角和右下角
222
+ freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
223
+ log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
224
+ f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
225
+ return freqs_cis_img