Spaces:
Running
Running
Upload 57 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- README.md +1 -1
- app/hydit_app.py +170 -0
- app/lang/en.csv +22 -0
- app/lang/zh.csv +22 -0
- asset/Hunyuan_DiT_Tech_Report_05140553.pdf +3 -0
- asset/chinese elements understanding.png +3 -0
- asset/cover.png +0 -0
- asset/framework.png +0 -0
- asset/logo.png +0 -0
- asset/long text understanding.png +3 -0
- asset/mllm.png +0 -0
- asset/radar.png +0 -0
- dialoggen/dialoggen_demo.py +172 -0
- dialoggen/images/demo1.jpeg +0 -0
- dialoggen/images/demo2.jpeg +0 -0
- dialoggen/llava/__init__.py +1 -0
- dialoggen/llava/constants.py +13 -0
- dialoggen/llava/conversation.py +396 -0
- dialoggen/llava/mm_utils.py +247 -0
- dialoggen/llava/model/__init__.py +6 -0
- dialoggen/llava/model/apply_delta.py +48 -0
- dialoggen/llava/model/builder.py +167 -0
- dialoggen/llava/model/consolidate.py +29 -0
- dialoggen/llava/model/language_model/llava_llama.py +158 -0
- dialoggen/llava/model/language_model/llava_mistral.py +158 -0
- dialoggen/llava/model/language_model/llava_mpt.py +97 -0
- dialoggen/llava/model/llava_arch.py +368 -0
- dialoggen/llava/model/make_delta.py +52 -0
- dialoggen/llava/model/multimodal_encoder/builder.py +11 -0
- dialoggen/llava/model/multimodal_encoder/clip_encoder.py +88 -0
- dialoggen/llava/model/multimodal_projector/builder.py +51 -0
- dialoggen/llava/model/utils.py +20 -0
- dialoggen/llava/utils.py +126 -0
- en.csv +22 -0
- environment.yml +8 -0
- example_prompts.txt +28 -0
- hydit/__init__.py +0 -0
- hydit/config.py +67 -0
- hydit/constants.py +62 -0
- hydit/diffusion/__init__.py +0 -0
- hydit/diffusion/pipeline.py +830 -0
- hydit/inference.py +389 -0
- hydit/modules/__init__.py +0 -0
- hydit/modules/attn_layers.py +377 -0
- hydit/modules/embedders.py +111 -0
- hydit/modules/models.py +409 -0
- hydit/modules/norm_layers.py +68 -0
- hydit/modules/poolers.py +39 -0
- hydit/modules/posemb_layers.py +225 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
asset/chinese[[:space:]]elements[[:space:]]understanding.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
asset/Hunyuan_DiT_Tech_Report_05140553.pdf filter=lfs diff=lfs merge=lfs -text
|
38 |
+
asset/long[[:space:]]text[[:space:]]understanding.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: indigo
|
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.31.1
|
8 |
-
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
|
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.31.1
|
8 |
+
app_file: app/hydit_app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
app/hydit_app.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
from pathlib import Path
|
4 |
+
from PIL import Image
|
5 |
+
import sys
|
6 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
7 |
+
|
8 |
+
from hydit.constants import SAMPLER_FACTORY
|
9 |
+
from sample_t2i import inferencer
|
10 |
+
|
11 |
+
ROOT = Path(__file__).parent.parent
|
12 |
+
SAMPLERS = list(SAMPLER_FACTORY.keys())
|
13 |
+
SIZES = {
|
14 |
+
"square": (1024, 1024),
|
15 |
+
"landscape": (768, 1280),
|
16 |
+
"portrait": (1280, 768),
|
17 |
+
}
|
18 |
+
|
19 |
+
def get_strings(lang):
|
20 |
+
lang_file = Path(f"app/lang/{lang}.csv")
|
21 |
+
strings = pd.read_csv(lang_file, header=0)
|
22 |
+
strings = strings.set_index("key")['value'].to_dict()
|
23 |
+
return strings
|
24 |
+
|
25 |
+
|
26 |
+
args, gen, enhancer = inferencer()
|
27 |
+
strings = get_strings("en")
|
28 |
+
|
29 |
+
|
30 |
+
def infer(
|
31 |
+
prompt,
|
32 |
+
negative_prompt,
|
33 |
+
seed,
|
34 |
+
cfg_scale,
|
35 |
+
infer_steps,
|
36 |
+
oriW, oriH,
|
37 |
+
sampler,
|
38 |
+
size,
|
39 |
+
enhance
|
40 |
+
):
|
41 |
+
if enhance and enhancer is not None:
|
42 |
+
success, enhanced_prompt = enhancer(prompt)
|
43 |
+
if not success:
|
44 |
+
fail_image = Image.open(ROOT / 'app/fail.png')
|
45 |
+
return fail_image
|
46 |
+
else:
|
47 |
+
enhanced_prompt = None
|
48 |
+
|
49 |
+
height, width = SIZES[size]
|
50 |
+
results = gen.predict(prompt,
|
51 |
+
height=height,
|
52 |
+
width=width,
|
53 |
+
seed=seed,
|
54 |
+
enhanced_prompt=enhanced_prompt,
|
55 |
+
negative_prompt=negative_prompt,
|
56 |
+
infer_steps=infer_steps,
|
57 |
+
guidance_scale=cfg_scale,
|
58 |
+
batch_size=1,
|
59 |
+
src_size_cond=(oriW, oriH),
|
60 |
+
sampler=sampler,
|
61 |
+
)
|
62 |
+
image = results['images'][0]
|
63 |
+
return image
|
64 |
+
|
65 |
+
|
66 |
+
def ui():
|
67 |
+
block = gr.Blocks()
|
68 |
+
|
69 |
+
description = f"""
|
70 |
+
# {strings['title']}
|
71 |
+
|
72 |
+
## {strings['desc']}
|
73 |
+
|
74 |
+
"""
|
75 |
+
|
76 |
+
with block:
|
77 |
+
with gr.Row():
|
78 |
+
gr.Markdown(description)
|
79 |
+
with gr.Row():
|
80 |
+
with gr.Column():
|
81 |
+
with gr.Row():
|
82 |
+
size = gr.Radio(
|
83 |
+
label=strings['size'], choices=[
|
84 |
+
(strings['square'], 'square'),
|
85 |
+
(strings['landscape'], 'landscape'),
|
86 |
+
(strings['portrait'], 'portrait'),
|
87 |
+
],
|
88 |
+
value="square"
|
89 |
+
)
|
90 |
+
prompt = gr.Textbox(label=strings['prompt'], value=strings['default prompt'], lines=3)
|
91 |
+
with gr.Row():
|
92 |
+
infer_steps = gr.Slider(
|
93 |
+
label=strings['infer steps'], minimum=1, maximum=200, value=100, step=1,
|
94 |
+
)
|
95 |
+
seed = gr.Number(
|
96 |
+
label=strings['seed'], minimum=-1, maximum=1_000_000_000, value=1, step=1, precision=0,
|
97 |
+
)
|
98 |
+
enhance = gr.Checkbox(
|
99 |
+
label=strings['enhance'], value=enhancer is not None, interactive=True,
|
100 |
+
)
|
101 |
+
|
102 |
+
with gr.Accordion(
|
103 |
+
strings['accordion'], open=False
|
104 |
+
):
|
105 |
+
with gr.Row():
|
106 |
+
negative_prompt = gr.Textbox(label=strings['negative_prompt'],
|
107 |
+
value=gen.default_negative_prompt,
|
108 |
+
lines=2,
|
109 |
+
)
|
110 |
+
with gr.Row():
|
111 |
+
sampler = gr.Dropdown(SAMPLERS, label=strings['sampler'], value="ddpm")
|
112 |
+
cfg_scale = gr.Slider(
|
113 |
+
label=strings['cfg'], minimum=1.0, maximum=16.0, value=6.0, step=1
|
114 |
+
)
|
115 |
+
oriW = gr.Number(
|
116 |
+
label=strings['width cond'], minimum=1024, maximum=4096, value=1024, step=64, precision=0,
|
117 |
+
min_width=80,
|
118 |
+
)
|
119 |
+
oriH = gr.Number(
|
120 |
+
label=strings['height cond'], minimum=1024, maximum=4096, value=1024, step=64, precision=0,
|
121 |
+
min_width=80,
|
122 |
+
)
|
123 |
+
with gr.Row():
|
124 |
+
advanced_button = gr.Button(strings['run'])
|
125 |
+
with gr.Column():
|
126 |
+
default_img = Image.open(ROOT / 'app/default.png')
|
127 |
+
output_img = gr.Image(
|
128 |
+
label=strings['generated image'],
|
129 |
+
interactive=False,
|
130 |
+
format='png',
|
131 |
+
value=default_img,
|
132 |
+
)
|
133 |
+
advanced_button.click(
|
134 |
+
fn=infer,
|
135 |
+
inputs=[
|
136 |
+
prompt, negative_prompt, seed, cfg_scale, infer_steps,
|
137 |
+
oriW, oriH, sampler, size, enhance,
|
138 |
+
],
|
139 |
+
outputs=output_img,
|
140 |
+
)
|
141 |
+
|
142 |
+
with gr.Row():
|
143 |
+
gr.Examples([
|
144 |
+
['一只小猫'],
|
145 |
+
['现实主义风格,画面主要描述一个巴洛克风格的花瓶,带有金色的装饰边框,花瓶上盛开着各种色彩鲜艳的花,白色背景'],
|
146 |
+
['一只聪明的狐狸走在阔叶树林里, 旁边是一条小溪, 细节真实, 摄影'],
|
147 |
+
['飞流直下三千尺,疑是银河落九天'],
|
148 |
+
['一只长靴猫手持亮银色的宝剑,身着铠甲,眼神坚毅,站在一堆金币上,背景是暗色调的洞穴,图像上有金币的光影点缀。'],
|
149 |
+
['麻婆豆腐'],
|
150 |
+
['苏州园林'],
|
151 |
+
['一颗新鲜的草莓特写,红色的外表,表面布满许多种子,背景是淡绿色的叶子'],
|
152 |
+
['请画出“忽如一夜春风来 千树万树梨花开”'],
|
153 |
+
['请将“杞人忧天”的样子画出来'],
|
154 |
+
['枯藤老树昏鸦,小桥流水人家'],
|
155 |
+
['湖水清澈,天空湛蓝,阳光灿烂。一只优雅的白天鹅在湖边游泳。它周围有几只小鸭子,看起来非常可爱,整个画面给人一种宁静祥和的感觉。'],
|
156 |
+
['一朵鲜艳的红色玫瑰花,花瓣撒有一些水珠,晶莹剔透,特写镜头'],
|
157 |
+
['臭豆腐'],
|
158 |
+
['九寨沟'],
|
159 |
+
['俗语“鲤鱼跃龙门”'],
|
160 |
+
['风格是写实,画面主要描述一个亚洲戏曲艺术家正在表演,她穿着华丽的戏服,脸上戴着精致的面具,身姿优雅,背景是古色古香的舞台,镜头是近景'],
|
161 |
+
],
|
162 |
+
[prompt],
|
163 |
+
label=strings['examples']
|
164 |
+
)
|
165 |
+
return block
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
interface = ui()
|
170 |
+
interface.launch()
|
app/lang/en.csv
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
key,value
|
2 |
+
size,Size
|
3 |
+
sampler,Sampler
|
4 |
+
prompt,Prompt
|
5 |
+
default prompt,"A cute cat"
|
6 |
+
negative_prompt,Negative Prompt
|
7 |
+
seed,Seed
|
8 |
+
cfg,CFG Scale
|
9 |
+
infer steps,Sampling Steps
|
10 |
+
batch size,Batch Size
|
11 |
+
width cond,Width Cond
|
12 |
+
height cond,Height Cond
|
13 |
+
enhance,Prompt Enhancement
|
14 |
+
run,Submit
|
15 |
+
square,Square(1024x1024)
|
16 |
+
landscape,Landscape(1280x768)
|
17 |
+
portrait,Portrait(768x1280)
|
18 |
+
accordion,Advanced Options
|
19 |
+
generated image,HunYuanDiT Generated Image
|
20 |
+
examples,More Examples
|
21 |
+
title,Hunyuan-DiT
|
22 |
+
desc,A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding
|
app/lang/zh.csv
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
key,value
|
2 |
+
size,尺寸
|
3 |
+
sampler,采样器
|
4 |
+
prompt,文本描述
|
5 |
+
default prompt,"一只可爱的猫"
|
6 |
+
negative_prompt,负向词
|
7 |
+
seed,种子
|
8 |
+
cfg,CFG系数
|
9 |
+
infer steps,采样步数
|
10 |
+
batch size,批大小
|
11 |
+
width cond,宽度条件
|
12 |
+
height cond,高度条件
|
13 |
+
enhance,文本增强
|
14 |
+
run,提交生成
|
15 |
+
square,方形(1024x1024)
|
16 |
+
portrait,竖屏(1280x768)
|
17 |
+
landscape,横屏(768x1280)
|
18 |
+
accordion,高级设置
|
19 |
+
generated image,HunYuanDiT 生成
|
20 |
+
examples,更多示例
|
21 |
+
title,混元-DiT
|
22 |
+
desc,具有细粒度中文理解的高性能多分辨率 Diffusion Transformer 模型
|
asset/Hunyuan_DiT_Tech_Report_05140553.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3f8514b002ba3bb4704575096683f65e09df06693a54bf3004f0b351138ab1e5
|
3 |
+
size 42132252
|
asset/chinese elements understanding.png
ADDED
Git LFS Details
|
asset/cover.png
ADDED
asset/framework.png
ADDED
asset/logo.png
ADDED
asset/long text understanding.png
ADDED
Git LFS Details
|
asset/mllm.png
ADDED
asset/radar.png
ADDED
dialoggen/dialoggen_demo.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
# 添加当前命令行运行的目录到 sys.path
|
6 |
+
sys.path.append(os.getcwd()+"/dialoggen")
|
7 |
+
|
8 |
+
|
9 |
+
from llava.constants import (
|
10 |
+
IMAGE_TOKEN_INDEX,
|
11 |
+
DEFAULT_IMAGE_TOKEN,
|
12 |
+
DEFAULT_IM_START_TOKEN,
|
13 |
+
DEFAULT_IM_END_TOKEN,
|
14 |
+
IMAGE_PLACEHOLDER,
|
15 |
+
)
|
16 |
+
from llava.conversation import conv_templates, SeparatorStyle
|
17 |
+
from llava.model.builder import load_pretrained_model
|
18 |
+
from llava.utils import disable_torch_init
|
19 |
+
from llava.mm_utils import (
|
20 |
+
process_images,
|
21 |
+
tokenizer_image_token,
|
22 |
+
get_model_name_from_path,
|
23 |
+
)
|
24 |
+
|
25 |
+
import requests
|
26 |
+
from PIL import Image
|
27 |
+
from io import BytesIO
|
28 |
+
import re
|
29 |
+
|
30 |
+
|
31 |
+
def image_parser(image_file, sep=','):
|
32 |
+
out = image_file.split(sep)
|
33 |
+
return out
|
34 |
+
|
35 |
+
|
36 |
+
def load_image(image_file):
|
37 |
+
if image_file.startswith("http") or image_file.startswith("https"):
|
38 |
+
response = requests.get(image_file)
|
39 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
40 |
+
else:
|
41 |
+
image = Image.open(image_file).convert("RGB")
|
42 |
+
return image
|
43 |
+
|
44 |
+
|
45 |
+
def load_images(image_files):
|
46 |
+
out = []
|
47 |
+
for image_file in image_files:
|
48 |
+
image = load_image(image_file)
|
49 |
+
out.append(image)
|
50 |
+
return out
|
51 |
+
|
52 |
+
|
53 |
+
def init_dialoggen_model(model_path, model_base=None):
|
54 |
+
model_name = get_model_name_from_path(model_path)
|
55 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
56 |
+
model_path, model_base, model_name, llava_type_model=True)
|
57 |
+
return {"tokenizer": tokenizer,
|
58 |
+
"model": model,
|
59 |
+
"image_processor": image_processor}
|
60 |
+
|
61 |
+
|
62 |
+
def eval_model(models,
|
63 |
+
query='详细描述一下这张图片',
|
64 |
+
image_file=None,
|
65 |
+
sep=',',
|
66 |
+
temperature=0.2,
|
67 |
+
top_p=None,
|
68 |
+
num_beams=1,
|
69 |
+
max_new_tokens=512,
|
70 |
+
):
|
71 |
+
# Model
|
72 |
+
disable_torch_init()
|
73 |
+
|
74 |
+
qs = query
|
75 |
+
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
76 |
+
if IMAGE_PLACEHOLDER in qs:
|
77 |
+
if models["model"].config.mm_use_im_start_end:
|
78 |
+
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
|
79 |
+
else:
|
80 |
+
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
|
81 |
+
else:
|
82 |
+
if models["model"].config.mm_use_im_start_end:
|
83 |
+
qs = image_token_se + "\n" + qs
|
84 |
+
else:
|
85 |
+
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
|
86 |
+
|
87 |
+
conv = conv_templates['llava_v1'].copy()
|
88 |
+
conv.append_message(conv.roles[0], qs)
|
89 |
+
conv.append_message(conv.roles[1], None)
|
90 |
+
prompt = conv.get_prompt()
|
91 |
+
|
92 |
+
if image_file is not None:
|
93 |
+
image_files = image_parser(image_file, sep=sep)
|
94 |
+
images = load_images(image_files)
|
95 |
+
image_sizes = [x.size for x in images]
|
96 |
+
images_tensor = process_images(
|
97 |
+
images,
|
98 |
+
models["image_processor"],
|
99 |
+
models["model"].config
|
100 |
+
).to(models["model"].device, dtype=torch.float16)
|
101 |
+
else:
|
102 |
+
# fomatted input as training data
|
103 |
+
image_sizes = [(1024, 1024)]
|
104 |
+
images_tensor = torch.zeros(1, 5, 3, models["image_processor"].crop_size["height"], models["image_processor"].crop_size["width"])
|
105 |
+
images_tensor = images_tensor.to(models["model"].device, dtype=torch.float16)
|
106 |
+
|
107 |
+
input_ids = (
|
108 |
+
tokenizer_image_token(prompt, models["tokenizer"], IMAGE_TOKEN_INDEX, return_tensors="pt")
|
109 |
+
.unsqueeze(0)
|
110 |
+
.cuda()
|
111 |
+
)
|
112 |
+
with torch.inference_mode():
|
113 |
+
output_ids = models["model"].generate(
|
114 |
+
input_ids,
|
115 |
+
images=images_tensor,
|
116 |
+
image_sizes=image_sizes,
|
117 |
+
do_sample=True if temperature > 0 else False,
|
118 |
+
temperature=temperature,
|
119 |
+
top_p=top_p,
|
120 |
+
num_beams=num_beams,
|
121 |
+
max_new_tokens=max_new_tokens,
|
122 |
+
use_cache=True,
|
123 |
+
)
|
124 |
+
|
125 |
+
outputs = models["tokenizer"].batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
126 |
+
return outputs
|
127 |
+
|
128 |
+
|
129 |
+
def remove_prefix(text):
|
130 |
+
if text.startswith("<画图>"):
|
131 |
+
return text[len("<画图>"):], True
|
132 |
+
elif text.startswith("对不起"):
|
133 |
+
# 拒绝画图
|
134 |
+
return "", False
|
135 |
+
else:
|
136 |
+
return text, True
|
137 |
+
|
138 |
+
|
139 |
+
class DialogGen(object):
|
140 |
+
def __init__(self, model_path):
|
141 |
+
self.models = init_dialoggen_model(model_path)
|
142 |
+
self.query_template = "请先判断用户的意图,若为画图则在输出前加入<画图>:{}"
|
143 |
+
|
144 |
+
def __call__(self, prompt):
|
145 |
+
enhanced_prompt = eval_model(
|
146 |
+
models=self.models,
|
147 |
+
query=self.query_template.format(prompt),
|
148 |
+
image_file=None,
|
149 |
+
)
|
150 |
+
|
151 |
+
enhanced_prompt, compliance = remove_prefix(enhanced_prompt)
|
152 |
+
if not compliance:
|
153 |
+
return False, ""
|
154 |
+
return True, enhanced_prompt
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
parser = argparse.ArgumentParser()
|
159 |
+
parser.add_argument('--model_path', type=str, default='./ckpts/dialoggen')
|
160 |
+
parser.add_argument('--prompt', type=str, default='画一只小猫')
|
161 |
+
parser.add_argument('--image_file', type=str, default=None) # 'images/demo1.jpeg'
|
162 |
+
args = parser.parse_args()
|
163 |
+
|
164 |
+
query = f"请先判断用户的意图,若为画图则在输出前加入<画图>:{args.prompt}"
|
165 |
+
|
166 |
+
models = init_dialoggen_model(args.model_path)
|
167 |
+
|
168 |
+
res = eval_model(models,
|
169 |
+
query=query,
|
170 |
+
image_file=args.image_file,
|
171 |
+
)
|
172 |
+
print(res)
|
dialoggen/images/demo1.jpeg
ADDED
dialoggen/images/demo2.jpeg
ADDED
dialoggen/llava/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import LlavaLlamaForCausalLM
|
dialoggen/llava/constants.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
IMAGE_TOKEN_INDEX = -200
|
9 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
10 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
11 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
12 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
13 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
dialoggen/llava/conversation.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
import base64
|
5 |
+
from io import BytesIO
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
class SeparatorStyle(Enum):
|
10 |
+
"""Different separator style."""
|
11 |
+
SINGLE = auto()
|
12 |
+
TWO = auto()
|
13 |
+
MPT = auto()
|
14 |
+
PLAIN = auto()
|
15 |
+
LLAMA_2 = auto()
|
16 |
+
|
17 |
+
|
18 |
+
@dataclasses.dataclass
|
19 |
+
class Conversation:
|
20 |
+
"""A class that keeps all conversation history."""
|
21 |
+
system: str
|
22 |
+
roles: List[str]
|
23 |
+
messages: List[List[str]]
|
24 |
+
offset: int
|
25 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
26 |
+
sep: str = "###"
|
27 |
+
sep2: str = None
|
28 |
+
version: str = "Unknown"
|
29 |
+
|
30 |
+
skip_next: bool = False
|
31 |
+
|
32 |
+
def get_prompt(self):
|
33 |
+
messages = self.messages
|
34 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
35 |
+
messages = self.messages.copy()
|
36 |
+
init_role, init_msg = messages[0].copy()
|
37 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
38 |
+
if 'mmtag' in self.version:
|
39 |
+
messages[0] = (init_role, init_msg)
|
40 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
41 |
+
messages.insert(1, (self.roles[1], "Received."))
|
42 |
+
else:
|
43 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
44 |
+
|
45 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
46 |
+
ret = self.system + self.sep
|
47 |
+
for role, message in messages:
|
48 |
+
if message:
|
49 |
+
if type(message) is tuple:
|
50 |
+
message, _, _ = message
|
51 |
+
ret += role + ": " + message + self.sep
|
52 |
+
else:
|
53 |
+
ret += role + ":"
|
54 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
55 |
+
seps = [self.sep, self.sep2]
|
56 |
+
ret = self.system + seps[0]
|
57 |
+
for i, (role, message) in enumerate(messages):
|
58 |
+
if message:
|
59 |
+
if type(message) is tuple:
|
60 |
+
message, _, _ = message
|
61 |
+
ret += role + ": " + message + seps[i % 2]
|
62 |
+
else:
|
63 |
+
ret += role + ":"
|
64 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
65 |
+
ret = self.system + self.sep
|
66 |
+
for role, message in messages:
|
67 |
+
if message:
|
68 |
+
if type(message) is tuple:
|
69 |
+
message, _, _ = message
|
70 |
+
ret += role + message + self.sep
|
71 |
+
else:
|
72 |
+
ret += role
|
73 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
74 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
75 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
76 |
+
ret = ""
|
77 |
+
|
78 |
+
for i, (role, message) in enumerate(messages):
|
79 |
+
if i == 0:
|
80 |
+
assert message, "first message should not be none"
|
81 |
+
assert role == self.roles[0], "first message should come from user"
|
82 |
+
if message:
|
83 |
+
if type(message) is tuple:
|
84 |
+
message, _, _ = message
|
85 |
+
if i == 0: message = wrap_sys(self.system) + message
|
86 |
+
if i % 2 == 0:
|
87 |
+
message = wrap_inst(message)
|
88 |
+
ret += self.sep + message
|
89 |
+
else:
|
90 |
+
ret += " " + message + " " + self.sep2
|
91 |
+
else:
|
92 |
+
ret += ""
|
93 |
+
ret = ret.lstrip(self.sep)
|
94 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
95 |
+
seps = [self.sep, self.sep2]
|
96 |
+
ret = self.system
|
97 |
+
for i, (role, message) in enumerate(messages):
|
98 |
+
if message:
|
99 |
+
if type(message) is tuple:
|
100 |
+
message, _, _ = message
|
101 |
+
ret += message + seps[i % 2]
|
102 |
+
else:
|
103 |
+
ret += ""
|
104 |
+
else:
|
105 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
106 |
+
|
107 |
+
return ret
|
108 |
+
|
109 |
+
def append_message(self, role, message):
|
110 |
+
self.messages.append([role, message])
|
111 |
+
|
112 |
+
def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
|
113 |
+
if image_process_mode == "Pad":
|
114 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
115 |
+
width, height = pil_img.size
|
116 |
+
if width == height:
|
117 |
+
return pil_img
|
118 |
+
elif width > height:
|
119 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
120 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
121 |
+
return result
|
122 |
+
else:
|
123 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
124 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
125 |
+
return result
|
126 |
+
image = expand2square(image)
|
127 |
+
elif image_process_mode in ["Default", "Crop"]:
|
128 |
+
pass
|
129 |
+
elif image_process_mode == "Resize":
|
130 |
+
image = image.resize((336, 336))
|
131 |
+
else:
|
132 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
133 |
+
if max(image.size) > max_len:
|
134 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
135 |
+
aspect_ratio = max_hw / min_hw
|
136 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
137 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
138 |
+
W, H = image.size
|
139 |
+
if H > W:
|
140 |
+
H, W = longest_edge, shortest_edge
|
141 |
+
else:
|
142 |
+
H, W = shortest_edge, longest_edge
|
143 |
+
image = image.resize((W, H))
|
144 |
+
if return_pil:
|
145 |
+
return image
|
146 |
+
else:
|
147 |
+
buffered = BytesIO()
|
148 |
+
image.save(buffered, format=image_format)
|
149 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
150 |
+
return img_b64_str
|
151 |
+
|
152 |
+
def get_images(self, return_pil=False):
|
153 |
+
images = []
|
154 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
155 |
+
if i % 2 == 0:
|
156 |
+
if type(msg) is tuple:
|
157 |
+
msg, image, image_process_mode = msg
|
158 |
+
image = self.process_image(image, image_process_mode, return_pil=return_pil)
|
159 |
+
images.append(image)
|
160 |
+
return images
|
161 |
+
|
162 |
+
def to_gradio_chatbot(self):
|
163 |
+
ret = []
|
164 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
165 |
+
if i % 2 == 0:
|
166 |
+
if type(msg) is tuple:
|
167 |
+
msg, image, image_process_mode = msg
|
168 |
+
img_b64_str = self.process_image(
|
169 |
+
image, "Default", return_pil=False,
|
170 |
+
image_format='JPEG')
|
171 |
+
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
|
172 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
173 |
+
ret.append([msg, None])
|
174 |
+
else:
|
175 |
+
ret.append([msg, None])
|
176 |
+
else:
|
177 |
+
ret[-1][-1] = msg
|
178 |
+
return ret
|
179 |
+
|
180 |
+
def copy(self):
|
181 |
+
return Conversation(
|
182 |
+
system=self.system,
|
183 |
+
roles=self.roles,
|
184 |
+
messages=[[x, y] for x, y in self.messages],
|
185 |
+
offset=self.offset,
|
186 |
+
sep_style=self.sep_style,
|
187 |
+
sep=self.sep,
|
188 |
+
sep2=self.sep2,
|
189 |
+
version=self.version)
|
190 |
+
|
191 |
+
def dict(self):
|
192 |
+
if len(self.get_images()) > 0:
|
193 |
+
return {
|
194 |
+
"system": self.system,
|
195 |
+
"roles": self.roles,
|
196 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
197 |
+
"offset": self.offset,
|
198 |
+
"sep": self.sep,
|
199 |
+
"sep2": self.sep2,
|
200 |
+
}
|
201 |
+
return {
|
202 |
+
"system": self.system,
|
203 |
+
"roles": self.roles,
|
204 |
+
"messages": self.messages,
|
205 |
+
"offset": self.offset,
|
206 |
+
"sep": self.sep,
|
207 |
+
"sep2": self.sep2,
|
208 |
+
}
|
209 |
+
|
210 |
+
|
211 |
+
conv_vicuna_v0 = Conversation(
|
212 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
213 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
214 |
+
roles=("Human", "Assistant"),
|
215 |
+
messages=(
|
216 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
217 |
+
("Assistant",
|
218 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
219 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
220 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
221 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
222 |
+
"renewable and non-renewable energy sources:\n"
|
223 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
224 |
+
"energy sources are finite and will eventually run out.\n"
|
225 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
226 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
227 |
+
"and other negative effects.\n"
|
228 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
229 |
+
"have lower operational costs than non-renewable sources.\n"
|
230 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
231 |
+
"locations than non-renewable sources.\n"
|
232 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
233 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
234 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
235 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
236 |
+
),
|
237 |
+
offset=2,
|
238 |
+
sep_style=SeparatorStyle.SINGLE,
|
239 |
+
sep="###",
|
240 |
+
)
|
241 |
+
|
242 |
+
conv_vicuna_v1 = Conversation(
|
243 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
244 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
245 |
+
roles=("USER", "ASSISTANT"),
|
246 |
+
version="v1",
|
247 |
+
messages=(),
|
248 |
+
offset=0,
|
249 |
+
sep_style=SeparatorStyle.TWO,
|
250 |
+
sep=" ",
|
251 |
+
sep2="</s>",
|
252 |
+
)
|
253 |
+
|
254 |
+
conv_llama_2 = Conversation(
|
255 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
256 |
+
|
257 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
258 |
+
roles=("USER", "ASSISTANT"),
|
259 |
+
version="llama_v2",
|
260 |
+
messages=(),
|
261 |
+
offset=0,
|
262 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
263 |
+
sep="<s>",
|
264 |
+
sep2="</s>",
|
265 |
+
)
|
266 |
+
|
267 |
+
conv_llava_llama_2 = Conversation(
|
268 |
+
system="You are a helpful language and vision assistant. "
|
269 |
+
"You are able to understand the visual content that the user provides, "
|
270 |
+
"and assist the user with a variety of tasks using natural language.",
|
271 |
+
roles=("USER", "ASSISTANT"),
|
272 |
+
version="llama_v2",
|
273 |
+
messages=(),
|
274 |
+
offset=0,
|
275 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
276 |
+
sep="<s>",
|
277 |
+
sep2="</s>",
|
278 |
+
)
|
279 |
+
|
280 |
+
conv_mpt = Conversation(
|
281 |
+
system="""<|im_start|>system
|
282 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
283 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
284 |
+
version="mpt",
|
285 |
+
messages=(),
|
286 |
+
offset=0,
|
287 |
+
sep_style=SeparatorStyle.MPT,
|
288 |
+
sep="<|im_end|>",
|
289 |
+
)
|
290 |
+
|
291 |
+
conv_llava_plain = Conversation(
|
292 |
+
system="",
|
293 |
+
roles=("", ""),
|
294 |
+
messages=(
|
295 |
+
),
|
296 |
+
offset=0,
|
297 |
+
sep_style=SeparatorStyle.PLAIN,
|
298 |
+
sep="\n",
|
299 |
+
)
|
300 |
+
|
301 |
+
conv_llava_v0 = Conversation(
|
302 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
303 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
304 |
+
roles=("Human", "Assistant"),
|
305 |
+
messages=(
|
306 |
+
),
|
307 |
+
offset=0,
|
308 |
+
sep_style=SeparatorStyle.SINGLE,
|
309 |
+
sep="###",
|
310 |
+
)
|
311 |
+
|
312 |
+
conv_llava_v0_mmtag = Conversation(
|
313 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
314 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
315 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
316 |
+
roles=("Human", "Assistant"),
|
317 |
+
messages=(
|
318 |
+
),
|
319 |
+
offset=0,
|
320 |
+
sep_style=SeparatorStyle.SINGLE,
|
321 |
+
sep="###",
|
322 |
+
version="v0_mmtag",
|
323 |
+
)
|
324 |
+
|
325 |
+
conv_llava_v1 = Conversation(
|
326 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
327 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
328 |
+
roles=("USER", "ASSISTANT"),
|
329 |
+
version="v1",
|
330 |
+
messages=(),
|
331 |
+
offset=0,
|
332 |
+
sep_style=SeparatorStyle.TWO,
|
333 |
+
sep=" ",
|
334 |
+
sep2="</s>",
|
335 |
+
)
|
336 |
+
|
337 |
+
conv_llava_v1_mmtag = Conversation(
|
338 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
339 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
340 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
341 |
+
roles=("USER", "ASSISTANT"),
|
342 |
+
messages=(),
|
343 |
+
offset=0,
|
344 |
+
sep_style=SeparatorStyle.TWO,
|
345 |
+
sep=" ",
|
346 |
+
sep2="</s>",
|
347 |
+
version="v1_mmtag",
|
348 |
+
)
|
349 |
+
|
350 |
+
conv_mistral_instruct = Conversation(
|
351 |
+
system="",
|
352 |
+
roles=("USER", "ASSISTANT"),
|
353 |
+
version="llama_v2",
|
354 |
+
messages=(),
|
355 |
+
offset=0,
|
356 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
357 |
+
sep="",
|
358 |
+
sep2="</s>",
|
359 |
+
)
|
360 |
+
|
361 |
+
conv_chatml_direct = Conversation(
|
362 |
+
system="""<|im_start|>system
|
363 |
+
Answer the questions.""",
|
364 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
365 |
+
version="mpt",
|
366 |
+
messages=(),
|
367 |
+
offset=0,
|
368 |
+
sep_style=SeparatorStyle.MPT,
|
369 |
+
sep="<|im_end|>",
|
370 |
+
)
|
371 |
+
|
372 |
+
default_conversation = conv_vicuna_v1
|
373 |
+
conv_templates = {
|
374 |
+
"default": conv_vicuna_v0,
|
375 |
+
"v0": conv_vicuna_v0,
|
376 |
+
"v1": conv_vicuna_v1,
|
377 |
+
"vicuna_v1": conv_vicuna_v1,
|
378 |
+
"llama_2": conv_llama_2,
|
379 |
+
"mistral_instruct": conv_mistral_instruct,
|
380 |
+
"chatml_direct": conv_chatml_direct,
|
381 |
+
"mistral_direct": conv_chatml_direct,
|
382 |
+
|
383 |
+
"plain": conv_llava_plain,
|
384 |
+
"v0_plain": conv_llava_plain,
|
385 |
+
"llava_v0": conv_llava_v0,
|
386 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
387 |
+
"llava_v1": conv_llava_v1,
|
388 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
389 |
+
"llava_llama_2": conv_llava_llama_2,
|
390 |
+
|
391 |
+
"mpt": conv_mpt,
|
392 |
+
}
|
393 |
+
|
394 |
+
|
395 |
+
if __name__ == "__main__":
|
396 |
+
print(default_conversation.get_prompt())
|
dialoggen/llava/mm_utils.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import base64
|
4 |
+
import torch
|
5 |
+
import math
|
6 |
+
import ast
|
7 |
+
|
8 |
+
from transformers import StoppingCriteria
|
9 |
+
from llava.constants import IMAGE_TOKEN_INDEX
|
10 |
+
|
11 |
+
|
12 |
+
def select_best_resolution(original_size, possible_resolutions):
|
13 |
+
"""
|
14 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
18 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
tuple: The best fit resolution in the format (width, height).
|
22 |
+
"""
|
23 |
+
original_width, original_height = original_size
|
24 |
+
best_fit = None
|
25 |
+
max_effective_resolution = 0
|
26 |
+
min_wasted_resolution = float('inf')
|
27 |
+
|
28 |
+
for width, height in possible_resolutions:
|
29 |
+
scale = min(width / original_width, height / original_height)
|
30 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
31 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
32 |
+
wasted_resolution = (width * height) - effective_resolution
|
33 |
+
|
34 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
35 |
+
max_effective_resolution = effective_resolution
|
36 |
+
min_wasted_resolution = wasted_resolution
|
37 |
+
best_fit = (width, height)
|
38 |
+
|
39 |
+
return best_fit
|
40 |
+
|
41 |
+
|
42 |
+
def resize_and_pad_image(image, target_resolution):
|
43 |
+
"""
|
44 |
+
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
image (PIL.Image.Image): The input image.
|
48 |
+
target_resolution (tuple): The target resolution (width, height) of the image.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
PIL.Image.Image: The resized and padded image.
|
52 |
+
"""
|
53 |
+
original_width, original_height = image.size
|
54 |
+
target_width, target_height = target_resolution
|
55 |
+
|
56 |
+
scale_w = target_width / original_width
|
57 |
+
scale_h = target_height / original_height
|
58 |
+
|
59 |
+
if scale_w < scale_h:
|
60 |
+
new_width = target_width
|
61 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
62 |
+
else:
|
63 |
+
new_height = target_height
|
64 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
65 |
+
|
66 |
+
# Resize the image
|
67 |
+
resized_image = image.resize((new_width, new_height))
|
68 |
+
|
69 |
+
new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
|
70 |
+
paste_x = (target_width - new_width) // 2
|
71 |
+
paste_y = (target_height - new_height) // 2
|
72 |
+
new_image.paste(resized_image, (paste_x, paste_y))
|
73 |
+
|
74 |
+
return new_image
|
75 |
+
|
76 |
+
|
77 |
+
def divide_to_patches(image, patch_size):
|
78 |
+
"""
|
79 |
+
Divides an image into patches of a specified size.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
image (PIL.Image.Image): The input image.
|
83 |
+
patch_size (int): The size of each patch.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
87 |
+
"""
|
88 |
+
patches = []
|
89 |
+
width, height = image.size
|
90 |
+
for i in range(0, height, patch_size):
|
91 |
+
for j in range(0, width, patch_size):
|
92 |
+
box = (j, i, j + patch_size, i + patch_size)
|
93 |
+
patch = image.crop(box)
|
94 |
+
patches.append(patch)
|
95 |
+
|
96 |
+
return patches
|
97 |
+
|
98 |
+
|
99 |
+
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
100 |
+
"""
|
101 |
+
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
image_size (tuple): The size of the input image in the format (width, height).
|
105 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
106 |
+
patch_size (int): The size of each image patch.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
tuple: The shape of the image patch grid in the format (width, height).
|
110 |
+
"""
|
111 |
+
if type(grid_pinpoints) is list:
|
112 |
+
possible_resolutions = grid_pinpoints
|
113 |
+
else:
|
114 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
115 |
+
width, height = select_best_resolution(image_size, possible_resolutions)
|
116 |
+
return width // patch_size, height // patch_size
|
117 |
+
|
118 |
+
|
119 |
+
def process_anyres_image(image, processor, grid_pinpoints):
|
120 |
+
"""
|
121 |
+
Process an image with variable resolutions.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
image (PIL.Image.Image): The input image to be processed.
|
125 |
+
processor: The image processor object.
|
126 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
torch.Tensor: A tensor containing the processed image patches.
|
130 |
+
"""
|
131 |
+
if type(grid_pinpoints) is list:
|
132 |
+
possible_resolutions = grid_pinpoints
|
133 |
+
else:
|
134 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
135 |
+
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
136 |
+
image_padded = resize_and_pad_image(image, best_resolution)
|
137 |
+
|
138 |
+
patches = divide_to_patches(image_padded, processor.crop_size['height'])
|
139 |
+
|
140 |
+
image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
|
141 |
+
|
142 |
+
image_patches = [image_original_resize] + patches
|
143 |
+
image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
|
144 |
+
for image_patch in image_patches]
|
145 |
+
return torch.stack(image_patches, dim=0)
|
146 |
+
|
147 |
+
|
148 |
+
def load_image_from_base64(image):
|
149 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
150 |
+
|
151 |
+
|
152 |
+
def expand2square(pil_img, background_color):
|
153 |
+
width, height = pil_img.size
|
154 |
+
if width == height:
|
155 |
+
return pil_img
|
156 |
+
elif width > height:
|
157 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
158 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
159 |
+
return result
|
160 |
+
else:
|
161 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
162 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
163 |
+
return result
|
164 |
+
|
165 |
+
|
166 |
+
def process_images(images, image_processor, model_cfg):
|
167 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
168 |
+
new_images = []
|
169 |
+
if image_aspect_ratio == 'pad':
|
170 |
+
for image in images:
|
171 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
172 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
173 |
+
new_images.append(image)
|
174 |
+
elif image_aspect_ratio == "anyres":
|
175 |
+
for image in images:
|
176 |
+
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
|
177 |
+
new_images.append(image)
|
178 |
+
else:
|
179 |
+
return image_processor(images, return_tensors='pt')['pixel_values']
|
180 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
181 |
+
new_images = torch.stack(new_images, dim=0)
|
182 |
+
return new_images
|
183 |
+
|
184 |
+
|
185 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
186 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
187 |
+
|
188 |
+
def insert_separator(X, sep):
|
189 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
190 |
+
|
191 |
+
input_ids = []
|
192 |
+
offset = 0
|
193 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
194 |
+
offset = 1
|
195 |
+
input_ids.append(prompt_chunks[0][0])
|
196 |
+
|
197 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
198 |
+
input_ids.extend(x[offset:])
|
199 |
+
|
200 |
+
if return_tensors is not None:
|
201 |
+
if return_tensors == 'pt':
|
202 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
203 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
204 |
+
return input_ids
|
205 |
+
|
206 |
+
|
207 |
+
def get_model_name_from_path(model_path):
|
208 |
+
model_path = model_path.strip("/")
|
209 |
+
model_paths = model_path.split("/")
|
210 |
+
if model_paths[-1].startswith('checkpoint-'):
|
211 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
212 |
+
else:
|
213 |
+
return model_paths[-1]
|
214 |
+
|
215 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
216 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
217 |
+
self.keywords = keywords
|
218 |
+
self.keyword_ids = []
|
219 |
+
self.max_keyword_len = 0
|
220 |
+
for keyword in keywords:
|
221 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
222 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
223 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
224 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
225 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
226 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
227 |
+
self.tokenizer = tokenizer
|
228 |
+
self.start_len = input_ids.shape[1]
|
229 |
+
|
230 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
231 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
232 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
233 |
+
for keyword_id in self.keyword_ids:
|
234 |
+
truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
|
235 |
+
if torch.equal(truncated_output_ids, keyword_id):
|
236 |
+
return True
|
237 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
238 |
+
for keyword in self.keywords:
|
239 |
+
if keyword in outputs:
|
240 |
+
return True
|
241 |
+
return False
|
242 |
+
|
243 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
244 |
+
outputs = []
|
245 |
+
for i in range(output_ids.shape[0]):
|
246 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
247 |
+
return all(outputs)
|
dialoggen/llava/model/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
try:
|
2 |
+
from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
|
3 |
+
from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
|
4 |
+
from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
|
5 |
+
except:
|
6 |
+
pass
|
dialoggen/llava/model/apply_delta.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
+
from llava import LlavaLlamaForCausalLM
|
11 |
+
|
12 |
+
|
13 |
+
def apply_delta(base_model_path, target_model_path, delta_path):
|
14 |
+
print("Loading base model")
|
15 |
+
base = AutoModelForCausalLM.from_pretrained(
|
16 |
+
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
17 |
+
|
18 |
+
print("Loading delta")
|
19 |
+
delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
20 |
+
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
|
21 |
+
|
22 |
+
print("Applying delta")
|
23 |
+
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
|
24 |
+
if name not in base.state_dict():
|
25 |
+
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
|
26 |
+
continue
|
27 |
+
if param.data.shape == base.state_dict()[name].shape:
|
28 |
+
param.data += base.state_dict()[name]
|
29 |
+
else:
|
30 |
+
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
|
31 |
+
f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
|
32 |
+
bparam = base.state_dict()[name]
|
33 |
+
param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
|
34 |
+
|
35 |
+
print("Saving target model")
|
36 |
+
delta.save_pretrained(target_model_path)
|
37 |
+
delta_tokenizer.save_pretrained(target_model_path)
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
parser = argparse.ArgumentParser()
|
42 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
43 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
44 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
45 |
+
|
46 |
+
args = parser.parse_args()
|
47 |
+
|
48 |
+
apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
|
dialoggen/llava/model/builder.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import os
|
17 |
+
import warnings
|
18 |
+
import shutil
|
19 |
+
|
20 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
21 |
+
import torch
|
22 |
+
from llava.model import *
|
23 |
+
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
+
|
25 |
+
|
26 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, llava_type_model=True, **kwargs):
|
27 |
+
kwargs = {"device_map": device_map, **kwargs}
|
28 |
+
|
29 |
+
if device != "cuda":
|
30 |
+
kwargs['device_map'] = {"": device}
|
31 |
+
|
32 |
+
if load_8bit:
|
33 |
+
kwargs['load_in_8bit'] = True
|
34 |
+
elif load_4bit:
|
35 |
+
kwargs['load_in_4bit'] = True
|
36 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
37 |
+
load_in_4bit=True,
|
38 |
+
bnb_4bit_compute_dtype=torch.float16,
|
39 |
+
bnb_4bit_use_double_quant=True,
|
40 |
+
bnb_4bit_quant_type='nf4'
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
kwargs['torch_dtype'] = torch.float16
|
44 |
+
|
45 |
+
if use_flash_attn:
|
46 |
+
kwargs['attn_implementation'] = 'flash_attention_2'
|
47 |
+
|
48 |
+
if 'llava' in model_name.lower():
|
49 |
+
# Load LLaVA model
|
50 |
+
if 'lora' in model_name.lower() and model_base is None:
|
51 |
+
warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
|
52 |
+
if 'lora' in model_name.lower() and model_base is not None:
|
53 |
+
from llava.model.language_model.llava_llama import LlavaConfig
|
54 |
+
lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
|
55 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
56 |
+
print('Loading LLaVA from base model...')
|
57 |
+
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
|
58 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
59 |
+
if model.lm_head.weight.shape[0] != token_num:
|
60 |
+
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
61 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
62 |
+
|
63 |
+
print('Loading additional LLaVA weights...')
|
64 |
+
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
65 |
+
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
66 |
+
else:
|
67 |
+
# this is probably from HF Hub
|
68 |
+
from huggingface_hub import hf_hub_download
|
69 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
70 |
+
cache_file = hf_hub_download(
|
71 |
+
repo_id=repo_id,
|
72 |
+
filename=filename,
|
73 |
+
subfolder=subfolder)
|
74 |
+
return torch.load(cache_file, map_location='cpu')
|
75 |
+
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
|
76 |
+
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
77 |
+
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
78 |
+
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
79 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
80 |
+
|
81 |
+
from peft import PeftModel
|
82 |
+
print('Loading LoRA weights...')
|
83 |
+
model = PeftModel.from_pretrained(model, model_path)
|
84 |
+
print('Merging LoRA weights...')
|
85 |
+
model = model.merge_and_unload()
|
86 |
+
print('Model is loaded...')
|
87 |
+
elif model_base is not None:
|
88 |
+
# this may be mm projector only
|
89 |
+
print('Loading LLaVA from base model...')
|
90 |
+
if 'mpt' in model_name.lower():
|
91 |
+
if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
|
92 |
+
shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
|
93 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
|
94 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
95 |
+
model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
96 |
+
else:
|
97 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
98 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
99 |
+
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
100 |
+
|
101 |
+
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
102 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
103 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
104 |
+
else:
|
105 |
+
if 'mpt' in model_name.lower():
|
106 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
107 |
+
model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
108 |
+
elif 'mistral' in model_name.lower():
|
109 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
110 |
+
model = LlavaMistralForCausalLM.from_pretrained(
|
111 |
+
model_path,
|
112 |
+
low_cpu_mem_usage=True,
|
113 |
+
**kwargs
|
114 |
+
)
|
115 |
+
else:
|
116 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
117 |
+
model = LlavaLlamaForCausalLM.from_pretrained(
|
118 |
+
model_path,
|
119 |
+
low_cpu_mem_usage=True,
|
120 |
+
**kwargs
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
# Load language model
|
124 |
+
if model_base is not None:
|
125 |
+
# PEFT model
|
126 |
+
from peft import PeftModel
|
127 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
128 |
+
model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
|
129 |
+
print(f"Loading LoRA weights from {model_path}")
|
130 |
+
model = PeftModel.from_pretrained(model, model_path)
|
131 |
+
print(f"Merging weights")
|
132 |
+
model = model.merge_and_unload()
|
133 |
+
print('Convert to FP16...')
|
134 |
+
model.to(torch.float16)
|
135 |
+
else:
|
136 |
+
use_fast = False
|
137 |
+
if 'mpt' in model_name.lower():
|
138 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
139 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
|
140 |
+
else:
|
141 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
142 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
143 |
+
|
144 |
+
image_processor = None
|
145 |
+
|
146 |
+
if llava_type_model:
|
147 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
148 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
149 |
+
if mm_use_im_patch_token:
|
150 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
151 |
+
if mm_use_im_start_end:
|
152 |
+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
153 |
+
model.resize_token_embeddings(len(tokenizer))
|
154 |
+
|
155 |
+
vision_tower = model.get_vision_tower()
|
156 |
+
if not vision_tower.is_loaded:
|
157 |
+
vision_tower.load_model(device_map=device_map)
|
158 |
+
if device_map != 'auto':
|
159 |
+
vision_tower.to(device=device_map, dtype=torch.float16)
|
160 |
+
image_processor = vision_tower.image_processor
|
161 |
+
|
162 |
+
if hasattr(model.config, "max_sequence_length"):
|
163 |
+
context_len = model.config.max_sequence_length
|
164 |
+
else:
|
165 |
+
context_len = 2048
|
166 |
+
|
167 |
+
return tokenizer, model, image_processor, context_len
|
dialoggen/llava/model/consolidate.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
9 |
+
from llava.model import *
|
10 |
+
from llava.model.utils import auto_upgrade
|
11 |
+
|
12 |
+
|
13 |
+
def consolidate_ckpt(src_path, dst_path):
|
14 |
+
print("Loading model")
|
15 |
+
auto_upgrade(src_path)
|
16 |
+
src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
17 |
+
src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
|
18 |
+
src_model.save_pretrained(dst_path)
|
19 |
+
src_tokenizer.save_pretrained(dst_path)
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument("--src", type=str, required=True)
|
25 |
+
parser.add_argument("--dst", type=str, required=True)
|
26 |
+
|
27 |
+
args = parser.parse_args()
|
28 |
+
|
29 |
+
consolidate_ckpt(args.src, args.dst)
|
dialoggen/llava/model/language_model/llava_llama.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
22 |
+
LlamaConfig, LlamaModel, LlamaForCausalLM
|
23 |
+
|
24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
25 |
+
from transformers.generation.utils import GenerateOutput
|
26 |
+
|
27 |
+
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
28 |
+
|
29 |
+
|
30 |
+
class LlavaConfig(LlamaConfig):
|
31 |
+
model_type = "llava_llama"
|
32 |
+
|
33 |
+
|
34 |
+
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
|
35 |
+
config_class = LlavaConfig
|
36 |
+
|
37 |
+
def __init__(self, config: LlamaConfig):
|
38 |
+
super(LlavaLlamaModel, self).__init__(config)
|
39 |
+
|
40 |
+
|
41 |
+
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
|
42 |
+
config_class = LlavaConfig
|
43 |
+
|
44 |
+
def __init__(self, config):
|
45 |
+
super(LlamaForCausalLM, self).__init__(config)
|
46 |
+
self.model = LlavaLlamaModel(config)
|
47 |
+
self.pretraining_tp = config.pretraining_tp
|
48 |
+
self.vocab_size = config.vocab_size
|
49 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
50 |
+
|
51 |
+
# Initialize weights and apply final processing
|
52 |
+
self.post_init()
|
53 |
+
|
54 |
+
def get_model(self):
|
55 |
+
return self.model
|
56 |
+
|
57 |
+
def forward(
|
58 |
+
self,
|
59 |
+
input_ids: torch.LongTensor = None,
|
60 |
+
attention_mask: Optional[torch.Tensor] = None,
|
61 |
+
position_ids: Optional[torch.LongTensor] = None,
|
62 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
63 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
64 |
+
labels: Optional[torch.LongTensor] = None,
|
65 |
+
use_cache: Optional[bool] = None,
|
66 |
+
output_attentions: Optional[bool] = None,
|
67 |
+
output_hidden_states: Optional[bool] = None,
|
68 |
+
images: Optional[torch.FloatTensor] = None,
|
69 |
+
image_sizes: Optional[List[List[int]]] = None,
|
70 |
+
return_dict: Optional[bool] = None,
|
71 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
72 |
+
|
73 |
+
if inputs_embeds is None:
|
74 |
+
(
|
75 |
+
input_ids,
|
76 |
+
position_ids,
|
77 |
+
attention_mask,
|
78 |
+
past_key_values,
|
79 |
+
inputs_embeds,
|
80 |
+
labels
|
81 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
82 |
+
input_ids,
|
83 |
+
position_ids,
|
84 |
+
attention_mask,
|
85 |
+
past_key_values,
|
86 |
+
labels,
|
87 |
+
images,
|
88 |
+
image_sizes
|
89 |
+
)
|
90 |
+
|
91 |
+
return super().forward(
|
92 |
+
input_ids=input_ids,
|
93 |
+
attention_mask=attention_mask,
|
94 |
+
position_ids=position_ids,
|
95 |
+
past_key_values=past_key_values,
|
96 |
+
inputs_embeds=inputs_embeds,
|
97 |
+
labels=labels,
|
98 |
+
use_cache=use_cache,
|
99 |
+
output_attentions=output_attentions,
|
100 |
+
output_hidden_states=output_hidden_states,
|
101 |
+
return_dict=return_dict
|
102 |
+
)
|
103 |
+
|
104 |
+
@torch.no_grad()
|
105 |
+
def generate(
|
106 |
+
self,
|
107 |
+
inputs: Optional[torch.Tensor] = None,
|
108 |
+
images: Optional[torch.Tensor] = None,
|
109 |
+
image_sizes: Optional[torch.Tensor] = None,
|
110 |
+
**kwargs,
|
111 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
112 |
+
position_ids = kwargs.pop("position_ids", None)
|
113 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
114 |
+
if "inputs_embeds" in kwargs:
|
115 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
116 |
+
|
117 |
+
if images is not None:
|
118 |
+
(
|
119 |
+
inputs,
|
120 |
+
position_ids,
|
121 |
+
attention_mask,
|
122 |
+
_,
|
123 |
+
inputs_embeds,
|
124 |
+
_
|
125 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
126 |
+
inputs,
|
127 |
+
position_ids,
|
128 |
+
attention_mask,
|
129 |
+
None,
|
130 |
+
None,
|
131 |
+
images,
|
132 |
+
image_sizes=image_sizes
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
136 |
+
|
137 |
+
return super().generate(
|
138 |
+
position_ids=position_ids,
|
139 |
+
attention_mask=attention_mask,
|
140 |
+
inputs_embeds=inputs_embeds,
|
141 |
+
**kwargs
|
142 |
+
)
|
143 |
+
|
144 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
145 |
+
inputs_embeds=None, **kwargs):
|
146 |
+
images = kwargs.pop("images", None)
|
147 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
148 |
+
inputs = super().prepare_inputs_for_generation(
|
149 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
150 |
+
)
|
151 |
+
if images is not None:
|
152 |
+
inputs['images'] = images
|
153 |
+
if image_sizes is not None:
|
154 |
+
inputs['image_sizes'] = image_sizes
|
155 |
+
return inputs
|
156 |
+
|
157 |
+
AutoConfig.register("llava_llama", LlavaConfig)
|
158 |
+
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
|
dialoggen/llava/model/language_model/llava_mistral.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.nn import CrossEntropyLoss
|
21 |
+
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
23 |
+
MistralConfig, MistralModel, MistralForCausalLM
|
24 |
+
|
25 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
26 |
+
from transformers.generation.utils import GenerateOutput
|
27 |
+
|
28 |
+
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
29 |
+
|
30 |
+
|
31 |
+
class LlavaMistralConfig(MistralConfig):
|
32 |
+
model_type = "llava_mistral"
|
33 |
+
|
34 |
+
|
35 |
+
class LlavaMistralModel(LlavaMetaModel, MistralModel):
|
36 |
+
config_class = LlavaMistralConfig
|
37 |
+
|
38 |
+
def __init__(self, config: MistralConfig):
|
39 |
+
super(LlavaMistralModel, self).__init__(config)
|
40 |
+
|
41 |
+
|
42 |
+
class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
|
43 |
+
config_class = LlavaMistralConfig
|
44 |
+
|
45 |
+
def __init__(self, config):
|
46 |
+
super(MistralForCausalLM, self).__init__(config)
|
47 |
+
self.model = LlavaMistralModel(config)
|
48 |
+
|
49 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
50 |
+
|
51 |
+
# Initialize weights and apply final processing
|
52 |
+
self.post_init()
|
53 |
+
|
54 |
+
def get_model(self):
|
55 |
+
return self.model
|
56 |
+
|
57 |
+
def forward(
|
58 |
+
self,
|
59 |
+
input_ids: torch.LongTensor = None,
|
60 |
+
attention_mask: Optional[torch.Tensor] = None,
|
61 |
+
position_ids: Optional[torch.LongTensor] = None,
|
62 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
63 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
64 |
+
labels: Optional[torch.LongTensor] = None,
|
65 |
+
use_cache: Optional[bool] = None,
|
66 |
+
output_attentions: Optional[bool] = None,
|
67 |
+
output_hidden_states: Optional[bool] = None,
|
68 |
+
images: Optional[torch.FloatTensor] = None,
|
69 |
+
image_sizes: Optional[List[List[int]]] = None,
|
70 |
+
return_dict: Optional[bool] = None,
|
71 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
72 |
+
|
73 |
+
if inputs_embeds is None:
|
74 |
+
(
|
75 |
+
input_ids,
|
76 |
+
position_ids,
|
77 |
+
attention_mask,
|
78 |
+
past_key_values,
|
79 |
+
inputs_embeds,
|
80 |
+
labels
|
81 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
82 |
+
input_ids,
|
83 |
+
position_ids,
|
84 |
+
attention_mask,
|
85 |
+
past_key_values,
|
86 |
+
labels,
|
87 |
+
images,
|
88 |
+
image_sizes
|
89 |
+
)
|
90 |
+
|
91 |
+
return super().forward(
|
92 |
+
input_ids=input_ids,
|
93 |
+
attention_mask=attention_mask,
|
94 |
+
position_ids=position_ids,
|
95 |
+
past_key_values=past_key_values,
|
96 |
+
inputs_embeds=inputs_embeds,
|
97 |
+
labels=labels,
|
98 |
+
use_cache=use_cache,
|
99 |
+
output_attentions=output_attentions,
|
100 |
+
output_hidden_states=output_hidden_states,
|
101 |
+
return_dict=return_dict
|
102 |
+
)
|
103 |
+
|
104 |
+
@torch.no_grad()
|
105 |
+
def generate(
|
106 |
+
self,
|
107 |
+
inputs: Optional[torch.Tensor] = None,
|
108 |
+
images: Optional[torch.Tensor] = None,
|
109 |
+
image_sizes: Optional[torch.Tensor] = None,
|
110 |
+
**kwargs,
|
111 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
112 |
+
position_ids = kwargs.pop("position_ids", None)
|
113 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
114 |
+
if "inputs_embeds" in kwargs:
|
115 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
116 |
+
|
117 |
+
if images is not None:
|
118 |
+
(
|
119 |
+
inputs,
|
120 |
+
position_ids,
|
121 |
+
attention_mask,
|
122 |
+
_,
|
123 |
+
inputs_embeds,
|
124 |
+
_
|
125 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
126 |
+
inputs,
|
127 |
+
position_ids,
|
128 |
+
attention_mask,
|
129 |
+
None,
|
130 |
+
None,
|
131 |
+
images,
|
132 |
+
image_sizes=image_sizes
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
136 |
+
|
137 |
+
return super().generate(
|
138 |
+
position_ids=position_ids,
|
139 |
+
attention_mask=attention_mask,
|
140 |
+
inputs_embeds=inputs_embeds,
|
141 |
+
**kwargs
|
142 |
+
)
|
143 |
+
|
144 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
145 |
+
inputs_embeds=None, **kwargs):
|
146 |
+
images = kwargs.pop("images", None)
|
147 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
148 |
+
inputs = super().prepare_inputs_for_generation(
|
149 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
150 |
+
)
|
151 |
+
if images is not None:
|
152 |
+
inputs['images'] = images
|
153 |
+
if image_sizes is not None:
|
154 |
+
inputs['image_sizes'] = image_sizes
|
155 |
+
return inputs
|
156 |
+
|
157 |
+
AutoConfig.register("llava_mistral", LlavaMistralConfig)
|
158 |
+
AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
|
dialoggen/llava/model/language_model/llava_mpt.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import Optional, Tuple
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
21 |
+
MptConfig, MptForCausalLM, MptModel
|
22 |
+
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
23 |
+
|
24 |
+
|
25 |
+
class LlavaMptConfig(MptConfig):
|
26 |
+
model_type = "llava_mpt"
|
27 |
+
|
28 |
+
|
29 |
+
class LlavaMptModel(LlavaMetaModel, MptModel):
|
30 |
+
config_class = LlavaMptConfig
|
31 |
+
|
32 |
+
def __init__(self, config: MptConfig):
|
33 |
+
config.hidden_size = config.d_model
|
34 |
+
super(LlavaMptModel, self).__init__(config)
|
35 |
+
|
36 |
+
def embed_tokens(self, x):
|
37 |
+
return self.wte(x)
|
38 |
+
|
39 |
+
|
40 |
+
class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
|
41 |
+
config_class = LlavaMptConfig
|
42 |
+
supports_gradient_checkpointing = True
|
43 |
+
|
44 |
+
def __init__(self, config):
|
45 |
+
super(MptForCausalLM, self).__init__(config)
|
46 |
+
|
47 |
+
self.transformer = LlavaMptModel(config)
|
48 |
+
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
49 |
+
|
50 |
+
# Initialize weights and apply final processing
|
51 |
+
self.post_init()
|
52 |
+
|
53 |
+
def get_model(self):
|
54 |
+
return self.transformer
|
55 |
+
|
56 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
57 |
+
if isinstance(module, LlavaMptModel):
|
58 |
+
module.gradient_checkpointing = value
|
59 |
+
|
60 |
+
def forward(
|
61 |
+
self,
|
62 |
+
input_ids: Optional[torch.LongTensor] = None,
|
63 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
64 |
+
attention_mask: Optional[torch.Tensor] = None,
|
65 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
66 |
+
labels: Optional[torch.Tensor] = None,
|
67 |
+
use_cache: Optional[bool] = None,
|
68 |
+
output_attentions: Optional[bool] = None,
|
69 |
+
output_hidden_states: Optional[bool] = None,
|
70 |
+
return_dict: Optional[bool] = None,
|
71 |
+
images=None):
|
72 |
+
|
73 |
+
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
|
74 |
+
|
75 |
+
return super().forward(
|
76 |
+
input_ids,
|
77 |
+
past_key_values=past_key_values,
|
78 |
+
attention_mask=attention_mask,
|
79 |
+
inputs_embeds=inputs_embeds,
|
80 |
+
labels=labels,
|
81 |
+
use_cache=use_cache,
|
82 |
+
output_attentions=output_attentions,
|
83 |
+
output_hidden_states=output_hidden_states,
|
84 |
+
return_dict=return_dict,
|
85 |
+
)
|
86 |
+
|
87 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
88 |
+
images = kwargs.pop("images", None)
|
89 |
+
_inputs = super().prepare_inputs_for_generation(
|
90 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
91 |
+
)
|
92 |
+
_inputs['images'] = images
|
93 |
+
return _inputs
|
94 |
+
|
95 |
+
|
96 |
+
AutoConfig.register("llava_mpt", LlavaMptConfig)
|
97 |
+
AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
|
dialoggen/llava/model/llava_arch.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from abc import ABC, abstractmethod
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from .multimodal_encoder.builder import build_vision_tower
|
22 |
+
from .multimodal_projector.builder import build_vision_projector
|
23 |
+
|
24 |
+
from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
25 |
+
|
26 |
+
from llava.mm_utils import get_anyres_image_grid_shape
|
27 |
+
|
28 |
+
|
29 |
+
class LlavaMetaModel:
|
30 |
+
|
31 |
+
def __init__(self, config):
|
32 |
+
super(LlavaMetaModel, self).__init__(config)
|
33 |
+
|
34 |
+
if hasattr(config, "mm_vision_tower"):
|
35 |
+
self.vision_tower = build_vision_tower(config, delay_load=True)
|
36 |
+
self.mm_projector = build_vision_projector(config)
|
37 |
+
|
38 |
+
if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
|
39 |
+
self.image_newline = nn.Parameter(
|
40 |
+
torch.empty(config.hidden_size, dtype=self.dtype)
|
41 |
+
)
|
42 |
+
|
43 |
+
def get_vision_tower(self):
|
44 |
+
vision_tower = getattr(self, 'vision_tower', None)
|
45 |
+
if type(vision_tower) is list:
|
46 |
+
vision_tower = vision_tower[0]
|
47 |
+
return vision_tower
|
48 |
+
|
49 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
50 |
+
vision_tower = model_args.vision_tower
|
51 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
52 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
53 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
54 |
+
mm_patch_merge_type = model_args.mm_patch_merge_type
|
55 |
+
|
56 |
+
self.config.mm_vision_tower = vision_tower
|
57 |
+
|
58 |
+
if self.get_vision_tower() is None:
|
59 |
+
vision_tower = build_vision_tower(model_args)
|
60 |
+
|
61 |
+
if fsdp is not None and len(fsdp) > 0:
|
62 |
+
self.vision_tower = [vision_tower]
|
63 |
+
else:
|
64 |
+
self.vision_tower = vision_tower
|
65 |
+
else:
|
66 |
+
if fsdp is not None and len(fsdp) > 0:
|
67 |
+
vision_tower = self.vision_tower[0]
|
68 |
+
else:
|
69 |
+
vision_tower = self.vision_tower
|
70 |
+
vision_tower.load_model()
|
71 |
+
|
72 |
+
self.config.use_mm_proj = True
|
73 |
+
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
74 |
+
self.config.mm_hidden_size = vision_tower.hidden_size
|
75 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
76 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
77 |
+
self.config.mm_patch_merge_type = mm_patch_merge_type
|
78 |
+
|
79 |
+
if getattr(self, 'mm_projector', None) is None:
|
80 |
+
self.mm_projector = build_vision_projector(self.config)
|
81 |
+
|
82 |
+
if 'unpad' in mm_patch_merge_type:
|
83 |
+
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
|
84 |
+
self.image_newline = nn.Parameter(
|
85 |
+
torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
# In case it is frozen by LoRA
|
89 |
+
for p in self.mm_projector.parameters():
|
90 |
+
p.requires_grad = True
|
91 |
+
|
92 |
+
if pretrain_mm_mlp_adapter is not None:
|
93 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
94 |
+
def get_w(weights, keyword):
|
95 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
96 |
+
|
97 |
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
98 |
+
|
99 |
+
|
100 |
+
def unpad_image(tensor, original_size):
|
101 |
+
"""
|
102 |
+
Unpads a PyTorch tensor of a padded and resized image.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
|
106 |
+
original_size (tuple): The original size of the image (height, width).
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
torch.Tensor: The unpadded image tensor.
|
110 |
+
"""
|
111 |
+
original_width, original_height = original_size
|
112 |
+
current_height, current_width = tensor.shape[1:]
|
113 |
+
|
114 |
+
original_aspect_ratio = original_width / original_height
|
115 |
+
current_aspect_ratio = current_width / current_height
|
116 |
+
|
117 |
+
if original_aspect_ratio > current_aspect_ratio:
|
118 |
+
scale_factor = current_width / original_width
|
119 |
+
new_height = int(original_height * scale_factor)
|
120 |
+
padding = (current_height - new_height) // 2
|
121 |
+
unpadded_tensor = tensor[:, padding:current_height - padding, :]
|
122 |
+
else:
|
123 |
+
scale_factor = current_height / original_height
|
124 |
+
new_width = int(original_width * scale_factor)
|
125 |
+
padding = (current_width - new_width) // 2
|
126 |
+
unpadded_tensor = tensor[:, :, padding:current_width - padding]
|
127 |
+
|
128 |
+
return unpadded_tensor
|
129 |
+
|
130 |
+
|
131 |
+
class LlavaMetaForCausalLM(ABC):
|
132 |
+
|
133 |
+
@abstractmethod
|
134 |
+
def get_model(self):
|
135 |
+
pass
|
136 |
+
|
137 |
+
def get_vision_tower(self):
|
138 |
+
return self.get_model().get_vision_tower()
|
139 |
+
|
140 |
+
def encode_images(self, images):
|
141 |
+
image_features = self.get_model().get_vision_tower()(images)
|
142 |
+
image_features = self.get_model().mm_projector(image_features)
|
143 |
+
return image_features
|
144 |
+
|
145 |
+
def prepare_inputs_labels_for_multimodal(
|
146 |
+
self, input_ids, position_ids, attention_mask, past_key_values, labels,
|
147 |
+
images, image_sizes=None
|
148 |
+
):
|
149 |
+
vision_tower = self.get_vision_tower()
|
150 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
151 |
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
152 |
+
|
153 |
+
if type(images) is list or images.ndim == 5:
|
154 |
+
if type(images) is list:
|
155 |
+
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
|
156 |
+
concat_images = torch.cat([image for image in images], dim=0)
|
157 |
+
image_features = self.encode_images(concat_images)
|
158 |
+
split_sizes = [image.shape[0] for image in images]
|
159 |
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
160 |
+
mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
|
161 |
+
image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
|
162 |
+
if mm_patch_merge_type == 'flat':
|
163 |
+
image_features = [x.flatten(0, 1) for x in image_features]
|
164 |
+
elif mm_patch_merge_type.startswith('spatial'):
|
165 |
+
new_image_features = []
|
166 |
+
for image_idx, image_feature in enumerate(image_features):
|
167 |
+
if image_feature.shape[0] > 1:
|
168 |
+
base_image_feature = image_feature[0]
|
169 |
+
image_feature = image_feature[1:]
|
170 |
+
height = width = self.get_vision_tower().num_patches_per_side
|
171 |
+
assert height * width == base_image_feature.shape[0]
|
172 |
+
if image_aspect_ratio == 'anyres':
|
173 |
+
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
|
174 |
+
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
175 |
+
else:
|
176 |
+
raise NotImplementedError
|
177 |
+
if 'unpad' in mm_patch_merge_type:
|
178 |
+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
179 |
+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
180 |
+
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
181 |
+
image_feature = torch.cat((
|
182 |
+
image_feature,
|
183 |
+
self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
|
184 |
+
), dim=-1)
|
185 |
+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
186 |
+
else:
|
187 |
+
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
|
188 |
+
image_feature = image_feature.flatten(0, 3)
|
189 |
+
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
190 |
+
else:
|
191 |
+
image_feature = image_feature[0]
|
192 |
+
if 'unpad' in mm_patch_merge_type:
|
193 |
+
image_feature = torch.cat((
|
194 |
+
image_feature,
|
195 |
+
self.model.image_newline[None].to(image_feature.device)
|
196 |
+
), dim=0)
|
197 |
+
new_image_features.append(image_feature)
|
198 |
+
image_features = new_image_features
|
199 |
+
else:
|
200 |
+
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
|
201 |
+
else:
|
202 |
+
image_features = self.encode_images(images)
|
203 |
+
|
204 |
+
# TODO: image start / end is not implemented here to support pretraining.
|
205 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
206 |
+
raise NotImplementedError
|
207 |
+
|
208 |
+
# Let's just add dummy tensors if they do not exist,
|
209 |
+
# it is a headache to deal with None all the time.
|
210 |
+
# But it is not ideal, and if you have a better idea,
|
211 |
+
# please open an issue / submit a PR, thanks.
|
212 |
+
_labels = labels
|
213 |
+
_position_ids = position_ids
|
214 |
+
_attention_mask = attention_mask
|
215 |
+
if attention_mask is None:
|
216 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
217 |
+
else:
|
218 |
+
attention_mask = attention_mask.bool()
|
219 |
+
if position_ids is None:
|
220 |
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
221 |
+
if labels is None:
|
222 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
223 |
+
|
224 |
+
# remove the padding using attention_mask -- FIXME
|
225 |
+
_input_ids = input_ids
|
226 |
+
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
227 |
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
228 |
+
|
229 |
+
new_input_embeds = []
|
230 |
+
new_labels = []
|
231 |
+
cur_image_idx = 0
|
232 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
233 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
234 |
+
if num_images == 0:
|
235 |
+
cur_image_features = image_features[cur_image_idx]
|
236 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
237 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
238 |
+
new_input_embeds.append(cur_input_embeds)
|
239 |
+
new_labels.append(labels[batch_idx])
|
240 |
+
cur_image_idx += 1
|
241 |
+
continue
|
242 |
+
|
243 |
+
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
244 |
+
cur_input_ids_noim = []
|
245 |
+
cur_labels = labels[batch_idx]
|
246 |
+
cur_labels_noim = []
|
247 |
+
for i in range(len(image_token_indices) - 1):
|
248 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
|
249 |
+
cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
|
250 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
251 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
252 |
+
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
253 |
+
cur_new_input_embeds = []
|
254 |
+
cur_new_labels = []
|
255 |
+
|
256 |
+
for i in range(num_images + 1):
|
257 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
258 |
+
cur_new_labels.append(cur_labels_noim[i])
|
259 |
+
if i < num_images:
|
260 |
+
cur_image_features = image_features[cur_image_idx]
|
261 |
+
cur_image_idx += 1
|
262 |
+
cur_new_input_embeds.append(cur_image_features)
|
263 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
264 |
+
|
265 |
+
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
266 |
+
|
267 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
268 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
269 |
+
|
270 |
+
new_input_embeds.append(cur_new_input_embeds)
|
271 |
+
new_labels.append(cur_new_labels)
|
272 |
+
|
273 |
+
# Truncate sequences to max length as image embeddings can make the sequence longer
|
274 |
+
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
|
275 |
+
if tokenizer_model_max_length is not None:
|
276 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
277 |
+
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
278 |
+
|
279 |
+
# Combine them
|
280 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
281 |
+
batch_size = len(new_input_embeds)
|
282 |
+
|
283 |
+
new_input_embeds_padded = []
|
284 |
+
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
285 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
286 |
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
287 |
+
|
288 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
289 |
+
cur_len = cur_new_embed.shape[0]
|
290 |
+
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
|
291 |
+
new_input_embeds_padded.append(torch.cat((
|
292 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
|
293 |
+
cur_new_embed
|
294 |
+
), dim=0))
|
295 |
+
if cur_len > 0:
|
296 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
297 |
+
attention_mask[i, -cur_len:] = True
|
298 |
+
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
299 |
+
else:
|
300 |
+
new_input_embeds_padded.append(torch.cat((
|
301 |
+
cur_new_embed,
|
302 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
|
303 |
+
), dim=0))
|
304 |
+
if cur_len > 0:
|
305 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
306 |
+
attention_mask[i, :cur_len] = True
|
307 |
+
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
308 |
+
|
309 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
310 |
+
|
311 |
+
if _labels is None:
|
312 |
+
new_labels = None
|
313 |
+
else:
|
314 |
+
new_labels = new_labels_padded
|
315 |
+
|
316 |
+
if _attention_mask is None:
|
317 |
+
attention_mask = None
|
318 |
+
else:
|
319 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
320 |
+
|
321 |
+
if _position_ids is None:
|
322 |
+
position_ids = None
|
323 |
+
|
324 |
+
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
325 |
+
|
326 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
327 |
+
if model_args.mm_use_im_patch_token:
|
328 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
329 |
+
self.resize_token_embeddings(len(tokenizer))
|
330 |
+
|
331 |
+
if model_args.mm_use_im_start_end:
|
332 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
333 |
+
self.resize_token_embeddings(len(tokenizer))
|
334 |
+
|
335 |
+
if num_new_tokens > 0:
|
336 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
337 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
338 |
+
|
339 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
340 |
+
dim=0, keepdim=True)
|
341 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
342 |
+
dim=0, keepdim=True)
|
343 |
+
|
344 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
345 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
346 |
+
|
347 |
+
if model_args.tune_mm_mlp_adapter:
|
348 |
+
for p in self.get_input_embeddings().parameters():
|
349 |
+
p.requires_grad = True
|
350 |
+
for p in self.get_output_embeddings().parameters():
|
351 |
+
p.requires_grad = False
|
352 |
+
|
353 |
+
if model_args.pretrain_mm_mlp_adapter:
|
354 |
+
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
|
355 |
+
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
356 |
+
assert num_new_tokens == 2
|
357 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
358 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
359 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
360 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
361 |
+
else:
|
362 |
+
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
363 |
+
elif model_args.mm_use_im_patch_token:
|
364 |
+
if model_args.tune_mm_mlp_adapter:
|
365 |
+
for p in self.get_input_embeddings().parameters():
|
366 |
+
p.requires_grad = False
|
367 |
+
for p in self.get_output_embeddings().parameters():
|
368 |
+
p.requires_grad = False
|
dialoggen/llava/model/make_delta.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
+
from llava.model.utils import auto_upgrade
|
11 |
+
|
12 |
+
|
13 |
+
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
|
14 |
+
print("Loading base model")
|
15 |
+
base = AutoModelForCausalLM.from_pretrained(
|
16 |
+
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
17 |
+
|
18 |
+
print("Loading target model")
|
19 |
+
auto_upgrade(target_model_path)
|
20 |
+
target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
21 |
+
|
22 |
+
print("Calculating delta")
|
23 |
+
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
|
24 |
+
if name not in base.state_dict():
|
25 |
+
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
|
26 |
+
continue
|
27 |
+
if param.data.shape == base.state_dict()[name].shape:
|
28 |
+
param.data -= base.state_dict()[name]
|
29 |
+
else:
|
30 |
+
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
|
31 |
+
bparam = base.state_dict()[name]
|
32 |
+
param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
|
33 |
+
|
34 |
+
print("Saving delta")
|
35 |
+
if hub_repo_id:
|
36 |
+
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
|
37 |
+
else:
|
38 |
+
kwargs = {}
|
39 |
+
target.save_pretrained(delta_path, **kwargs)
|
40 |
+
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
|
41 |
+
target_tokenizer.save_pretrained(delta_path, **kwargs)
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
parser = argparse.ArgumentParser()
|
46 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
47 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
48 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
49 |
+
parser.add_argument("--hub-repo-id", type=str, default=None)
|
50 |
+
args = parser.parse_args()
|
51 |
+
|
52 |
+
make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
|
dialoggen/llava/model/multimodal_encoder/builder.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .clip_encoder import CLIPVisionTower
|
3 |
+
|
4 |
+
|
5 |
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
6 |
+
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
|
7 |
+
is_absolute_path_exists = os.path.exists(vision_tower)
|
8 |
+
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
|
9 |
+
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
10 |
+
|
11 |
+
raise ValueError(f'Unknown vision tower: {vision_tower}')
|
dialoggen/llava/model/multimodal_encoder/clip_encoder.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
5 |
+
|
6 |
+
|
7 |
+
class CLIPVisionTower(nn.Module):
|
8 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.is_loaded = False
|
12 |
+
|
13 |
+
self.vision_tower_name = vision_tower
|
14 |
+
self.select_layer = args.mm_vision_select_layer
|
15 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
16 |
+
|
17 |
+
if not delay_load:
|
18 |
+
self.load_model()
|
19 |
+
elif getattr(args, 'unfreeze_mm_vision_tower', False):
|
20 |
+
self.load_model()
|
21 |
+
else:
|
22 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
23 |
+
|
24 |
+
def load_model(self, device_map=None):
|
25 |
+
if self.is_loaded:
|
26 |
+
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
27 |
+
return
|
28 |
+
|
29 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
30 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
31 |
+
self.vision_tower.requires_grad_(False)
|
32 |
+
|
33 |
+
self.is_loaded = True
|
34 |
+
|
35 |
+
def feature_select(self, image_forward_outs):
|
36 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
37 |
+
if self.select_feature == 'patch':
|
38 |
+
image_features = image_features[:, 1:]
|
39 |
+
elif self.select_feature == 'cls_patch':
|
40 |
+
image_features = image_features
|
41 |
+
else:
|
42 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
43 |
+
return image_features
|
44 |
+
|
45 |
+
@torch.no_grad()
|
46 |
+
def forward(self, images):
|
47 |
+
if type(images) is list:
|
48 |
+
image_features = []
|
49 |
+
for image in images:
|
50 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
51 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
52 |
+
image_features.append(image_feature)
|
53 |
+
else:
|
54 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
55 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
56 |
+
|
57 |
+
return image_features
|
58 |
+
|
59 |
+
@property
|
60 |
+
def dummy_feature(self):
|
61 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
62 |
+
|
63 |
+
@property
|
64 |
+
def dtype(self):
|
65 |
+
return self.vision_tower.dtype
|
66 |
+
|
67 |
+
@property
|
68 |
+
def device(self):
|
69 |
+
return self.vision_tower.device
|
70 |
+
|
71 |
+
@property
|
72 |
+
def config(self):
|
73 |
+
if self.is_loaded:
|
74 |
+
return self.vision_tower.config
|
75 |
+
else:
|
76 |
+
return self.cfg_only
|
77 |
+
|
78 |
+
@property
|
79 |
+
def hidden_size(self):
|
80 |
+
return self.config.hidden_size
|
81 |
+
|
82 |
+
@property
|
83 |
+
def num_patches_per_side(self):
|
84 |
+
return self.config.image_size // self.config.patch_size
|
85 |
+
|
86 |
+
@property
|
87 |
+
def num_patches(self):
|
88 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
dialoggen/llava/model/multimodal_projector/builder.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import re
|
4 |
+
|
5 |
+
|
6 |
+
class IdentityMap(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
def forward(self, x, *args, **kwargs):
|
11 |
+
return x
|
12 |
+
|
13 |
+
@property
|
14 |
+
def config(self):
|
15 |
+
return {"mm_projector_type": 'identity'}
|
16 |
+
|
17 |
+
|
18 |
+
class SimpleResBlock(nn.Module):
|
19 |
+
def __init__(self, channels):
|
20 |
+
super().__init__()
|
21 |
+
self.pre_norm = nn.LayerNorm(channels)
|
22 |
+
|
23 |
+
self.proj = nn.Sequential(
|
24 |
+
nn.Linear(channels, channels),
|
25 |
+
nn.GELU(),
|
26 |
+
nn.Linear(channels, channels)
|
27 |
+
)
|
28 |
+
def forward(self, x):
|
29 |
+
x = self.pre_norm(x)
|
30 |
+
return x + self.proj(x)
|
31 |
+
|
32 |
+
|
33 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
34 |
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
35 |
+
|
36 |
+
if projector_type == 'linear':
|
37 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
38 |
+
|
39 |
+
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
40 |
+
if mlp_gelu_match:
|
41 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
42 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
43 |
+
for _ in range(1, mlp_depth):
|
44 |
+
modules.append(nn.GELU())
|
45 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
46 |
+
return nn.Sequential(*modules)
|
47 |
+
|
48 |
+
if projector_type == 'identity':
|
49 |
+
return IdentityMap()
|
50 |
+
|
51 |
+
raise ValueError(f'Unknown projector type: {projector_type}')
|
dialoggen/llava/model/utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoConfig
|
2 |
+
|
3 |
+
|
4 |
+
def auto_upgrade(config):
|
5 |
+
cfg = AutoConfig.from_pretrained(config)
|
6 |
+
if 'llava' in config and 'llava' not in cfg.model_type:
|
7 |
+
assert cfg.model_type == 'llama'
|
8 |
+
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
|
9 |
+
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
|
10 |
+
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
11 |
+
if confirm.lower() in ["y", "yes"]:
|
12 |
+
print("Upgrading checkpoint...")
|
13 |
+
assert len(cfg.architectures) == 1
|
14 |
+
setattr(cfg.__class__, "model_type", "llava")
|
15 |
+
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
|
16 |
+
cfg.save_pretrained(config)
|
17 |
+
print("Checkpoint upgraded.")
|
18 |
+
else:
|
19 |
+
print("Checkpoint upgrade aborted.")
|
20 |
+
exit(1)
|
dialoggen/llava/utils.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import logging.handlers
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import requests
|
8 |
+
|
9 |
+
from llava.constants import LOGDIR
|
10 |
+
|
11 |
+
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
12 |
+
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
13 |
+
|
14 |
+
handler = None
|
15 |
+
|
16 |
+
|
17 |
+
def build_logger(logger_name, logger_filename):
|
18 |
+
global handler
|
19 |
+
|
20 |
+
formatter = logging.Formatter(
|
21 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
22 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
23 |
+
)
|
24 |
+
|
25 |
+
# Set the format of root handlers
|
26 |
+
if not logging.getLogger().handlers:
|
27 |
+
logging.basicConfig(level=logging.INFO)
|
28 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
29 |
+
|
30 |
+
# Redirect stdout and stderr to loggers
|
31 |
+
stdout_logger = logging.getLogger("stdout")
|
32 |
+
stdout_logger.setLevel(logging.INFO)
|
33 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
34 |
+
sys.stdout = sl
|
35 |
+
|
36 |
+
stderr_logger = logging.getLogger("stderr")
|
37 |
+
stderr_logger.setLevel(logging.ERROR)
|
38 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
39 |
+
sys.stderr = sl
|
40 |
+
|
41 |
+
# Get logger
|
42 |
+
logger = logging.getLogger(logger_name)
|
43 |
+
logger.setLevel(logging.INFO)
|
44 |
+
|
45 |
+
# Add a file handler for all loggers
|
46 |
+
if handler is None:
|
47 |
+
os.makedirs(LOGDIR, exist_ok=True)
|
48 |
+
filename = os.path.join(LOGDIR, logger_filename)
|
49 |
+
handler = logging.handlers.TimedRotatingFileHandler(
|
50 |
+
filename, when='D', utc=True, encoding='UTF-8')
|
51 |
+
handler.setFormatter(formatter)
|
52 |
+
|
53 |
+
for name, item in logging.root.manager.loggerDict.items():
|
54 |
+
if isinstance(item, logging.Logger):
|
55 |
+
item.addHandler(handler)
|
56 |
+
|
57 |
+
return logger
|
58 |
+
|
59 |
+
|
60 |
+
class StreamToLogger(object):
|
61 |
+
"""
|
62 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
63 |
+
"""
|
64 |
+
def __init__(self, logger, log_level=logging.INFO):
|
65 |
+
self.terminal = sys.stdout
|
66 |
+
self.logger = logger
|
67 |
+
self.log_level = log_level
|
68 |
+
self.linebuf = ''
|
69 |
+
|
70 |
+
def __getattr__(self, attr):
|
71 |
+
return getattr(self.terminal, attr)
|
72 |
+
|
73 |
+
def write(self, buf):
|
74 |
+
temp_linebuf = self.linebuf + buf
|
75 |
+
self.linebuf = ''
|
76 |
+
for line in temp_linebuf.splitlines(True):
|
77 |
+
# From the io.TextIOWrapper docs:
|
78 |
+
# On output, if newline is None, any '\n' characters written
|
79 |
+
# are translated to the system default line separator.
|
80 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
81 |
+
# translates them so this is still cross platform.
|
82 |
+
if line[-1] == '\n':
|
83 |
+
self.logger.log(self.log_level, line.rstrip())
|
84 |
+
else:
|
85 |
+
self.linebuf += line
|
86 |
+
|
87 |
+
def flush(self):
|
88 |
+
if self.linebuf != '':
|
89 |
+
self.logger.log(self.log_level, self.linebuf.rstrip())
|
90 |
+
self.linebuf = ''
|
91 |
+
|
92 |
+
|
93 |
+
def disable_torch_init():
|
94 |
+
"""
|
95 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
96 |
+
"""
|
97 |
+
import torch
|
98 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
99 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
100 |
+
|
101 |
+
|
102 |
+
def violates_moderation(text):
|
103 |
+
"""
|
104 |
+
Check whether the text violates OpenAI moderation API.
|
105 |
+
"""
|
106 |
+
url = "https://api.openai.com/v1/moderations"
|
107 |
+
headers = {"Content-Type": "application/json",
|
108 |
+
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
|
109 |
+
text = text.replace("\n", "")
|
110 |
+
data = "{" + '"input": ' + f'"{text}"' + "}"
|
111 |
+
data = data.encode("utf-8")
|
112 |
+
try:
|
113 |
+
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
114 |
+
flagged = ret.json()["results"][0]["flagged"]
|
115 |
+
except requests.exceptions.RequestException as e:
|
116 |
+
flagged = False
|
117 |
+
except KeyError as e:
|
118 |
+
flagged = False
|
119 |
+
|
120 |
+
return flagged
|
121 |
+
|
122 |
+
|
123 |
+
def pretty_print_semaphore(semaphore):
|
124 |
+
if semaphore is None:
|
125 |
+
return "None"
|
126 |
+
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
en.csv
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
key,value
|
2 |
+
size,Size
|
3 |
+
sampler,Sampler
|
4 |
+
prompt,Prompt
|
5 |
+
default prompt,"A cute cat"
|
6 |
+
negative_prompt,Negative Prompt
|
7 |
+
seed,Seed
|
8 |
+
cfg,CFG Scale
|
9 |
+
infer steps,Sampling Steps
|
10 |
+
batch size,Batch Size
|
11 |
+
width cond,Width Cond
|
12 |
+
height cond,Height Cond
|
13 |
+
enhance,Prompt Enhancement
|
14 |
+
run,Submit
|
15 |
+
square,Square(1024x1024)
|
16 |
+
landscape,Landscape(1280x768)
|
17 |
+
portrait,Portrait(768x1280)
|
18 |
+
accordion,Advanced Options
|
19 |
+
generated image,HunYuanDiT Generated Image
|
20 |
+
examples,More Examples
|
21 |
+
title,Hunyuan-DiT
|
22 |
+
desc,A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding
|
environment.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: HunyuanDiT
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
dependencies:
|
6 |
+
- python=3.8.12
|
7 |
+
- pytorch=1.13.1
|
8 |
+
- pip
|
example_prompts.txt
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
一只聪明的狐狸走在阔叶树林里, 旁边是一条小溪, 细节真实, 摄影
|
2 |
+
湖水清澈,天空湛蓝,阳光灿烂。一只优雅的白天鹅在湖边游泳。它周围有几只小鸭子,看起来非常可爱,整个画面给人一种宁静祥和的感觉。
|
3 |
+
太阳微微升起,花园里的玫瑰花瓣上露珠晶莹剔透,一只瓢虫正在爬向露珠,背景是清晨的花园,微距镜头
|
4 |
+
一位女明星,中国人,头发是黑色,衣服是纯白色短袖,人物风格清新,城市背景
|
5 |
+
后印象主义风格,一条古老的石板路上面散落着金黄色的树叶。路旁的风车在静谧地转动,后面竖着两个风车。背景是一片向日葵田,蓝天上飘着几朵白云
|
6 |
+
一幅细致的油画描绘了一只年轻獾轻轻嗅着一朵明亮的黄色玫瑰时错综复杂的皮毛。背景是一棵大树干的粗糙纹理,獾的爪子轻轻地挖进树皮。在柔和的背景中,一个宁静的瀑布倾泻而下,它的水在绿色植物中闪烁着蓝色。
|
7 |
+
渔舟唱晚
|
8 |
+
请将杞人忧天的样子画出来
|
9 |
+
一只长靴猫手持亮银色的宝剑,身着铠甲,眼神坚毅,站在一堆金币上,背景是暗色调的洞穴,图像上有金币的光影点缀。
|
10 |
+
插画风格,一只狐狸和一只刺猬坐在水边的石头上,刺猬手里拿着一杯茶,狐狸旁边放着一个玻璃杯。周围是茂密的绿色植物和树木,阳光透过树叶洒在水面上,画面宁静温馨。
|
11 |
+
泥塑风格,一座五彩斑斓的花园在画面中展现,各种各样的花朵,绿色的叶子和一只正在嬉戏的小猫形成了一幅生动的图像,背景是蓝天和白云
|
12 |
+
枯藤老树昏鸦,小桥流水人家
|
13 |
+
一张细致的照片捕捉到了一尊雕像的形象,这尊雕像酷似一位古代法老,头上出人意料地戴着一副青铜蒸汽朋克护目镜。这座雕像穿着复古时髦,一件清爽的白色T恤和一件合身的黑色皮夹克,与传统的头饰形成鲜明对比。背景是简单的纯色,突出了雕像的非传统服装和蒸汽朋克眼镜的复杂细节。
|
14 |
+
一朵鲜艳的红色玫瑰花,花瓣撒有一些水珠,晶莹剔透,特写镜头,
|
15 |
+
一只可爱的猫, 细节真实, 摄影
|
16 |
+
飞流直下三千尺,疑是银河落九天
|
17 |
+
成语“鲤鱼跃龙门”
|
18 |
+
一颗新鲜的草莓特写,红色的外表,表面布满许多种子,背景是淡绿色的叶子
|
19 |
+
九寨沟
|
20 |
+
摄影风格,在画面中心是一盘热气腾腾的麻婆豆腐,豆腐呈白色,上面撒着一层红色的辣酱,有些许绿色的葱花点缀,背景是深色木质餐桌,桌子上放有辣椒和葱花作为点缀。
|
21 |
+
一位年轻女子站在春季的火车站月台上。她身着蓝灰色长风衣,白色衬衫。她的深棕色头发扎成低马尾,几缕碎发随风飘扬。她的眼神充满期待,阳光洒在她温暖的脸庞上。
|
22 |
+
一只优雅的白鹤在湖边静静地站立,它的身体纯白色,翅膀轻轻展开,背景是湖面和远处的山脉
|
23 |
+
国画风格,苏州园林中的小桥流水,周围是郁郁葱葱的树,池塘里有几朵绽放的荷花,背景是宁静的江南水乡
|
24 |
+
现实主义风格,画面主要描述一个巴洛克风格的花瓶,带有金色的装饰边框,花瓶上盛开着各种色彩鲜艳的花,白色背景
|
25 |
+
醉后不知天在水,满船清梦压星河
|
26 |
+
长城
|
27 |
+
一个亚洲中年男士在夕阳下的公园长椅上静坐。他穿着一件深蓝色的针织毛衣和灰色裤子。他的头发略显花白,手中拿着一本敞开的书。面带微笑,眼神温和,周围是落日余晖和四周的绿树。
|
28 |
+
风格是写实,画面主要描述一个亚洲戏曲艺术家正在表演,她穿着华丽的戏服,脸上戴着精致的面具,身姿优雅,背景是古色古香的舞台,镜头是近景
|
hydit/__init__.py
ADDED
File without changes
|
hydit/config.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
from .constants import *
|
4 |
+
from .modules.models import HUNYUAN_DIT_CONFIG
|
5 |
+
|
6 |
+
|
7 |
+
def get_args(default_args=None):
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
|
10 |
+
# Basic
|
11 |
+
parser.add_argument("--prompt", type=str, default="一只小猫", help="The prompt for generating images.")
|
12 |
+
parser.add_argument("--model-root", type=str, default="ckpts", help="Model root path.")
|
13 |
+
parser.add_argument("--image-size", type=int, nargs='+', default=[1024, 1024],
|
14 |
+
help='Image size (h, w). If a single value is provided, the image will be treated to '
|
15 |
+
'(value, value).')
|
16 |
+
parser.add_argument("--infer-mode", type=str, choices=["fa", "torch", "trt"], default="torch",
|
17 |
+
help="Inference mode")
|
18 |
+
|
19 |
+
# HunYuan-DiT
|
20 |
+
parser.add_argument("--model", type=str, choices=list(HUNYUAN_DIT_CONFIG.keys()), default='DiT-g/2')
|
21 |
+
parser.add_argument("--norm", type=str, default="layer", help="Normalization layer type")
|
22 |
+
parser.add_argument("--load-key", type=str, choices=["ema", "module"], default="ema", help="Load model key for HunYuanDiT checkpoint.")
|
23 |
+
parser.add_argument('--size-cond', type=int, nargs='+', default=[1024, 1024],
|
24 |
+
help="Size condition used in sampling. 2 values are required for height and width. "
|
25 |
+
"If a single value is provided, the image will be treated to (value, value).")
|
26 |
+
parser.add_argument("--cfg-scale", type=float, default=6.0, help="Guidance scale for classifier-free.")
|
27 |
+
|
28 |
+
# Prompt enhancement
|
29 |
+
parser.add_argument("--enhance", action="store_true", help="Enhance prompt with dialoggen.")
|
30 |
+
parser.add_argument("--no-enhance", dest="enhance", action="store_false")
|
31 |
+
parser.set_defaults(enhance=True)
|
32 |
+
|
33 |
+
# Diffusion
|
34 |
+
parser.add_argument("--learn-sigma", action="store_true", help="Learn extra channels for sigma.")
|
35 |
+
parser.add_argument("--no-learn-sigma", dest="learn_sigma", action="store_false")
|
36 |
+
parser.set_defaults(learn_sigma=True)
|
37 |
+
parser.add_argument("--predict-type", type=str, choices=list(PREDICT_TYPE), default="v_prediction",
|
38 |
+
help="Diffusion predict type")
|
39 |
+
parser.add_argument("--noise-schedule", type=str, choices=list(NOISE_SCHEDULES), default="scaled_linear",
|
40 |
+
help="Noise schedule")
|
41 |
+
parser.add_argument("--beta-start", type=float, default=0.00085, help="Beta start value")
|
42 |
+
parser.add_argument("--beta-end", type=float, default=0.03, help="Beta end value")
|
43 |
+
|
44 |
+
# Text condition
|
45 |
+
parser.add_argument("--text-states-dim", type=int, default=1024, help="Hidden size of CLIP text encoder.")
|
46 |
+
parser.add_argument("--text-len", type=int, default=77, help="Token length of CLIP text encoder output.")
|
47 |
+
parser.add_argument("--text-states-dim-t5", type=int, default=2048, help="Hidden size of CLIP text encoder.")
|
48 |
+
parser.add_argument("--text-len-t5", type=int, default=256, help="Token length of T5 text encoder output.")
|
49 |
+
parser.add_argument("--negative", type=str, default=None, help="Negative prompt.")
|
50 |
+
|
51 |
+
# Acceleration
|
52 |
+
parser.add_argument("--use_fp16", action="store_true", help="Use FP16 precision.")
|
53 |
+
parser.add_argument("--no-fp16", dest="use_fp16", action="store_false")
|
54 |
+
parser.set_defaults(use_fp16=True)
|
55 |
+
|
56 |
+
# Sampling
|
57 |
+
parser.add_argument("--batch-size", type=int, default=1, help="Per-GPU batch size")
|
58 |
+
parser.add_argument("--sampler", type=str, choices=SAMPLER_FACTORY, default="ddpm", help="Diffusion sampler")
|
59 |
+
parser.add_argument("--infer-steps", type=int, default=100, help="Inference steps")
|
60 |
+
parser.add_argument('--seed', type=int, default=42, help="A seed for all the prompts.")
|
61 |
+
|
62 |
+
# App
|
63 |
+
parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="Language")
|
64 |
+
|
65 |
+
args = parser.parse_args(default_args)
|
66 |
+
|
67 |
+
return args
|
hydit/constants.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# =======================================================
|
2 |
+
NOISE_SCHEDULES = {
|
3 |
+
"linear",
|
4 |
+
"scaled_linear",
|
5 |
+
"squaredcos_cap_v2",
|
6 |
+
}
|
7 |
+
|
8 |
+
PREDICT_TYPE = {
|
9 |
+
"epsilon",
|
10 |
+
"sample",
|
11 |
+
"v_prediction",
|
12 |
+
}
|
13 |
+
|
14 |
+
# =======================================================
|
15 |
+
NEGATIVE_PROMPT = '错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,'
|
16 |
+
|
17 |
+
|
18 |
+
# =======================================================
|
19 |
+
# Constants about models
|
20 |
+
# =======================================================
|
21 |
+
|
22 |
+
SAMPLER_FACTORY = {
|
23 |
+
'ddpm': {
|
24 |
+
'scheduler': 'DDPMScheduler',
|
25 |
+
'name': 'DDPM',
|
26 |
+
'kwargs': {
|
27 |
+
'steps_offset': 1,
|
28 |
+
'clip_sample': False,
|
29 |
+
'clip_sample_range': 1.0,
|
30 |
+
'beta_schedule': 'scaled_linear',
|
31 |
+
'beta_start': 0.00085,
|
32 |
+
'beta_end': 0.03,
|
33 |
+
'prediction_type': 'v_prediction',
|
34 |
+
}
|
35 |
+
},
|
36 |
+
'ddim': {
|
37 |
+
'scheduler': 'DDIMScheduler',
|
38 |
+
'name': 'DDIM',
|
39 |
+
'kwargs': {
|
40 |
+
'steps_offset': 1,
|
41 |
+
'clip_sample': False,
|
42 |
+
'clip_sample_range': 1.0,
|
43 |
+
'beta_schedule': 'scaled_linear',
|
44 |
+
'beta_start': 0.00085,
|
45 |
+
'beta_end': 0.03,
|
46 |
+
'prediction_type': 'v_prediction',
|
47 |
+
}
|
48 |
+
},
|
49 |
+
'dpmms': {
|
50 |
+
'scheduler': 'DPMSolverMultistepScheduler',
|
51 |
+
'name': 'DPMMS',
|
52 |
+
'kwargs': {
|
53 |
+
'beta_schedule': 'scaled_linear',
|
54 |
+
'beta_start': 0.00085,
|
55 |
+
'beta_end': 0.03,
|
56 |
+
'prediction_type': 'v_prediction',
|
57 |
+
'trained_betas': None,
|
58 |
+
'solver_order': 2,
|
59 |
+
'algorithm_type': 'dpmsolver++',
|
60 |
+
}
|
61 |
+
},
|
62 |
+
}
|
hydit/diffusion/__init__.py
ADDED
File without changes
|
hydit/diffusion/pipeline.py
ADDED
@@ -0,0 +1,830 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
2 |
+
# you may not use this file except in compliance with the License.
|
3 |
+
# You may obtain a copy of the License at
|
4 |
+
#
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
#
|
7 |
+
# Unless required by applicable law or agreed to in writing, software
|
8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10 |
+
# See the License for the specific language governing permissions and
|
11 |
+
# limitations under the License.
|
12 |
+
|
13 |
+
import inspect
|
14 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
15 |
+
|
16 |
+
import PIL
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import torchvision.transforms as T
|
20 |
+
from diffusers.configuration_utils import FrozenDict
|
21 |
+
from diffusers.image_processor import VaeImageProcessor
|
22 |
+
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
23 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
24 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
26 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
27 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
28 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
29 |
+
from diffusers.utils import (
|
30 |
+
PIL_INTERPOLATION,
|
31 |
+
deprecate,
|
32 |
+
logging,
|
33 |
+
replace_example_docstring,
|
34 |
+
)
|
35 |
+
from diffusers.utils.torch_utils import randn_tensor
|
36 |
+
from transformers import BertModel, BertTokenizer
|
37 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
38 |
+
|
39 |
+
from ..modules.models import HunYuanDiT
|
40 |
+
|
41 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
42 |
+
|
43 |
+
EXAMPLE_DOC_STRING = """
|
44 |
+
Examples:
|
45 |
+
```py
|
46 |
+
>>> import requests
|
47 |
+
>>> import torch
|
48 |
+
>>> from PIL import Image
|
49 |
+
>>> from io import BytesIO
|
50 |
+
|
51 |
+
>>> from diffusers import StableDiffusionImg2ImgPipeline
|
52 |
+
|
53 |
+
>>> device = "cuda"
|
54 |
+
>>> model_id_or_path = "runwayml/stable-diffusion-v1-5"
|
55 |
+
>>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
|
56 |
+
>>> pipe = pipe.to(device)
|
57 |
+
|
58 |
+
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
59 |
+
|
60 |
+
>>> response = requests.get(url)
|
61 |
+
>>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
62 |
+
>>> init_image = init_image.resize((768, 512))
|
63 |
+
|
64 |
+
>>> prompt = "A fantasy landscape, trending on artstation"
|
65 |
+
|
66 |
+
>>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
|
67 |
+
>>> images[0].save("fantasy_landscape.png")
|
68 |
+
```
|
69 |
+
"""
|
70 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
71 |
+
"""
|
72 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
73 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
74 |
+
"""
|
75 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
76 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
77 |
+
# rescale the results from guidance (fixes overexposure)
|
78 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
79 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
80 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
81 |
+
return noise_cfg
|
82 |
+
|
83 |
+
def preprocess(image):
|
84 |
+
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
|
85 |
+
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
|
86 |
+
if isinstance(image, torch.Tensor):
|
87 |
+
return image
|
88 |
+
elif isinstance(image, PIL.Image.Image):
|
89 |
+
image = [image]
|
90 |
+
|
91 |
+
if isinstance(image[0], PIL.Image.Image):
|
92 |
+
w, h = image[0].size
|
93 |
+
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
94 |
+
|
95 |
+
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
96 |
+
image = np.concatenate(image, axis=0)
|
97 |
+
image = np.array(image).astype(np.float32) / 255.0
|
98 |
+
image = image.transpose(0, 3, 1, 2)
|
99 |
+
image = 2.0 * image - 1.0
|
100 |
+
image = torch.from_numpy(image)
|
101 |
+
elif isinstance(image[0], torch.Tensor):
|
102 |
+
image = torch.cat(image, dim=0)
|
103 |
+
return image
|
104 |
+
|
105 |
+
|
106 |
+
class StableDiffusionPipeline(
|
107 |
+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
108 |
+
):
|
109 |
+
r"""
|
110 |
+
Pipeline for text-guided image-to-image generation using Stable Diffusion.
|
111 |
+
|
112 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
113 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
114 |
+
|
115 |
+
The pipeline also inherits the following loading methods:
|
116 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
117 |
+
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
118 |
+
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
119 |
+
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
120 |
+
|
121 |
+
Args:
|
122 |
+
vae ([`AutoencoderKL`]):
|
123 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
124 |
+
text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
|
125 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
126 |
+
tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
|
127 |
+
A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
|
128 |
+
unet (Optional[`HunYuanDiT`, `UNet2DConditionModel`]):
|
129 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
130 |
+
scheduler ([`SchedulerMixin`]):
|
131 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
132 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
133 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
134 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
135 |
+
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
136 |
+
about a model's potential harms.
|
137 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
138 |
+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
139 |
+
"""
|
140 |
+
model_cpu_offload_seq = "text_encoder->unet->vae"
|
141 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
142 |
+
_exclude_from_cpu_offload = ["safety_checker"]
|
143 |
+
|
144 |
+
def __init__(
|
145 |
+
self,
|
146 |
+
vae: AutoencoderKL,
|
147 |
+
text_encoder: Union[BertModel, CLIPTextModel],
|
148 |
+
tokenizer: Union[BertTokenizer, CLIPTokenizer],
|
149 |
+
unet: Union[HunYuanDiT, UNet2DConditionModel],
|
150 |
+
scheduler: KarrasDiffusionSchedulers,
|
151 |
+
safety_checker: StableDiffusionSafetyChecker,
|
152 |
+
feature_extractor: CLIPImageProcessor,
|
153 |
+
requires_safety_checker: bool = True,
|
154 |
+
progress_bar_config: Dict[str, Any] = None,
|
155 |
+
embedder_t5=None,
|
156 |
+
infer_mode='torch',
|
157 |
+
):
|
158 |
+
super().__init__()
|
159 |
+
|
160 |
+
# ========================================================
|
161 |
+
self.embedder_t5 = embedder_t5
|
162 |
+
self.infer_mode = infer_mode
|
163 |
+
|
164 |
+
# ========================================================
|
165 |
+
if progress_bar_config is None:
|
166 |
+
progress_bar_config = {}
|
167 |
+
if not hasattr(self, '_progress_bar_config'):
|
168 |
+
self._progress_bar_config = {}
|
169 |
+
self._progress_bar_config.update(progress_bar_config)
|
170 |
+
# ========================================================
|
171 |
+
|
172 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
173 |
+
deprecation_message = (
|
174 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
175 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
176 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
177 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
178 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
179 |
+
" file"
|
180 |
+
)
|
181 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
182 |
+
new_config = dict(scheduler.config)
|
183 |
+
new_config["steps_offset"] = 1
|
184 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
185 |
+
|
186 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
187 |
+
deprecation_message = (
|
188 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
189 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
190 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
191 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
192 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
193 |
+
)
|
194 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
195 |
+
new_config = dict(scheduler.config)
|
196 |
+
new_config["clip_sample"] = False
|
197 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
198 |
+
|
199 |
+
if safety_checker is None and requires_safety_checker:
|
200 |
+
logger.warning(
|
201 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
202 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
203 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
204 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
205 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
206 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
207 |
+
)
|
208 |
+
|
209 |
+
if safety_checker is not None and feature_extractor is None:
|
210 |
+
raise ValueError(
|
211 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
212 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
213 |
+
)
|
214 |
+
|
215 |
+
self.register_modules(
|
216 |
+
vae=vae,
|
217 |
+
text_encoder=text_encoder,
|
218 |
+
tokenizer=tokenizer,
|
219 |
+
unet=unet,
|
220 |
+
scheduler=scheduler,
|
221 |
+
safety_checker=safety_checker,
|
222 |
+
feature_extractor=feature_extractor,
|
223 |
+
)
|
224 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
225 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
226 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
227 |
+
|
228 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
229 |
+
def _encode_prompt(
|
230 |
+
self,
|
231 |
+
prompt,
|
232 |
+
device,
|
233 |
+
num_images_per_prompt,
|
234 |
+
do_classifier_free_guidance,
|
235 |
+
negative_prompt=None,
|
236 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
237 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
238 |
+
lora_scale: Optional[float] = None,
|
239 |
+
):
|
240 |
+
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
241 |
+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
242 |
+
|
243 |
+
prompt_embeds_tuple = self.encode_prompt(
|
244 |
+
prompt=prompt,
|
245 |
+
device=device,
|
246 |
+
num_images_per_prompt=num_images_per_prompt,
|
247 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
248 |
+
negative_prompt=negative_prompt,
|
249 |
+
prompt_embeds=prompt_embeds,
|
250 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
251 |
+
lora_scale=lora_scale,
|
252 |
+
)
|
253 |
+
|
254 |
+
# concatenate for backwards comp
|
255 |
+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
256 |
+
|
257 |
+
return prompt_embeds
|
258 |
+
|
259 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
260 |
+
def encode_prompt(
|
261 |
+
self,
|
262 |
+
prompt,
|
263 |
+
device,
|
264 |
+
num_images_per_prompt,
|
265 |
+
do_classifier_free_guidance,
|
266 |
+
negative_prompt=None,
|
267 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
268 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
269 |
+
lora_scale: Optional[float] = None,
|
270 |
+
embedder=None,
|
271 |
+
):
|
272 |
+
r"""
|
273 |
+
Encodes the prompt into text encoder hidden states.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
prompt (`str` or `List[str]`, *optional*):
|
277 |
+
prompt to be encoded
|
278 |
+
device: (`torch.device`):
|
279 |
+
torch device
|
280 |
+
num_images_per_prompt (`int`):
|
281 |
+
number of images that should be generated per prompt
|
282 |
+
do_classifier_free_guidance (`bool`):
|
283 |
+
whether to use classifier free guidance or not
|
284 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
285 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
286 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
287 |
+
less than `1`).
|
288 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
289 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
290 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
291 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
292 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
293 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
294 |
+
argument.
|
295 |
+
lora_scale (`float`, *optional*):
|
296 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
297 |
+
embedder:
|
298 |
+
T5 embedder (including text encoder and tokenizer)
|
299 |
+
"""
|
300 |
+
if embedder is None:
|
301 |
+
text_encoder = self.text_encoder
|
302 |
+
tokenizer = self.tokenizer
|
303 |
+
max_length = self.tokenizer.model_max_length
|
304 |
+
else:
|
305 |
+
text_encoder = embedder.model
|
306 |
+
tokenizer = embedder.tokenizer
|
307 |
+
max_length = embedder.max_length
|
308 |
+
|
309 |
+
# set lora scale so that monkey patched LoRA
|
310 |
+
# function of text encoder can correctly access it
|
311 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
312 |
+
self._lora_scale = lora_scale
|
313 |
+
|
314 |
+
# dynamically adjust the LoRA scale
|
315 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
316 |
+
|
317 |
+
if prompt is not None and isinstance(prompt, str):
|
318 |
+
batch_size = 1
|
319 |
+
elif prompt is not None and isinstance(prompt, list):
|
320 |
+
batch_size = len(prompt)
|
321 |
+
else:
|
322 |
+
batch_size = prompt_embeds.shape[0]
|
323 |
+
|
324 |
+
if prompt_embeds is None:
|
325 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
326 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
327 |
+
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
328 |
+
|
329 |
+
text_inputs = tokenizer(
|
330 |
+
prompt,
|
331 |
+
padding="max_length",
|
332 |
+
max_length=max_length,
|
333 |
+
truncation=True,
|
334 |
+
return_attention_mask=True,
|
335 |
+
return_tensors="pt",
|
336 |
+
)
|
337 |
+
text_input_ids = text_inputs.input_ids
|
338 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
339 |
+
|
340 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
341 |
+
text_input_ids, untruncated_ids
|
342 |
+
):
|
343 |
+
removed_text = tokenizer.batch_decode(
|
344 |
+
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
|
345 |
+
)
|
346 |
+
logger.warning(
|
347 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
348 |
+
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
349 |
+
)
|
350 |
+
|
351 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
352 |
+
prompt_embeds = text_encoder(
|
353 |
+
text_input_ids.to(device),
|
354 |
+
attention_mask=attention_mask,
|
355 |
+
)
|
356 |
+
prompt_embeds = prompt_embeds[0]
|
357 |
+
attention_mask = attention_mask.repeat(num_images_per_prompt, 1)
|
358 |
+
else:
|
359 |
+
attention_mask = None
|
360 |
+
|
361 |
+
if text_encoder is not None:
|
362 |
+
prompt_embeds_dtype = text_encoder.dtype
|
363 |
+
elif self.unet is not None:
|
364 |
+
prompt_embeds_dtype = self.unet.dtype
|
365 |
+
else:
|
366 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
367 |
+
|
368 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
369 |
+
|
370 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
371 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
372 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
373 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
374 |
+
|
375 |
+
# get unconditional embeddings for classifier free guidance
|
376 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
377 |
+
uncond_tokens: List[str]
|
378 |
+
if negative_prompt is None:
|
379 |
+
uncond_tokens = [""] * batch_size
|
380 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
381 |
+
raise TypeError(
|
382 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
383 |
+
f" {type(prompt)}."
|
384 |
+
)
|
385 |
+
elif isinstance(negative_prompt, str):
|
386 |
+
uncond_tokens = [negative_prompt]
|
387 |
+
elif batch_size != len(negative_prompt):
|
388 |
+
raise ValueError(
|
389 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
390 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
391 |
+
" the batch size of `prompt`."
|
392 |
+
)
|
393 |
+
else:
|
394 |
+
uncond_tokens = negative_prompt
|
395 |
+
|
396 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
397 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
398 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
|
399 |
+
|
400 |
+
max_length = prompt_embeds.shape[1]
|
401 |
+
uncond_input = tokenizer(
|
402 |
+
uncond_tokens,
|
403 |
+
padding="max_length",
|
404 |
+
max_length=max_length,
|
405 |
+
truncation=True,
|
406 |
+
return_tensors="pt",
|
407 |
+
)
|
408 |
+
|
409 |
+
uncond_attention_mask = uncond_input.attention_mask.to(device)
|
410 |
+
negative_prompt_embeds = text_encoder(
|
411 |
+
uncond_input.input_ids.to(device),
|
412 |
+
attention_mask=uncond_attention_mask,
|
413 |
+
)
|
414 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
415 |
+
uncond_attention_mask = uncond_attention_mask.repeat(num_images_per_prompt, 1)
|
416 |
+
else:
|
417 |
+
uncond_attention_mask = None
|
418 |
+
|
419 |
+
if do_classifier_free_guidance:
|
420 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
421 |
+
seq_len = negative_prompt_embeds.shape[1]
|
422 |
+
|
423 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
424 |
+
|
425 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
426 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
427 |
+
|
428 |
+
return prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask
|
429 |
+
|
430 |
+
def _convert_to_rgb(self, image):
|
431 |
+
return image.convert('RGB')
|
432 |
+
|
433 |
+
def image_transform(self, image_size=224):
|
434 |
+
transform = T.Compose([
|
435 |
+
T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC),
|
436 |
+
self._convert_to_rgb,
|
437 |
+
T.ToTensor(),
|
438 |
+
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
439 |
+
])
|
440 |
+
return transform
|
441 |
+
|
442 |
+
def encode_img(self, img, device, do_classifier_free_guidance):
|
443 |
+
# print('len', len(img))
|
444 |
+
# print('img', img.size)
|
445 |
+
img = img[0] # TODO: support batch processing
|
446 |
+
image_preprocess = self.image_transform(224)
|
447 |
+
img_for_clip = image_preprocess(img)
|
448 |
+
# print('img_for_clip', img_for_clip.shape)
|
449 |
+
img_for_clip = img_for_clip.unsqueeze(0)
|
450 |
+
img_clip_embedding = self.img_encoder(img_for_clip.to(device)).to(dtype=torch.float16)
|
451 |
+
# print('img_clip_embedding_1_type', img_clip_embedding.dtype)
|
452 |
+
if do_classifier_free_guidance:
|
453 |
+
negative_img_clip_embedding = torch.zeros_like(img_clip_embedding)
|
454 |
+
return img_clip_embedding, negative_img_clip_embedding
|
455 |
+
|
456 |
+
|
457 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
458 |
+
def run_safety_checker(self, image, device, dtype):
|
459 |
+
if self.safety_checker is None:
|
460 |
+
has_nsfw_concept = None
|
461 |
+
else:
|
462 |
+
if torch.is_tensor(image):
|
463 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
464 |
+
else:
|
465 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
466 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
467 |
+
image, has_nsfw_concept = self.safety_checker(
|
468 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
469 |
+
)
|
470 |
+
return image, has_nsfw_concept
|
471 |
+
|
472 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
473 |
+
def decode_latents(self, latents):
|
474 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
475 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
476 |
+
|
477 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
478 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
479 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
480 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
481 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
482 |
+
return image
|
483 |
+
|
484 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
485 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
486 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
487 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
488 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
489 |
+
# and should be between [0, 1]
|
490 |
+
|
491 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
492 |
+
extra_step_kwargs = {}
|
493 |
+
if accepts_eta:
|
494 |
+
extra_step_kwargs["eta"] = eta
|
495 |
+
|
496 |
+
# check if the scheduler accepts generator
|
497 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
498 |
+
if accepts_generator:
|
499 |
+
extra_step_kwargs["generator"] = generator
|
500 |
+
return extra_step_kwargs
|
501 |
+
|
502 |
+
def check_inputs(
|
503 |
+
self,
|
504 |
+
prompt,
|
505 |
+
height,
|
506 |
+
width,
|
507 |
+
callback_steps,
|
508 |
+
negative_prompt=None,
|
509 |
+
prompt_embeds=None,
|
510 |
+
negative_prompt_embeds=None,
|
511 |
+
):
|
512 |
+
if height % 8 != 0 or width % 8 != 0:
|
513 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
514 |
+
|
515 |
+
if (callback_steps is None) or (
|
516 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
517 |
+
):
|
518 |
+
raise ValueError(
|
519 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
520 |
+
f" {type(callback_steps)}."
|
521 |
+
)
|
522 |
+
|
523 |
+
if prompt is not None and prompt_embeds is not None:
|
524 |
+
raise ValueError(
|
525 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
526 |
+
" only forward one of the two."
|
527 |
+
)
|
528 |
+
elif prompt is None and prompt_embeds is None:
|
529 |
+
raise ValueError(
|
530 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
531 |
+
)
|
532 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
533 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
534 |
+
|
535 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
536 |
+
raise ValueError(
|
537 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
538 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
539 |
+
)
|
540 |
+
|
541 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
542 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
543 |
+
raise ValueError(
|
544 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
545 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
546 |
+
f" {negative_prompt_embeds.shape}."
|
547 |
+
)
|
548 |
+
|
549 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
550 |
+
# get the original timestep using init_timestep
|
551 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
552 |
+
|
553 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
554 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
555 |
+
|
556 |
+
return timesteps, num_inference_steps - t_start
|
557 |
+
|
558 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
559 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
560 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
561 |
+
raise ValueError(
|
562 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
563 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
564 |
+
)
|
565 |
+
|
566 |
+
if latents is None:
|
567 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
568 |
+
else:
|
569 |
+
latents = latents.to(device)
|
570 |
+
|
571 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
572 |
+
latents = latents * self.scheduler.init_noise_sigma
|
573 |
+
return latents
|
574 |
+
|
575 |
+
@torch.no_grad()
|
576 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
577 |
+
def __call__(
|
578 |
+
self,
|
579 |
+
height: int,
|
580 |
+
width: int,
|
581 |
+
prompt: Union[str, List[str]] = None,
|
582 |
+
num_inference_steps: Optional[int] = 50,
|
583 |
+
guidance_scale: Optional[float] = 7.5,
|
584 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
585 |
+
num_images_per_prompt: Optional[int] = 1,
|
586 |
+
eta: Optional[float] = 0.0,
|
587 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
588 |
+
latents: Optional[torch.FloatTensor] = None,
|
589 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
590 |
+
prompt_embeds_t5: Optional[torch.FloatTensor] = None,
|
591 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
592 |
+
negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None,
|
593 |
+
output_type: Optional[str] = "pil",
|
594 |
+
return_dict: bool = True,
|
595 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor, torch.FloatTensor], None]] = None,
|
596 |
+
callback_steps: int = 1,
|
597 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
598 |
+
guidance_rescale: float = 0.0,
|
599 |
+
image_meta_size: Optional[torch.LongTensor] = None,
|
600 |
+
style: Optional[torch.LongTensor] = None,
|
601 |
+
progress: bool = True,
|
602 |
+
use_fp16: bool = False,
|
603 |
+
freqs_cis_img: Optional[tuple] = None,
|
604 |
+
learn_sigma: bool = True,
|
605 |
+
):
|
606 |
+
r"""
|
607 |
+
The call function to the pipeline for generation.
|
608 |
+
|
609 |
+
Args:
|
610 |
+
height (`int`):
|
611 |
+
The height in pixels of the generated image.
|
612 |
+
width (`int`):
|
613 |
+
The width in pixels of the generated image.
|
614 |
+
prompt (`str` or `List[str]`, *optional*):
|
615 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
616 |
+
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
617 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
618 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
619 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
620 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
621 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
622 |
+
strength (`float`, *optional*, defaults to 1.0):
|
623 |
+
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
624 |
+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
625 |
+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
626 |
+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
627 |
+
essentially ignores `image`.
|
628 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
629 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
630 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
631 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
632 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
633 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
634 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
635 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
636 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
637 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
638 |
+
The number of images to generate per prompt.
|
639 |
+
eta (`float`, *optional*, defaults to 0.0):
|
640 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
641 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
642 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
643 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
644 |
+
generation deterministic.
|
645 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
646 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
647 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
648 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
649 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
650 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
651 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
652 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
653 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
654 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
655 |
+
plain tuple.
|
656 |
+
callback (`Callable`, *optional*):
|
657 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
658 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor,
|
659 |
+
pred_x0: torch.FloatTensor)`.
|
660 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
661 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
662 |
+
every step.
|
663 |
+
cross_attention_kwargs (`dict`, *optional*):
|
664 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
665 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
666 |
+
|
667 |
+
Examples:
|
668 |
+
|
669 |
+
Returns:
|
670 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
671 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
672 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
673 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
674 |
+
"not-safe-for-work" (nsfw) content.
|
675 |
+
"""
|
676 |
+
# 1. Check inputs. Raise error if not correct
|
677 |
+
self.check_inputs(
|
678 |
+
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
679 |
+
)
|
680 |
+
|
681 |
+
# 2. Define call parameters
|
682 |
+
if prompt is not None and isinstance(prompt, str):
|
683 |
+
batch_size = 1
|
684 |
+
elif prompt is not None and isinstance(prompt, list):
|
685 |
+
batch_size = len(prompt)
|
686 |
+
else:
|
687 |
+
batch_size = prompt_embeds.shape[0]
|
688 |
+
|
689 |
+
device = self._execution_device
|
690 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
691 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
692 |
+
# corresponds to doing no classifier free guidance.
|
693 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
694 |
+
|
695 |
+
# 3. Encode input prompt
|
696 |
+
text_encoder_lora_scale = (
|
697 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
698 |
+
)
|
699 |
+
|
700 |
+
prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask = \
|
701 |
+
self.encode_prompt(prompt,
|
702 |
+
device,
|
703 |
+
num_images_per_prompt,
|
704 |
+
do_classifier_free_guidance,
|
705 |
+
negative_prompt,
|
706 |
+
prompt_embeds=prompt_embeds,
|
707 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
708 |
+
lora_scale=text_encoder_lora_scale,
|
709 |
+
)
|
710 |
+
prompt_embeds_t5, negative_prompt_embeds_t5, attention_mask_t5, uncond_attention_mask_t5 = \
|
711 |
+
self.encode_prompt(prompt,
|
712 |
+
device,
|
713 |
+
num_images_per_prompt,
|
714 |
+
do_classifier_free_guidance,
|
715 |
+
negative_prompt,
|
716 |
+
prompt_embeds=prompt_embeds_t5,
|
717 |
+
negative_prompt_embeds=negative_prompt_embeds_t5,
|
718 |
+
lora_scale=text_encoder_lora_scale,
|
719 |
+
embedder=self.embedder_t5,
|
720 |
+
)
|
721 |
+
|
722 |
+
# For classifier free guidance, we need to do two forward passes.
|
723 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
724 |
+
# to avoid doing two forward passes
|
725 |
+
if do_classifier_free_guidance:
|
726 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
727 |
+
attention_mask = torch.cat([uncond_attention_mask, attention_mask])
|
728 |
+
prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5])
|
729 |
+
attention_mask_t5 = torch.cat([uncond_attention_mask_t5, attention_mask_t5])
|
730 |
+
|
731 |
+
# 4. Prepare timesteps
|
732 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
733 |
+
timesteps = self.scheduler.timesteps
|
734 |
+
|
735 |
+
# 6. Prepare latent variables
|
736 |
+
num_channels_latents = self.unet.config.in_channels
|
737 |
+
latents = self.prepare_latents(batch_size * num_images_per_prompt,
|
738 |
+
num_channels_latents,
|
739 |
+
height,
|
740 |
+
width,
|
741 |
+
prompt_embeds.dtype,
|
742 |
+
device,
|
743 |
+
generator,
|
744 |
+
latents,
|
745 |
+
)
|
746 |
+
|
747 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
748 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
749 |
+
|
750 |
+
# 8. Denoising loop
|
751 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
752 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
753 |
+
for i, t in enumerate(timesteps):
|
754 |
+
# expand the latents if we are doing classifier free guidance
|
755 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
756 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
757 |
+
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
|
758 |
+
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=latent_model_input.device)
|
759 |
+
|
760 |
+
if use_fp16:
|
761 |
+
latent_model_input = latent_model_input.half()
|
762 |
+
t_expand = t_expand.half()
|
763 |
+
prompt_embeds = prompt_embeds.half()
|
764 |
+
ims = image_meta_size.half() if image_meta_size is not None else None
|
765 |
+
else:
|
766 |
+
ims = image_meta_size if image_meta_size is not None else None
|
767 |
+
|
768 |
+
# predict the noise residual
|
769 |
+
if self.infer_mode in ["fa", "torch"]:
|
770 |
+
noise_pred = self.unet(
|
771 |
+
latent_model_input,
|
772 |
+
t_expand,
|
773 |
+
encoder_hidden_states=prompt_embeds,
|
774 |
+
text_embedding_mask=attention_mask,
|
775 |
+
encoder_hidden_states_t5=prompt_embeds_t5,
|
776 |
+
text_embedding_mask_t5=attention_mask_t5,
|
777 |
+
image_meta_size=ims,
|
778 |
+
style=style,
|
779 |
+
cos_cis_img=freqs_cis_img[0],
|
780 |
+
sin_cis_img=freqs_cis_img[1],
|
781 |
+
return_dict=False,
|
782 |
+
)
|
783 |
+
elif self.infer_mode == "trt":
|
784 |
+
raise NotImplementedError("TensorRT model is not supported yet.")
|
785 |
+
else:
|
786 |
+
raise ValueError("[ERROR] invalid inference mode! please check your config file")
|
787 |
+
if learn_sigma:
|
788 |
+
noise_pred, _ = noise_pred.chunk(2, dim=1)
|
789 |
+
|
790 |
+
# perform guidance
|
791 |
+
if do_classifier_free_guidance:
|
792 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
793 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
794 |
+
|
795 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
796 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
797 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
798 |
+
|
799 |
+
# compute the previous noisy sample x_t -> x_t-1
|
800 |
+
results = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
|
801 |
+
latents = results.prev_sample
|
802 |
+
pred_x0 = results.pred_original_sample if hasattr(results, 'pred_original_sample') else None
|
803 |
+
|
804 |
+
# call the callback, if provided
|
805 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
806 |
+
progress_bar.update()
|
807 |
+
if callback is not None and i % callback_steps == 0:
|
808 |
+
callback(i, t, latents, pred_x0)
|
809 |
+
|
810 |
+
if not output_type == "latent":
|
811 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
812 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
813 |
+
else:
|
814 |
+
image = latents
|
815 |
+
has_nsfw_concept = None
|
816 |
+
|
817 |
+
if has_nsfw_concept is None:
|
818 |
+
do_denormalize = [True] * image.shape[0]
|
819 |
+
else:
|
820 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
821 |
+
|
822 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
823 |
+
|
824 |
+
# Offload all models
|
825 |
+
self.maybe_free_model_hooks()
|
826 |
+
|
827 |
+
if not return_dict:
|
828 |
+
return (image, has_nsfw_concept)
|
829 |
+
|
830 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
hydit/inference.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
# For reproducibility
|
9 |
+
# torch.backends.cudnn.benchmark = False
|
10 |
+
# torch.backends.cudnn.deterministic = True
|
11 |
+
|
12 |
+
from diffusers import schedulers
|
13 |
+
from diffusers.models import AutoencoderKL
|
14 |
+
from loguru import logger
|
15 |
+
from transformers import BertModel, BertTokenizer
|
16 |
+
from transformers.modeling_utils import logger as tf_logger
|
17 |
+
|
18 |
+
from .constants import SAMPLER_FACTORY, NEGATIVE_PROMPT
|
19 |
+
from .diffusion.pipeline import StableDiffusionPipeline
|
20 |
+
from .modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG
|
21 |
+
from .modules.posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
22 |
+
from .modules.text_encoder import MT5Embedder
|
23 |
+
from .utils.tools import set_seeds
|
24 |
+
|
25 |
+
|
26 |
+
class Resolution:
|
27 |
+
def __init__(self, width, height):
|
28 |
+
self.width = width
|
29 |
+
self.height = height
|
30 |
+
|
31 |
+
def __str__(self):
|
32 |
+
return f'{self.height}x{self.width}'
|
33 |
+
|
34 |
+
|
35 |
+
class ResolutionGroup:
|
36 |
+
def __init__(self):
|
37 |
+
self.data = [
|
38 |
+
Resolution(768, 768), # 1:1
|
39 |
+
Resolution(1024, 1024), # 1:1
|
40 |
+
Resolution(1280, 1280), # 1:1
|
41 |
+
Resolution(1024, 768), # 4:3
|
42 |
+
Resolution(1152, 864), # 4:3
|
43 |
+
Resolution(1280, 960), # 4:3
|
44 |
+
Resolution(768, 1024), # 3:4
|
45 |
+
Resolution(864, 1152), # 3:4
|
46 |
+
Resolution(960, 1280), # 3:4
|
47 |
+
Resolution(1280, 768), # 16:9
|
48 |
+
Resolution(768, 1280), # 9:16
|
49 |
+
]
|
50 |
+
self.supported_sizes = set([(r.width, r.height) for r in self.data])
|
51 |
+
|
52 |
+
def is_valid(self, width, height):
|
53 |
+
return (width, height) in self.supported_sizes
|
54 |
+
|
55 |
+
|
56 |
+
STANDARD_RATIO = np.array([
|
57 |
+
1.0, # 1:1
|
58 |
+
4.0 / 3.0, # 4:3
|
59 |
+
3.0 / 4.0, # 3:4
|
60 |
+
16.0 / 9.0, # 16:9
|
61 |
+
9.0 / 16.0, # 9:16
|
62 |
+
])
|
63 |
+
STANDARD_SHAPE = [
|
64 |
+
[(768, 768), (1024, 1024), (1280, 1280)], # 1:1
|
65 |
+
[(1024, 768), (1152, 864), (1280, 960)], # 4:3
|
66 |
+
[(768, 1024), (864, 1152), (960, 1280)], # 3:4
|
67 |
+
[(1280, 768)], # 16:9
|
68 |
+
[(768, 1280)], # 9:16
|
69 |
+
]
|
70 |
+
STANDARD_AREA = [
|
71 |
+
np.array([w * h for w, h in shapes])
|
72 |
+
for shapes in STANDARD_SHAPE
|
73 |
+
]
|
74 |
+
|
75 |
+
|
76 |
+
def get_standard_shape(target_width, target_height):
|
77 |
+
"""
|
78 |
+
Map image size to standard size.
|
79 |
+
"""
|
80 |
+
target_ratio = target_width / target_height
|
81 |
+
closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
|
82 |
+
closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
|
83 |
+
width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
|
84 |
+
return width, height
|
85 |
+
|
86 |
+
|
87 |
+
def _to_tuple(val):
|
88 |
+
if isinstance(val, (list, tuple)):
|
89 |
+
if len(val) == 1:
|
90 |
+
val = [val[0], val[0]]
|
91 |
+
elif len(val) == 2:
|
92 |
+
val = tuple(val)
|
93 |
+
else:
|
94 |
+
raise ValueError(f"Invalid value: {val}")
|
95 |
+
elif isinstance(val, (int, float)):
|
96 |
+
val = (val, val)
|
97 |
+
else:
|
98 |
+
raise ValueError(f"Invalid value: {val}")
|
99 |
+
return val
|
100 |
+
|
101 |
+
|
102 |
+
def get_pipeline(args, vae, text_encoder, tokenizer, model, device, rank,
|
103 |
+
embedder_t5, infer_mode, sampler=None):
|
104 |
+
"""
|
105 |
+
Get scheduler and pipeline for sampling. The sampler and pipeline are both
|
106 |
+
based on diffusers and make some modifications.
|
107 |
+
|
108 |
+
Returns
|
109 |
+
-------
|
110 |
+
pipeline: StableDiffusionPipeline
|
111 |
+
sampler_name: str
|
112 |
+
"""
|
113 |
+
sampler = sampler or args.sampler
|
114 |
+
|
115 |
+
# Load sampler from factory
|
116 |
+
kwargs = SAMPLER_FACTORY[sampler]['kwargs']
|
117 |
+
scheduler = SAMPLER_FACTORY[sampler]['scheduler']
|
118 |
+
|
119 |
+
# Update sampler according to the arguments
|
120 |
+
kwargs['beta_schedule'] = args.noise_schedule
|
121 |
+
kwargs['beta_start'] = args.beta_start
|
122 |
+
kwargs['beta_end'] = args.beta_end
|
123 |
+
kwargs['prediction_type'] = args.predict_type
|
124 |
+
|
125 |
+
# Build scheduler according to the sampler.
|
126 |
+
scheduler_class = getattr(schedulers, scheduler)
|
127 |
+
scheduler = scheduler_class(**kwargs)
|
128 |
+
|
129 |
+
# Set timesteps for inference steps.
|
130 |
+
scheduler.set_timesteps(args.infer_steps, device)
|
131 |
+
|
132 |
+
# Only enable progress bar for rank 0
|
133 |
+
progress_bar_config = {} if rank == 0 else {'disable': True}
|
134 |
+
|
135 |
+
pipeline = StableDiffusionPipeline(vae=vae,
|
136 |
+
text_encoder=text_encoder,
|
137 |
+
tokenizer=tokenizer,
|
138 |
+
unet=model,
|
139 |
+
scheduler=scheduler,
|
140 |
+
feature_extractor=None,
|
141 |
+
safety_checker=None,
|
142 |
+
requires_safety_checker=False,
|
143 |
+
progress_bar_config=progress_bar_config,
|
144 |
+
embedder_t5=embedder_t5,
|
145 |
+
infer_mode=infer_mode,
|
146 |
+
)
|
147 |
+
|
148 |
+
pipeline = pipeline.to(device)
|
149 |
+
|
150 |
+
return pipeline, sampler
|
151 |
+
|
152 |
+
|
153 |
+
class End2End(object):
|
154 |
+
def __init__(self, args, models_root_path):
|
155 |
+
self.args = args
|
156 |
+
|
157 |
+
# Check arguments
|
158 |
+
t2i_root_path = Path(models_root_path) / "t2i"
|
159 |
+
self.root = t2i_root_path
|
160 |
+
logger.info(f"Got text-to-image model root path: {t2i_root_path}")
|
161 |
+
|
162 |
+
# Set device and disable gradient
|
163 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
164 |
+
torch.set_grad_enabled(False)
|
165 |
+
# Disable BertModel logging checkpoint info
|
166 |
+
tf_logger.setLevel('ERROR')
|
167 |
+
|
168 |
+
# ========================================================================
|
169 |
+
model_dir = self.root / "model"
|
170 |
+
|
171 |
+
# ========================================================================
|
172 |
+
logger.info(f"Loading CLIP Text Encoder...")
|
173 |
+
text_encoder_path = self.root / "clip_text_encoder"
|
174 |
+
self.clip_text_encoder = BertModel.from_pretrained(str(text_encoder_path), False, revision=None).to(self.device)
|
175 |
+
logger.info(f"Loading CLIP Text Encoder finished")
|
176 |
+
|
177 |
+
# ========================================================================
|
178 |
+
logger.info(f"Loading CLIP Tokenizer...")
|
179 |
+
tokenizer_path = self.root / "tokenizer"
|
180 |
+
self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path))
|
181 |
+
logger.info(f"Loading CLIP Tokenizer finished")
|
182 |
+
|
183 |
+
# ========================================================================
|
184 |
+
logger.info(f"Loading T5 Text Encoder and T5 Tokenizer...")
|
185 |
+
t5_text_encoder_path = self.root / 'mt5'
|
186 |
+
embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256)
|
187 |
+
self.embedder_t5 = embedder_t5
|
188 |
+
logger.info(f"Loading t5_text_encoder and t5_tokenizer finished")
|
189 |
+
|
190 |
+
# ========================================================================
|
191 |
+
logger.info(f"Loading VAE...")
|
192 |
+
vae_path = self.root / "sdxl-vae-fp16-fix"
|
193 |
+
self.vae = AutoencoderKL.from_pretrained(str(vae_path)).to(self.device)
|
194 |
+
logger.info(f"Loading VAE finished")
|
195 |
+
|
196 |
+
# ========================================================================
|
197 |
+
# Create model structure and load the checkpoint
|
198 |
+
logger.info(f"Building HunYuan-DiT model...")
|
199 |
+
model_config = HUNYUAN_DIT_CONFIG[self.args.model]
|
200 |
+
self.patch_size = model_config['patch_size']
|
201 |
+
self.head_size = model_config['hidden_size'] // model_config['num_heads']
|
202 |
+
self.resolutions, self.freqs_cis_img = self.standard_shapes() # Used for TensorRT models
|
203 |
+
self.image_size = _to_tuple(self.args.image_size)
|
204 |
+
latent_size = (self.image_size[0] // 8, self.image_size[1] // 8)
|
205 |
+
|
206 |
+
self.infer_mode = self.args.infer_mode
|
207 |
+
if self.infer_mode in ['fa', 'torch']:
|
208 |
+
model_path = model_dir / f"pytorch_model_{self.args.load_key}.pt"
|
209 |
+
if not model_path.exists():
|
210 |
+
raise ValueError(f"model_path not exists: {model_path}")
|
211 |
+
# Build model structure
|
212 |
+
self.model = HunYuanDiT(self.args,
|
213 |
+
input_size=latent_size,
|
214 |
+
**model_config,
|
215 |
+
log_fn=logger.info,
|
216 |
+
).half().to(self.device) # Force to use fp16
|
217 |
+
# Load model checkpoint
|
218 |
+
logger.info(f"Loading model checkpoint {model_path}...")
|
219 |
+
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
|
220 |
+
self.model.load_state_dict(state_dict)
|
221 |
+
self.model.eval()
|
222 |
+
elif self.infer_mode == 'trt':
|
223 |
+
raise NotImplementedError("TensorRT model is not supported yet.")
|
224 |
+
else:
|
225 |
+
raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
|
226 |
+
|
227 |
+
# ========================================================================
|
228 |
+
# Build inference pipeline. We use a customized StableDiffusionPipeline.
|
229 |
+
logger.info(f"Loading inference pipeline...")
|
230 |
+
self.pipeline, self.sampler = self.load_sampler()
|
231 |
+
logger.info(f'Loading pipeline finished')
|
232 |
+
|
233 |
+
# ========================================================================
|
234 |
+
self.default_negative_prompt = NEGATIVE_PROMPT
|
235 |
+
logger.info("==================================================")
|
236 |
+
logger.info(f" Model is ready. ")
|
237 |
+
logger.info("==================================================")
|
238 |
+
|
239 |
+
def load_sampler(self, sampler=None):
|
240 |
+
pipeline, sampler = get_pipeline(self.args,
|
241 |
+
self.vae,
|
242 |
+
self.clip_text_encoder,
|
243 |
+
self.tokenizer,
|
244 |
+
self.model,
|
245 |
+
device=self.device,
|
246 |
+
rank=0,
|
247 |
+
embedder_t5=self.embedder_t5,
|
248 |
+
infer_mode=self.infer_mode,
|
249 |
+
sampler=sampler,
|
250 |
+
)
|
251 |
+
return pipeline, sampler
|
252 |
+
|
253 |
+
def calc_rope(self, height, width):
|
254 |
+
th = height // 8 // self.patch_size
|
255 |
+
tw = width // 8 // self.patch_size
|
256 |
+
base_size = 512 // 8 // self.patch_size
|
257 |
+
start, stop = get_fill_resize_and_crop((th, tw), base_size)
|
258 |
+
sub_args = [start, stop, (th, tw)]
|
259 |
+
rope = get_2d_rotary_pos_embed(self.head_size, *sub_args)
|
260 |
+
return rope
|
261 |
+
|
262 |
+
def standard_shapes(self):
|
263 |
+
resolutions = ResolutionGroup()
|
264 |
+
freqs_cis_img = {}
|
265 |
+
for reso in resolutions.data:
|
266 |
+
freqs_cis_img[str(reso)] = self.calc_rope(reso.height, reso.width)
|
267 |
+
return resolutions, freqs_cis_img
|
268 |
+
|
269 |
+
def predict(self,
|
270 |
+
user_prompt,
|
271 |
+
height=1024,
|
272 |
+
width=1024,
|
273 |
+
seed=None,
|
274 |
+
enhanced_prompt=None,
|
275 |
+
negative_prompt=None,
|
276 |
+
infer_steps=100,
|
277 |
+
guidance_scale=6,
|
278 |
+
batch_size=1,
|
279 |
+
src_size_cond=(1024, 1024),
|
280 |
+
sampler=None,
|
281 |
+
):
|
282 |
+
# ========================================================================
|
283 |
+
# Arguments: seed
|
284 |
+
# ========================================================================
|
285 |
+
if seed is None:
|
286 |
+
seed = random.randint(0, 1_000_000)
|
287 |
+
if not isinstance(seed, int):
|
288 |
+
raise TypeError(f"`seed` must be an integer, but got {type(seed)}")
|
289 |
+
generator = set_seeds(seed)
|
290 |
+
|
291 |
+
# ========================================================================
|
292 |
+
# Arguments: target_width, target_height
|
293 |
+
# ========================================================================
|
294 |
+
if width <= 0 or height <= 0:
|
295 |
+
raise ValueError(f"`height` and `width` must be positive integers, got height={height}, width={width}")
|
296 |
+
logger.info(f"Input (height, width) = ({height}, {width})")
|
297 |
+
if self.infer_mode in ['fa', 'torch']:
|
298 |
+
# We must force height and width to align to 16 and to be an integer.
|
299 |
+
target_height = int((height // 16) * 16)
|
300 |
+
target_width = int((width // 16) * 16)
|
301 |
+
logger.info(f"Align to 16: (height, width) = ({target_height}, {target_width})")
|
302 |
+
elif self.infer_mode == 'trt':
|
303 |
+
target_width, target_height = get_standard_shape(width, height)
|
304 |
+
logger.info(f"Align to standard shape: (height, width) = ({target_height}, {target_width})")
|
305 |
+
else:
|
306 |
+
raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
|
307 |
+
|
308 |
+
# ========================================================================
|
309 |
+
# Arguments: prompt, new_prompt, negative_prompt
|
310 |
+
# ========================================================================
|
311 |
+
if not isinstance(user_prompt, str):
|
312 |
+
raise TypeError(f"`user_prompt` must be a string, but got {type(user_prompt)}")
|
313 |
+
user_prompt = user_prompt.strip()
|
314 |
+
prompt = user_prompt
|
315 |
+
|
316 |
+
if enhanced_prompt is not None:
|
317 |
+
if not isinstance(enhanced_prompt, str):
|
318 |
+
raise TypeError(f"`enhanced_prompt` must be a string, but got {type(enhanced_prompt)}")
|
319 |
+
enhanced_prompt = enhanced_prompt.strip()
|
320 |
+
prompt = enhanced_prompt
|
321 |
+
|
322 |
+
# negative prompt
|
323 |
+
if negative_prompt is None or negative_prompt == '':
|
324 |
+
negative_prompt = self.default_negative_prompt
|
325 |
+
if not isinstance(negative_prompt, str):
|
326 |
+
raise TypeError(f"`negative_prompt` must be a string, but got {type(negative_prompt)}")
|
327 |
+
|
328 |
+
# ========================================================================
|
329 |
+
# Arguments: style. (A fixed argument. Don't Change it.)
|
330 |
+
# ========================================================================
|
331 |
+
style = torch.as_tensor([0, 0] * batch_size, device=self.device)
|
332 |
+
|
333 |
+
# ========================================================================
|
334 |
+
# Inner arguments: image_meta_size (Please refer to SDXL.)
|
335 |
+
# ========================================================================
|
336 |
+
if isinstance(src_size_cond, int):
|
337 |
+
src_size_cond = [src_size_cond, src_size_cond]
|
338 |
+
if not isinstance(src_size_cond, (list, tuple)):
|
339 |
+
raise TypeError(f"`src_size_cond` must be a list or tuple, but got {type(src_size_cond)}")
|
340 |
+
if len(src_size_cond) != 2:
|
341 |
+
raise ValueError(f"`src_size_cond` must be a tuple of 2 integers, but got {len(src_size_cond)}")
|
342 |
+
size_cond = list(src_size_cond) + [target_width, target_height, 0, 0]
|
343 |
+
image_meta_size = torch.as_tensor([size_cond] * 2 * batch_size, device=self.device)
|
344 |
+
|
345 |
+
# ========================================================================
|
346 |
+
start_time = time.time()
|
347 |
+
logger.debug(f"""
|
348 |
+
prompt: {user_prompt}
|
349 |
+
enhanced prompt: {enhanced_prompt}
|
350 |
+
seed: {seed}
|
351 |
+
(height, width): {(target_height, target_width)}
|
352 |
+
negative_prompt: {negative_prompt}
|
353 |
+
batch_size: {batch_size}
|
354 |
+
guidance_scale: {guidance_scale}
|
355 |
+
infer_steps: {infer_steps}
|
356 |
+
image_meta_size: {size_cond}
|
357 |
+
""")
|
358 |
+
reso = f'{target_height}x{target_width}'
|
359 |
+
if reso in self.freqs_cis_img:
|
360 |
+
freqs_cis_img = self.freqs_cis_img[reso]
|
361 |
+
else:
|
362 |
+
freqs_cis_img = self.calc_rope(target_height, target_width)
|
363 |
+
|
364 |
+
if sampler is not None and sampler != self.sampler:
|
365 |
+
self.pipeline, self.sampler = self.load_sampler(sampler)
|
366 |
+
|
367 |
+
samples = self.pipeline(
|
368 |
+
height=target_height,
|
369 |
+
width=target_width,
|
370 |
+
prompt=prompt,
|
371 |
+
negative_prompt=negative_prompt,
|
372 |
+
num_images_per_prompt=batch_size,
|
373 |
+
guidance_scale=guidance_scale,
|
374 |
+
num_inference_steps=infer_steps,
|
375 |
+
image_meta_size=image_meta_size,
|
376 |
+
style=style,
|
377 |
+
return_dict=False,
|
378 |
+
generator=generator,
|
379 |
+
freqs_cis_img=freqs_cis_img,
|
380 |
+
use_fp16=self.args.use_fp16,
|
381 |
+
learn_sigma=self.args.learn_sigma,
|
382 |
+
)[0]
|
383 |
+
gen_time = time.time() - start_time
|
384 |
+
logger.debug(f"Success, time: {gen_time}")
|
385 |
+
|
386 |
+
return {
|
387 |
+
'images': samples,
|
388 |
+
'seed': seed,
|
389 |
+
}
|
hydit/modules/__init__.py
ADDED
File without changes
|
hydit/modules/attn_layers.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import Tuple, Union, Optional
|
4 |
+
|
5 |
+
try:
|
6 |
+
import flash_attn
|
7 |
+
if hasattr(flash_attn, '__version__') and int(flash_attn.__version__[0]) == 2:
|
8 |
+
from flash_attn.flash_attn_interface import flash_attn_kvpacked_func
|
9 |
+
from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention
|
10 |
+
else:
|
11 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
|
12 |
+
from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention
|
13 |
+
except Exception as e:
|
14 |
+
print(f'flash_attn import failed: {e}')
|
15 |
+
|
16 |
+
|
17 |
+
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
|
18 |
+
"""
|
19 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
20 |
+
|
21 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
22 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
26 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
27 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
torch.Tensor: Reshaped frequency tensor.
|
31 |
+
|
32 |
+
Raises:
|
33 |
+
AssertionError: If the frequency tensor doesn't match the expected shape.
|
34 |
+
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
35 |
+
"""
|
36 |
+
ndim = x.ndim
|
37 |
+
assert 0 <= 1 < ndim
|
38 |
+
|
39 |
+
if isinstance(freqs_cis, tuple):
|
40 |
+
# freqs_cis: (cos, sin) in real space
|
41 |
+
if head_first:
|
42 |
+
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
|
43 |
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
44 |
+
else:
|
45 |
+
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
|
46 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
47 |
+
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
48 |
+
else:
|
49 |
+
# freqs_cis: values in complex space
|
50 |
+
if head_first:
|
51 |
+
assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
|
52 |
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
53 |
+
else:
|
54 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
|
55 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
56 |
+
return freqs_cis.view(*shape)
|
57 |
+
|
58 |
+
|
59 |
+
def rotate_half(x):
|
60 |
+
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
61 |
+
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
62 |
+
|
63 |
+
|
64 |
+
def apply_rotary_emb(
|
65 |
+
xq: torch.Tensor,
|
66 |
+
xk: Optional[torch.Tensor],
|
67 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
68 |
+
head_first: bool = False,
|
69 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
70 |
+
"""
|
71 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
72 |
+
|
73 |
+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
74 |
+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
75 |
+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
76 |
+
returned as real tensors.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
80 |
+
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
81 |
+
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
|
82 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
86 |
+
|
87 |
+
"""
|
88 |
+
xk_out = None
|
89 |
+
if isinstance(freqs_cis, tuple):
|
90 |
+
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
91 |
+
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
92 |
+
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
93 |
+
if xk is not None:
|
94 |
+
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
95 |
+
else:
|
96 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
97 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
98 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
99 |
+
if xk is not None:
|
100 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
101 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
102 |
+
|
103 |
+
return xq_out, xk_out
|
104 |
+
|
105 |
+
|
106 |
+
class FlashSelfMHAModified(nn.Module):
|
107 |
+
"""
|
108 |
+
Use QK Normalization.
|
109 |
+
"""
|
110 |
+
def __init__(self,
|
111 |
+
dim,
|
112 |
+
num_heads,
|
113 |
+
qkv_bias=True,
|
114 |
+
qk_norm=False,
|
115 |
+
attn_drop=0.0,
|
116 |
+
proj_drop=0.0,
|
117 |
+
device=None,
|
118 |
+
dtype=None,
|
119 |
+
norm_layer=nn.LayerNorm,
|
120 |
+
):
|
121 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
122 |
+
super().__init__()
|
123 |
+
self.dim = dim
|
124 |
+
self.num_heads = num_heads
|
125 |
+
assert self.dim % num_heads == 0, "self.kdim must be divisible by num_heads"
|
126 |
+
self.head_dim = self.dim // num_heads
|
127 |
+
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
128 |
+
|
129 |
+
self.Wqkv = nn.Linear(dim, 3 * dim, bias=qkv_bias, **factory_kwargs)
|
130 |
+
# TODO: eps should be 1 / 65530 if using fp16
|
131 |
+
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
132 |
+
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
133 |
+
self.inner_attn = FlashSelfAttention(attention_dropout=attn_drop)
|
134 |
+
self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs)
|
135 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
136 |
+
|
137 |
+
def forward(self, x, freqs_cis_img=None):
|
138 |
+
"""
|
139 |
+
Parameters
|
140 |
+
----------
|
141 |
+
x: torch.Tensor
|
142 |
+
(batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
|
143 |
+
freqs_cis_img: torch.Tensor
|
144 |
+
(batch, hidden_dim // 2), RoPE for image
|
145 |
+
"""
|
146 |
+
b, s, d = x.shape
|
147 |
+
|
148 |
+
qkv = self.Wqkv(x)
|
149 |
+
qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, h, d]
|
150 |
+
q, k, v = qkv.unbind(dim=2) # [b, s, h, d]
|
151 |
+
q = self.q_norm(q).half() # [b, s, h, d]
|
152 |
+
k = self.k_norm(k).half()
|
153 |
+
|
154 |
+
# Apply RoPE if needed
|
155 |
+
if freqs_cis_img is not None:
|
156 |
+
qq, kk = apply_rotary_emb(q, k, freqs_cis_img)
|
157 |
+
assert qq.shape == q.shape and kk.shape == k.shape, f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
|
158 |
+
q, k = qq, kk
|
159 |
+
|
160 |
+
qkv = torch.stack([q, k, v], dim=2) # [b, s, 3, h, d]
|
161 |
+
context = self.inner_attn(qkv)
|
162 |
+
out = self.out_proj(context.view(b, s, d))
|
163 |
+
out = self.proj_drop(out)
|
164 |
+
|
165 |
+
out_tuple = (out,)
|
166 |
+
|
167 |
+
return out_tuple
|
168 |
+
|
169 |
+
|
170 |
+
class FlashCrossMHAModified(nn.Module):
|
171 |
+
"""
|
172 |
+
Use QK Normalization.
|
173 |
+
"""
|
174 |
+
def __init__(self,
|
175 |
+
qdim,
|
176 |
+
kdim,
|
177 |
+
num_heads,
|
178 |
+
qkv_bias=True,
|
179 |
+
qk_norm=False,
|
180 |
+
attn_drop=0.0,
|
181 |
+
proj_drop=0.0,
|
182 |
+
device=None,
|
183 |
+
dtype=None,
|
184 |
+
norm_layer=nn.LayerNorm,
|
185 |
+
):
|
186 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
187 |
+
super().__init__()
|
188 |
+
self.qdim = qdim
|
189 |
+
self.kdim = kdim
|
190 |
+
self.num_heads = num_heads
|
191 |
+
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
|
192 |
+
self.head_dim = self.qdim // num_heads
|
193 |
+
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
194 |
+
|
195 |
+
self.scale = self.head_dim ** -0.5
|
196 |
+
|
197 |
+
self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
198 |
+
self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
|
199 |
+
|
200 |
+
# TODO: eps should be 1 / 65530 if using fp16
|
201 |
+
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
202 |
+
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
203 |
+
|
204 |
+
self.inner_attn = FlashCrossAttention(attention_dropout=attn_drop)
|
205 |
+
self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
206 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
207 |
+
|
208 |
+
def forward(self, x, y, freqs_cis_img=None):
|
209 |
+
"""
|
210 |
+
Parameters
|
211 |
+
----------
|
212 |
+
x: torch.Tensor
|
213 |
+
(batch, seqlen1, hidden_dim) (where hidden_dim = num_heads * head_dim)
|
214 |
+
y: torch.Tensor
|
215 |
+
(batch, seqlen2, hidden_dim2)
|
216 |
+
freqs_cis_img: torch.Tensor
|
217 |
+
(batch, hidden_dim // num_heads), RoPE for image
|
218 |
+
"""
|
219 |
+
b, s1, _ = x.shape # [b, s1, D]
|
220 |
+
_, s2, _ = y.shape # [b, s2, 1024]
|
221 |
+
|
222 |
+
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
|
223 |
+
kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
|
224 |
+
k, v = kv.unbind(dim=2) # [b, s2, h, d]
|
225 |
+
q = self.q_norm(q).half() # [b, s1, h, d]
|
226 |
+
k = self.k_norm(k).half() # [b, s2, h, d]
|
227 |
+
|
228 |
+
# Apply RoPE if needed
|
229 |
+
if freqs_cis_img is not None:
|
230 |
+
qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
|
231 |
+
assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
|
232 |
+
q = qq # [b, s1, h, d]
|
233 |
+
kv = torch.stack([k, v], dim=2) # [b, s1, 2, h, d]
|
234 |
+
context = self.inner_attn(q, kv) # [b, s1, h, d]
|
235 |
+
context = context.view(b, s1, -1) # [b, s1, D]
|
236 |
+
|
237 |
+
out = self.out_proj(context)
|
238 |
+
out = self.proj_drop(out)
|
239 |
+
|
240 |
+
out_tuple = (out,)
|
241 |
+
|
242 |
+
return out_tuple
|
243 |
+
|
244 |
+
|
245 |
+
class CrossAttention(nn.Module):
|
246 |
+
"""
|
247 |
+
Use QK Normalization.
|
248 |
+
"""
|
249 |
+
def __init__(self,
|
250 |
+
qdim,
|
251 |
+
kdim,
|
252 |
+
num_heads,
|
253 |
+
qkv_bias=True,
|
254 |
+
qk_norm=False,
|
255 |
+
attn_drop=0.0,
|
256 |
+
proj_drop=0.0,
|
257 |
+
device=None,
|
258 |
+
dtype=None,
|
259 |
+
norm_layer=nn.LayerNorm,
|
260 |
+
):
|
261 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
262 |
+
super().__init__()
|
263 |
+
self.qdim = qdim
|
264 |
+
self.kdim = kdim
|
265 |
+
self.num_heads = num_heads
|
266 |
+
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
|
267 |
+
self.head_dim = self.qdim // num_heads
|
268 |
+
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
269 |
+
self.scale = self.head_dim ** -0.5
|
270 |
+
|
271 |
+
self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
272 |
+
self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
|
273 |
+
|
274 |
+
# TODO: eps should be 1 / 65530 if using fp16
|
275 |
+
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
276 |
+
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
277 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
278 |
+
self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
279 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
280 |
+
|
281 |
+
def forward(self, x, y, freqs_cis_img=None):
|
282 |
+
"""
|
283 |
+
Parameters
|
284 |
+
----------
|
285 |
+
x: torch.Tensor
|
286 |
+
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
|
287 |
+
y: torch.Tensor
|
288 |
+
(batch, seqlen2, hidden_dim2)
|
289 |
+
freqs_cis_img: torch.Tensor
|
290 |
+
(batch, hidden_dim // 2), RoPE for image
|
291 |
+
"""
|
292 |
+
b, s1, c = x.shape # [b, s1, D]
|
293 |
+
_, s2, c = y.shape # [b, s2, 1024]
|
294 |
+
|
295 |
+
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
|
296 |
+
kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
|
297 |
+
k, v = kv.unbind(dim=2) # [b, s, h, d]
|
298 |
+
q = self.q_norm(q)
|
299 |
+
k = self.k_norm(k)
|
300 |
+
|
301 |
+
# Apply RoPE if needed
|
302 |
+
if freqs_cis_img is not None:
|
303 |
+
qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
|
304 |
+
assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
|
305 |
+
q = qq
|
306 |
+
|
307 |
+
q = q * self.scale
|
308 |
+
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
|
309 |
+
k = k.permute(0, 2, 3, 1).contiguous() # k -> B, L2, H, C - B, H, C, L2
|
310 |
+
attn = q @ k # attn -> B, H, L1, L2
|
311 |
+
attn = attn.softmax(dim=-1) # attn -> B, H, L1, L2
|
312 |
+
attn = self.attn_drop(attn)
|
313 |
+
x = attn @ v.transpose(-2, -3) # v -> B, L2, H, C - B, H, L2, C x-> B, H, L1, C
|
314 |
+
context = x.transpose(1, 2) # context -> B, H, L1, C - B, L1, H, C
|
315 |
+
|
316 |
+
context = context.contiguous().view(b, s1, -1)
|
317 |
+
|
318 |
+
out = self.out_proj(context) # context.reshape - B, L1, -1
|
319 |
+
out = self.proj_drop(out)
|
320 |
+
|
321 |
+
out_tuple = (out,)
|
322 |
+
|
323 |
+
return out_tuple
|
324 |
+
|
325 |
+
|
326 |
+
class Attention(nn.Module):
|
327 |
+
"""
|
328 |
+
We rename some layer names to align with flash attention
|
329 |
+
"""
|
330 |
+
def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0.,
|
331 |
+
norm_layer=nn.LayerNorm,
|
332 |
+
):
|
333 |
+
super().__init__()
|
334 |
+
self.dim = dim
|
335 |
+
self.num_heads = num_heads
|
336 |
+
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
|
337 |
+
self.head_dim = self.dim // num_heads
|
338 |
+
# This assertion is aligned with flash attention
|
339 |
+
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
340 |
+
self.scale = self.head_dim ** -0.5
|
341 |
+
|
342 |
+
# qkv --> Wqkv
|
343 |
+
self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
344 |
+
# TODO: eps should be 1 / 65530 if using fp16
|
345 |
+
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
346 |
+
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
347 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
348 |
+
self.out_proj = nn.Linear(dim, dim)
|
349 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
350 |
+
|
351 |
+
def forward(self, x, freqs_cis_img=None):
|
352 |
+
B, N, C = x.shape
|
353 |
+
qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d]
|
354 |
+
q, k, v = qkv.unbind(0) # [b, h, s, d]
|
355 |
+
q = self.q_norm(q) # [b, h, s, d]
|
356 |
+
k = self.k_norm(k) # [b, h, s, d]
|
357 |
+
|
358 |
+
# Apply RoPE if needed
|
359 |
+
if freqs_cis_img is not None:
|
360 |
+
qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
|
361 |
+
assert qq.shape == q.shape and kk.shape == k.shape, \
|
362 |
+
f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
|
363 |
+
q, k = qq, kk
|
364 |
+
|
365 |
+
q = q * self.scale
|
366 |
+
attn = q @ k.transpose(-2, -1) # [b, h, s, d] @ [b, h, d, s]
|
367 |
+
attn = attn.softmax(dim=-1) # [b, h, s, s]
|
368 |
+
attn = self.attn_drop(attn)
|
369 |
+
x = attn @ v # [b, h, s, d]
|
370 |
+
|
371 |
+
x = x.transpose(1, 2).reshape(B, N, C) # [b, s, h, d]
|
372 |
+
x = self.out_proj(x)
|
373 |
+
x = self.proj_drop(x)
|
374 |
+
|
375 |
+
out_tuple = (x,)
|
376 |
+
|
377 |
+
return out_tuple
|
hydit/modules/embedders.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from einops import repeat
|
5 |
+
|
6 |
+
from timm.models.layers import to_2tuple
|
7 |
+
|
8 |
+
|
9 |
+
class PatchEmbed(nn.Module):
|
10 |
+
""" 2D Image to Patch Embedding
|
11 |
+
|
12 |
+
Image to Patch Embedding using Conv2d
|
13 |
+
|
14 |
+
A convolution based approach to patchifying a 2D image w/ embedding projection.
|
15 |
+
|
16 |
+
Based on the impl in https://github.com/google-research/vision_transformer
|
17 |
+
|
18 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
19 |
+
|
20 |
+
Remove the _assert function in forward function to be compatible with multi-resolution images.
|
21 |
+
"""
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
img_size=224,
|
25 |
+
patch_size=16,
|
26 |
+
in_chans=3,
|
27 |
+
embed_dim=768,
|
28 |
+
norm_layer=None,
|
29 |
+
flatten=True,
|
30 |
+
bias=True,
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
if isinstance(img_size, int):
|
34 |
+
img_size = to_2tuple(img_size)
|
35 |
+
elif isinstance(img_size, (tuple, list)) and len(img_size) == 2:
|
36 |
+
img_size = tuple(img_size)
|
37 |
+
else:
|
38 |
+
raise ValueError(f"img_size must be int or tuple/list of length 2. Got {img_size}")
|
39 |
+
patch_size = to_2tuple(patch_size)
|
40 |
+
self.img_size = img_size
|
41 |
+
self.patch_size = patch_size
|
42 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
43 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
44 |
+
self.flatten = flatten
|
45 |
+
|
46 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
47 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
48 |
+
|
49 |
+
def update_image_size(self, img_size):
|
50 |
+
self.img_size = img_size
|
51 |
+
self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1])
|
52 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
# B, C, H, W = x.shape
|
56 |
+
# _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
57 |
+
# _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
58 |
+
x = self.proj(x)
|
59 |
+
if self.flatten:
|
60 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
61 |
+
x = self.norm(x)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
|
66 |
+
"""
|
67 |
+
Create sinusoidal timestep embeddings.
|
68 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
69 |
+
These may be fractional.
|
70 |
+
:param dim: the dimension of the output.
|
71 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
72 |
+
:return: an (N, D) Tensor of positional embeddings.
|
73 |
+
"""
|
74 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
75 |
+
if not repeat_only:
|
76 |
+
half = dim // 2
|
77 |
+
freqs = torch.exp(
|
78 |
+
-math.log(max_period)
|
79 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
80 |
+
/ half
|
81 |
+
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
|
82 |
+
args = t[:, None].float() * freqs[None]
|
83 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
84 |
+
if dim % 2:
|
85 |
+
embedding = torch.cat(
|
86 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
87 |
+
)
|
88 |
+
else:
|
89 |
+
embedding = repeat(t, "b -> b d", d=dim)
|
90 |
+
return embedding
|
91 |
+
|
92 |
+
|
93 |
+
class TimestepEmbedder(nn.Module):
|
94 |
+
"""
|
95 |
+
Embeds scalar timesteps into vector representations.
|
96 |
+
"""
|
97 |
+
def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None):
|
98 |
+
super().__init__()
|
99 |
+
if out_size is None:
|
100 |
+
out_size = hidden_size
|
101 |
+
self.mlp = nn.Sequential(
|
102 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
103 |
+
nn.SiLU(),
|
104 |
+
nn.Linear(hidden_size, out_size, bias=True),
|
105 |
+
)
|
106 |
+
self.frequency_embedding_size = frequency_embedding_size
|
107 |
+
|
108 |
+
def forward(self, t):
|
109 |
+
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
110 |
+
t_emb = self.mlp(t_freq)
|
111 |
+
return t_emb
|
hydit/modules/models.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
5 |
+
from diffusers.models import ModelMixin
|
6 |
+
from timm.models.vision_transformer import Mlp
|
7 |
+
|
8 |
+
from .attn_layers import Attention, FlashCrossMHAModified, FlashSelfMHAModified, CrossAttention
|
9 |
+
from .embedders import TimestepEmbedder, PatchEmbed, timestep_embedding
|
10 |
+
from .norm_layers import RMSNorm
|
11 |
+
from .poolers import AttentionPool
|
12 |
+
|
13 |
+
|
14 |
+
def modulate(x, shift, scale):
|
15 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
16 |
+
|
17 |
+
|
18 |
+
class FP32_Layernorm(nn.LayerNorm):
|
19 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
20 |
+
origin_dtype = inputs.dtype
|
21 |
+
return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(),
|
22 |
+
self.eps).to(origin_dtype)
|
23 |
+
|
24 |
+
|
25 |
+
class FP32_SiLU(nn.SiLU):
|
26 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
27 |
+
return torch.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype)
|
28 |
+
|
29 |
+
|
30 |
+
class HunYuanDiTBlock(nn.Module):
|
31 |
+
"""
|
32 |
+
A HunYuanDiT block with `add` conditioning.
|
33 |
+
"""
|
34 |
+
def __init__(self,
|
35 |
+
hidden_size,
|
36 |
+
c_emb_size,
|
37 |
+
num_heads,
|
38 |
+
mlp_ratio=4.0,
|
39 |
+
text_states_dim=1024,
|
40 |
+
use_flash_attn=False,
|
41 |
+
qk_norm=False,
|
42 |
+
norm_type="layer",
|
43 |
+
skip=False,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.use_flash_attn = use_flash_attn
|
47 |
+
use_ele_affine = True
|
48 |
+
|
49 |
+
if norm_type == "layer":
|
50 |
+
norm_layer = FP32_Layernorm
|
51 |
+
elif norm_type == "rms":
|
52 |
+
norm_layer = RMSNorm
|
53 |
+
else:
|
54 |
+
raise ValueError(f"Unknown norm_type: {norm_type}")
|
55 |
+
|
56 |
+
# ========================= Self-Attention =========================
|
57 |
+
self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
|
58 |
+
if use_flash_attn:
|
59 |
+
self.attn1 = FlashSelfMHAModified(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm)
|
60 |
+
else:
|
61 |
+
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm)
|
62 |
+
|
63 |
+
# ========================= FFN =========================
|
64 |
+
self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
|
65 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
66 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
67 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
68 |
+
|
69 |
+
# ========================= Add =========================
|
70 |
+
# Simply use add like SDXL.
|
71 |
+
self.default_modulation = nn.Sequential(
|
72 |
+
FP32_SiLU(),
|
73 |
+
nn.Linear(c_emb_size, hidden_size, bias=True)
|
74 |
+
)
|
75 |
+
|
76 |
+
# ========================= Cross-Attention =========================
|
77 |
+
if use_flash_attn:
|
78 |
+
self.attn2 = FlashCrossMHAModified(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
|
79 |
+
qk_norm=qk_norm)
|
80 |
+
else:
|
81 |
+
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
|
82 |
+
qk_norm=qk_norm)
|
83 |
+
self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
|
84 |
+
|
85 |
+
# ========================= Skip Connection =========================
|
86 |
+
if skip:
|
87 |
+
self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6)
|
88 |
+
self.skip_linear = nn.Linear(2 * hidden_size, hidden_size)
|
89 |
+
else:
|
90 |
+
self.skip_linear = None
|
91 |
+
|
92 |
+
def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
|
93 |
+
# Long Skip Connection
|
94 |
+
if self.skip_linear is not None:
|
95 |
+
cat = torch.cat([x, skip], dim=-1)
|
96 |
+
cat = self.skip_norm(cat)
|
97 |
+
x = self.skip_linear(cat)
|
98 |
+
|
99 |
+
# Self-Attention
|
100 |
+
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
|
101 |
+
attn_inputs = (
|
102 |
+
self.norm1(x) + shift_msa, freq_cis_img,
|
103 |
+
)
|
104 |
+
x = x + self.attn1(*attn_inputs)[0]
|
105 |
+
|
106 |
+
# Cross-Attention
|
107 |
+
cross_inputs = (
|
108 |
+
self.norm3(x), text_states, freq_cis_img
|
109 |
+
)
|
110 |
+
x = x + self.attn2(*cross_inputs)[0]
|
111 |
+
|
112 |
+
# FFN Layer
|
113 |
+
mlp_inputs = self.norm2(x)
|
114 |
+
x = x + self.mlp(mlp_inputs)
|
115 |
+
|
116 |
+
return x
|
117 |
+
|
118 |
+
|
119 |
+
class FinalLayer(nn.Module):
|
120 |
+
"""
|
121 |
+
The final layer of HunYuanDiT.
|
122 |
+
"""
|
123 |
+
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels):
|
124 |
+
super().__init__()
|
125 |
+
self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
|
126 |
+
self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
|
127 |
+
self.adaLN_modulation = nn.Sequential(
|
128 |
+
FP32_SiLU(),
|
129 |
+
nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)
|
130 |
+
)
|
131 |
+
|
132 |
+
def forward(self, x, c):
|
133 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
134 |
+
x = modulate(self.norm_final(x), shift, scale)
|
135 |
+
x = self.linear(x)
|
136 |
+
return x
|
137 |
+
|
138 |
+
|
139 |
+
class HunYuanDiT(ModelMixin, ConfigMixin):
|
140 |
+
"""
|
141 |
+
HunYuanDiT: Diffusion model with a Transformer backbone.
|
142 |
+
|
143 |
+
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
144 |
+
|
145 |
+
Parameters
|
146 |
+
----------
|
147 |
+
args: argparse.Namespace
|
148 |
+
The arguments parsed by argparse.
|
149 |
+
input_size: tuple
|
150 |
+
The size of the input image.
|
151 |
+
patch_size: int
|
152 |
+
The size of the patch.
|
153 |
+
in_channels: int
|
154 |
+
The number of input channels.
|
155 |
+
hidden_size: int
|
156 |
+
The hidden size of the transformer backbone.
|
157 |
+
depth: int
|
158 |
+
The number of transformer blocks.
|
159 |
+
num_heads: int
|
160 |
+
The number of attention heads.
|
161 |
+
mlp_ratio: float
|
162 |
+
The ratio of the hidden size of the MLP in the transformer block.
|
163 |
+
log_fn: callable
|
164 |
+
The logging function.
|
165 |
+
"""
|
166 |
+
@register_to_config
|
167 |
+
def __init__(
|
168 |
+
self, args,
|
169 |
+
input_size=(32, 32),
|
170 |
+
patch_size=2,
|
171 |
+
in_channels=4,
|
172 |
+
hidden_size=1152,
|
173 |
+
depth=28,
|
174 |
+
num_heads=16,
|
175 |
+
mlp_ratio=4.0,
|
176 |
+
log_fn=print,
|
177 |
+
):
|
178 |
+
super().__init__()
|
179 |
+
self.args = args
|
180 |
+
self.log_fn = log_fn
|
181 |
+
self.depth = depth
|
182 |
+
self.learn_sigma = args.learn_sigma
|
183 |
+
self.in_channels = in_channels
|
184 |
+
self.out_channels = in_channels * 2 if args.learn_sigma else in_channels
|
185 |
+
self.patch_size = patch_size
|
186 |
+
self.num_heads = num_heads
|
187 |
+
self.hidden_size = hidden_size
|
188 |
+
self.text_states_dim = args.text_states_dim
|
189 |
+
self.text_states_dim_t5 = args.text_states_dim_t5
|
190 |
+
self.text_len = args.text_len
|
191 |
+
self.text_len_t5 = args.text_len_t5
|
192 |
+
self.norm = args.norm
|
193 |
+
|
194 |
+
use_flash_attn = args.infer_mode == 'fa'
|
195 |
+
if use_flash_attn:
|
196 |
+
log_fn(f" Enable Flash Attention.")
|
197 |
+
qk_norm = True # See http://arxiv.org/abs/2302.05442 for details.
|
198 |
+
|
199 |
+
self.mlp_t5 = nn.Sequential(
|
200 |
+
nn.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True),
|
201 |
+
FP32_SiLU(),
|
202 |
+
nn.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True),
|
203 |
+
)
|
204 |
+
# learnable replace
|
205 |
+
self.text_embedding_padding = nn.Parameter(
|
206 |
+
torch.randn(self.text_len + self.text_len_t5, self.text_states_dim, dtype=torch.float32))
|
207 |
+
|
208 |
+
# Attention pooling
|
209 |
+
self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=1024)
|
210 |
+
|
211 |
+
# Here we use a default learned embedder layer for future extension.
|
212 |
+
self.style_embedder = nn.Embedding(1, hidden_size)
|
213 |
+
|
214 |
+
# Image size and crop size conditions
|
215 |
+
self.extra_in_dim = 256 * 6 + hidden_size
|
216 |
+
|
217 |
+
# Text embedding for `add`
|
218 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size)
|
219 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
220 |
+
self.extra_in_dim += 1024
|
221 |
+
self.extra_embedder = nn.Sequential(
|
222 |
+
nn.Linear(self.extra_in_dim, hidden_size * 4),
|
223 |
+
FP32_SiLU(),
|
224 |
+
nn.Linear(hidden_size * 4, hidden_size, bias=True),
|
225 |
+
)
|
226 |
+
|
227 |
+
# Image embedding
|
228 |
+
num_patches = self.x_embedder.num_patches
|
229 |
+
log_fn(f" Number of tokens: {num_patches}")
|
230 |
+
|
231 |
+
# HUnYuanDiT Blocks
|
232 |
+
self.blocks = nn.ModuleList([
|
233 |
+
HunYuanDiTBlock(hidden_size=hidden_size,
|
234 |
+
c_emb_size=hidden_size,
|
235 |
+
num_heads=num_heads,
|
236 |
+
mlp_ratio=mlp_ratio,
|
237 |
+
text_states_dim=self.text_states_dim,
|
238 |
+
use_flash_attn=use_flash_attn,
|
239 |
+
qk_norm=qk_norm,
|
240 |
+
norm_type=self.norm,
|
241 |
+
skip=layer > depth // 2,
|
242 |
+
)
|
243 |
+
for layer in range(depth)
|
244 |
+
])
|
245 |
+
|
246 |
+
self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels)
|
247 |
+
self.unpatchify_channels = self.out_channels
|
248 |
+
|
249 |
+
self.initialize_weights()
|
250 |
+
|
251 |
+
def forward(self,
|
252 |
+
x,
|
253 |
+
t,
|
254 |
+
encoder_hidden_states=None,
|
255 |
+
text_embedding_mask=None,
|
256 |
+
encoder_hidden_states_t5=None,
|
257 |
+
text_embedding_mask_t5=None,
|
258 |
+
image_meta_size=None,
|
259 |
+
style=None,
|
260 |
+
cos_cis_img=None,
|
261 |
+
sin_cis_img=None,
|
262 |
+
return_dict=True,
|
263 |
+
):
|
264 |
+
"""
|
265 |
+
Forward pass of the encoder.
|
266 |
+
|
267 |
+
Parameters
|
268 |
+
----------
|
269 |
+
x: torch.Tensor
|
270 |
+
(B, D, H, W)
|
271 |
+
t: torch.Tensor
|
272 |
+
(B)
|
273 |
+
encoder_hidden_states: torch.Tensor
|
274 |
+
CLIP text embedding, (B, L_clip, D)
|
275 |
+
text_embedding_mask: torch.Tensor
|
276 |
+
CLIP text embedding mask, (B, L_clip)
|
277 |
+
encoder_hidden_states_t5: torch.Tensor
|
278 |
+
T5 text embedding, (B, L_t5, D)
|
279 |
+
text_embedding_mask_t5: torch.Tensor
|
280 |
+
T5 text embedding mask, (B, L_t5)
|
281 |
+
image_meta_size: torch.Tensor
|
282 |
+
(B, 6)
|
283 |
+
style: torch.Tensor
|
284 |
+
(B)
|
285 |
+
cos_cis_img: torch.Tensor
|
286 |
+
sin_cis_img: torch.Tensor
|
287 |
+
return_dict: bool
|
288 |
+
Whether to return a dictionary.
|
289 |
+
"""
|
290 |
+
|
291 |
+
text_states = encoder_hidden_states # 2,77,1024
|
292 |
+
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
293 |
+
text_states_mask = text_embedding_mask.bool() # 2,77
|
294 |
+
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
|
295 |
+
b_t5, l_t5, c_t5 = text_states_t5.shape
|
296 |
+
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5))
|
297 |
+
text_states = torch.cat([text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1) # 2,205,1024
|
298 |
+
clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
|
299 |
+
|
300 |
+
clip_t5_mask = clip_t5_mask
|
301 |
+
text_states = torch.where(clip_t5_mask.unsqueeze(2), text_states, self.text_embedding_padding.to(text_states))
|
302 |
+
|
303 |
+
_, _, oh, ow = x.shape
|
304 |
+
th, tw = oh // self.patch_size, ow // self.patch_size
|
305 |
+
|
306 |
+
# ========================= Build time and image embedding =========================
|
307 |
+
t = self.t_embedder(t)
|
308 |
+
x = self.x_embedder(x)
|
309 |
+
|
310 |
+
# Get image RoPE embedding according to `reso`lution.
|
311 |
+
freqs_cis_img = (cos_cis_img, sin_cis_img)
|
312 |
+
|
313 |
+
# ========================= Concatenate all extra vectors =========================
|
314 |
+
# Build text tokens with pooling
|
315 |
+
extra_vec = self.pooler(encoder_hidden_states_t5)
|
316 |
+
|
317 |
+
# Build image meta size tokens
|
318 |
+
image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
|
319 |
+
if self.args.use_fp16:
|
320 |
+
image_meta_size = image_meta_size.half()
|
321 |
+
image_meta_size = image_meta_size.view(-1, 6 * 256)
|
322 |
+
extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
|
323 |
+
|
324 |
+
# Build style tokens
|
325 |
+
style_embedding = self.style_embedder(style)
|
326 |
+
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
|
327 |
+
|
328 |
+
# Concatenate all extra vectors
|
329 |
+
c = t + self.extra_embedder(extra_vec) # [B, D]
|
330 |
+
|
331 |
+
# ========================= Forward pass through HunYuanDiT blocks =========================
|
332 |
+
skips = []
|
333 |
+
for layer, block in enumerate(self.blocks):
|
334 |
+
if layer > self.depth // 2:
|
335 |
+
skip = skips.pop()
|
336 |
+
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
337 |
+
else:
|
338 |
+
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
|
339 |
+
|
340 |
+
if layer < (self.depth // 2 - 1):
|
341 |
+
skips.append(x)
|
342 |
+
|
343 |
+
# ========================= Final layer =========================
|
344 |
+
x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels)
|
345 |
+
x = self.unpatchify(x, th, tw) # (N, out_channels, H, W)
|
346 |
+
|
347 |
+
if return_dict:
|
348 |
+
return {'x': x}
|
349 |
+
return x
|
350 |
+
|
351 |
+
def initialize_weights(self):
|
352 |
+
# Initialize transformer layers:
|
353 |
+
def _basic_init(module):
|
354 |
+
if isinstance(module, nn.Linear):
|
355 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
356 |
+
if module.bias is not None:
|
357 |
+
nn.init.constant_(module.bias, 0)
|
358 |
+
self.apply(_basic_init)
|
359 |
+
|
360 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
361 |
+
w = self.x_embedder.proj.weight.data
|
362 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
363 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
364 |
+
|
365 |
+
# Initialize label embedding table:
|
366 |
+
nn.init.normal_(self.extra_embedder[0].weight, std=0.02)
|
367 |
+
nn.init.normal_(self.extra_embedder[2].weight, std=0.02)
|
368 |
+
|
369 |
+
# Initialize timestep embedding MLP:
|
370 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
371 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
372 |
+
|
373 |
+
# Zero-out adaLN modulation layers in HunYuanDiT blocks:
|
374 |
+
for block in self.blocks:
|
375 |
+
nn.init.constant_(block.default_modulation[-1].weight, 0)
|
376 |
+
nn.init.constant_(block.default_modulation[-1].bias, 0)
|
377 |
+
|
378 |
+
# Zero-out output layers:
|
379 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
380 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
381 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
382 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
383 |
+
|
384 |
+
def unpatchify(self, x, h, w):
|
385 |
+
"""
|
386 |
+
x: (N, T, patch_size**2 * C)
|
387 |
+
imgs: (N, H, W, C)
|
388 |
+
"""
|
389 |
+
c = self.unpatchify_channels
|
390 |
+
p = self.x_embedder.patch_size[0]
|
391 |
+
# h = w = int(x.shape[1] ** 0.5)
|
392 |
+
assert h * w == x.shape[1]
|
393 |
+
|
394 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
395 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
396 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
397 |
+
return imgs
|
398 |
+
|
399 |
+
|
400 |
+
#################################################################################
|
401 |
+
# HunYuanDiT Configs #
|
402 |
+
#################################################################################
|
403 |
+
|
404 |
+
HUNYUAN_DIT_CONFIG = {
|
405 |
+
'DiT-g/2': {'depth': 40, 'hidden_size': 1408, 'patch_size': 2, 'num_heads': 16, 'mlp_ratio': 4.3637},
|
406 |
+
'DiT-XL/2': {'depth': 28, 'hidden_size': 1152, 'patch_size': 2, 'num_heads': 16},
|
407 |
+
'DiT-L/2': {'depth': 24, 'hidden_size': 1024, 'patch_size': 2, 'num_heads': 16},
|
408 |
+
'DiT-B/2': {'depth': 12, 'hidden_size': 768, 'patch_size': 2, 'num_heads': 12},
|
409 |
+
}
|
hydit/modules/norm_layers.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class RMSNorm(nn.Module):
|
6 |
+
def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6):
|
7 |
+
"""
|
8 |
+
Initialize the RMSNorm normalization layer.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
dim (int): The dimension of the input tensor.
|
12 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
13 |
+
|
14 |
+
Attributes:
|
15 |
+
eps (float): A small value added to the denominator for numerical stability.
|
16 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
17 |
+
|
18 |
+
"""
|
19 |
+
super().__init__()
|
20 |
+
self.eps = eps
|
21 |
+
if elementwise_affine:
|
22 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
23 |
+
|
24 |
+
def _norm(self, x):
|
25 |
+
"""
|
26 |
+
Apply the RMSNorm normalization to the input tensor.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
x (torch.Tensor): The input tensor.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
torch.Tensor: The normalized tensor.
|
33 |
+
|
34 |
+
"""
|
35 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
"""
|
39 |
+
Forward pass through the RMSNorm layer.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
x (torch.Tensor): The input tensor.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
46 |
+
|
47 |
+
"""
|
48 |
+
output = self._norm(x.float()).type_as(x)
|
49 |
+
if hasattr(self, "weight"):
|
50 |
+
output = output * self.weight
|
51 |
+
return output
|
52 |
+
|
53 |
+
|
54 |
+
class GroupNorm32(nn.GroupNorm):
|
55 |
+
def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None):
|
56 |
+
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
y = super().forward(x).to(x.dtype)
|
60 |
+
return y
|
61 |
+
|
62 |
+
def normalization(channels, dtype=None):
|
63 |
+
"""
|
64 |
+
Make a standard normalization layer.
|
65 |
+
:param channels: number of input channels.
|
66 |
+
:return: an nn.Module for normalization.
|
67 |
+
"""
|
68 |
+
return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype)
|
hydit/modules/poolers.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class AttentionPool(nn.Module):
|
7 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
8 |
+
super().__init__()
|
9 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
10 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
11 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
12 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
13 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
14 |
+
self.num_heads = num_heads
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
x = x.permute(1, 0, 2) # NLC -> LNC
|
18 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
19 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
20 |
+
x, _ = F.multi_head_attention_forward(
|
21 |
+
query=x[:1], key=x, value=x,
|
22 |
+
embed_dim_to_check=x.shape[-1],
|
23 |
+
num_heads=self.num_heads,
|
24 |
+
q_proj_weight=self.q_proj.weight,
|
25 |
+
k_proj_weight=self.k_proj.weight,
|
26 |
+
v_proj_weight=self.v_proj.weight,
|
27 |
+
in_proj_weight=None,
|
28 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
29 |
+
bias_k=None,
|
30 |
+
bias_v=None,
|
31 |
+
add_zero_attn=False,
|
32 |
+
dropout_p=0,
|
33 |
+
out_proj_weight=self.c_proj.weight,
|
34 |
+
out_proj_bias=self.c_proj.bias,
|
35 |
+
use_separate_proj_weight=True,
|
36 |
+
training=self.training,
|
37 |
+
need_weights=False
|
38 |
+
)
|
39 |
+
return x.squeeze(0)
|
hydit/modules/posemb_layers.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
|
6 |
+
def _to_tuple(x):
|
7 |
+
if isinstance(x, int):
|
8 |
+
return x, x
|
9 |
+
else:
|
10 |
+
return x
|
11 |
+
|
12 |
+
|
13 |
+
def get_fill_resize_and_crop(src, tgt): # src 来源的分辨率 tgt base 分辨率
|
14 |
+
th, tw = _to_tuple(tgt)
|
15 |
+
h, w = _to_tuple(src)
|
16 |
+
|
17 |
+
tr = th / tw # base 分辨率
|
18 |
+
r = h / w # 目标分辨率
|
19 |
+
|
20 |
+
# resize
|
21 |
+
if r > tr:
|
22 |
+
resize_height = th
|
23 |
+
resize_width = int(round(th / h * w))
|
24 |
+
else:
|
25 |
+
resize_width = tw
|
26 |
+
resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来
|
27 |
+
|
28 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
29 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
30 |
+
|
31 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
32 |
+
|
33 |
+
|
34 |
+
def get_meshgrid(start, *args):
|
35 |
+
if len(args) == 0:
|
36 |
+
# start is grid_size
|
37 |
+
num = _to_tuple(start)
|
38 |
+
start = (0, 0)
|
39 |
+
stop = num
|
40 |
+
elif len(args) == 1:
|
41 |
+
# start is start, args[0] is stop, step is 1
|
42 |
+
start = _to_tuple(start)
|
43 |
+
stop = _to_tuple(args[0])
|
44 |
+
num = (stop[0] - start[0], stop[1] - start[1])
|
45 |
+
elif len(args) == 2:
|
46 |
+
# start is start, args[0] is stop, args[1] is num
|
47 |
+
start = _to_tuple(start) # 左上角 eg: 12,0
|
48 |
+
stop = _to_tuple(args[0]) # 右下角 eg: 20,32
|
49 |
+
num = _to_tuple(args[1]) # 目标大小 eg: 32,124
|
50 |
+
else:
|
51 |
+
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
52 |
+
|
53 |
+
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份
|
54 |
+
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
|
55 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
56 |
+
grid = np.stack(grid, axis=0) # [2, W, H]
|
57 |
+
return grid
|
58 |
+
|
59 |
+
#################################################################################
|
60 |
+
# Sine/Cosine Positional Embedding Functions #
|
61 |
+
#################################################################################
|
62 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
63 |
+
|
64 |
+
def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
|
65 |
+
"""
|
66 |
+
grid_size: int of the grid height and width
|
67 |
+
return:
|
68 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
69 |
+
"""
|
70 |
+
grid = get_meshgrid(start, *args) # [2, H, w]
|
71 |
+
# grid_h = np.arange(grid_size, dtype=np.float32)
|
72 |
+
# grid_w = np.arange(grid_size, dtype=np.float32)
|
73 |
+
# grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
74 |
+
# grid = np.stack(grid, axis=0) # [2, W, H]
|
75 |
+
|
76 |
+
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
77 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
78 |
+
if cls_token and extra_tokens > 0:
|
79 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
80 |
+
return pos_embed
|
81 |
+
|
82 |
+
|
83 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
84 |
+
assert embed_dim % 2 == 0
|
85 |
+
|
86 |
+
# use half of dimensions to encode grid_h
|
87 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
88 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
89 |
+
|
90 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
91 |
+
return emb
|
92 |
+
|
93 |
+
|
94 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
95 |
+
"""
|
96 |
+
embed_dim: output dimension for each position
|
97 |
+
pos: a list of positions to be encoded: size (W,H)
|
98 |
+
out: (M, D)
|
99 |
+
"""
|
100 |
+
assert embed_dim % 2 == 0
|
101 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
102 |
+
omega /= embed_dim / 2.
|
103 |
+
omega = 1. / 10000**omega # (D/2,)
|
104 |
+
|
105 |
+
pos = pos.reshape(-1) # (M,)
|
106 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
107 |
+
|
108 |
+
emb_sin = np.sin(out) # (M, D/2)
|
109 |
+
emb_cos = np.cos(out) # (M, D/2)
|
110 |
+
|
111 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
112 |
+
return emb
|
113 |
+
|
114 |
+
|
115 |
+
#################################################################################
|
116 |
+
# Rotary Positional Embedding Functions #
|
117 |
+
#################################################################################
|
118 |
+
# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
|
119 |
+
|
120 |
+
def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
|
121 |
+
"""
|
122 |
+
This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
|
123 |
+
|
124 |
+
Parameters
|
125 |
+
----------
|
126 |
+
embed_dim: int
|
127 |
+
embedding dimension size
|
128 |
+
start: int or tuple of int
|
129 |
+
If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
|
130 |
+
If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
131 |
+
use_real: bool
|
132 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
133 |
+
|
134 |
+
Returns
|
135 |
+
-------
|
136 |
+
pos_embed: torch.Tensor
|
137 |
+
[HW, D/2]
|
138 |
+
"""
|
139 |
+
grid = get_meshgrid(start, *args) # [2, H, w]
|
140 |
+
grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致
|
141 |
+
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
142 |
+
return pos_embed
|
143 |
+
|
144 |
+
|
145 |
+
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
146 |
+
assert embed_dim % 4 == 0
|
147 |
+
|
148 |
+
# use half of dimensions to encode grid_h
|
149 |
+
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
|
150 |
+
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
|
151 |
+
|
152 |
+
if use_real:
|
153 |
+
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
|
154 |
+
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
|
155 |
+
return cos, sin
|
156 |
+
else:
|
157 |
+
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
158 |
+
return emb
|
159 |
+
|
160 |
+
|
161 |
+
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
|
162 |
+
"""
|
163 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
164 |
+
|
165 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
166 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
167 |
+
The returned tensor contains complex values in complex64 data type.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
dim (int): Dimension of the frequency tensor.
|
171 |
+
pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
|
172 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
173 |
+
use_real (bool, optional): If True, return real part and imaginary part separately.
|
174 |
+
Otherwise, return complex numbers.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
178 |
+
|
179 |
+
"""
|
180 |
+
if isinstance(pos, int):
|
181 |
+
pos = np.arange(pos)
|
182 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
183 |
+
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
184 |
+
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
185 |
+
if use_real:
|
186 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
187 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
188 |
+
return freqs_cos, freqs_sin
|
189 |
+
else:
|
190 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
191 |
+
return freqs_cis
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
def calc_sizes(rope_img, patch_size, th, tw):
|
196 |
+
""" 计算 RoPE 的尺寸. """
|
197 |
+
if rope_img == 'extend':
|
198 |
+
# 拓展模式
|
199 |
+
sub_args = [(th, tw)]
|
200 |
+
elif rope_img.startswith('base'):
|
201 |
+
# 基于一个尺寸, 其他尺寸插值获得.
|
202 |
+
base_size = int(rope_img[4:]) // 8 // patch_size # 基于512作为base,其他根据512差值得到
|
203 |
+
start, stop = get_fill_resize_and_crop((th, tw), base_size) # 需要在32x32里面 crop的左上角和右下角
|
204 |
+
sub_args = [start, stop, (th, tw)]
|
205 |
+
else:
|
206 |
+
raise ValueError(f"Unknown rope_img: {rope_img}")
|
207 |
+
return sub_args
|
208 |
+
|
209 |
+
|
210 |
+
def init_image_posemb(rope_img,
|
211 |
+
resolutions,
|
212 |
+
patch_size,
|
213 |
+
hidden_size,
|
214 |
+
num_heads,
|
215 |
+
log_fn,
|
216 |
+
rope_real=True,
|
217 |
+
):
|
218 |
+
freqs_cis_img = {}
|
219 |
+
for reso in resolutions:
|
220 |
+
th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
|
221 |
+
sub_args = calc_sizes(rope_img, patch_size, th, tw) # [左上角, 右下角, 目标高宽] 需要在32x32里面 crop的左上角和右下角
|
222 |
+
freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
|
223 |
+
log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
|
224 |
+
f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
|
225 |
+
return freqs_cis_img
|