pan-yl commited on
Commit
2a00960
1 Parent(s): 03b9c21

update file

Browse files
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import modules
app.py CHANGED
@@ -1,64 +1,1199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
-
62
-
63
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  demo.launch()
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import argparse
4
+ import base64
5
+ import copy
6
+ import glob
7
+ import io
8
+ import os
9
+ import random
10
+ import re
11
+ import string
12
+ import threading
13
+
14
+ import cv2
15
  import gradio as gr
16
+ import numpy as np
17
+ import torch
18
+ import transformers
19
+ from diffusers import CogVideoXImageToVideoPipeline
20
+ from diffusers.utils import export_to_video
21
+ from gradio_imageslider import ImageSlider
22
+ from PIL import Image
23
+ from transformers import AutoModel, AutoTokenizer
24
+
25
+ from scepter.modules.utils.config import Config
26
+ from scepter.modules.utils.directory import get_md5
27
+ from scepter.modules.utils.file_system import FS
28
+ from scepter.studio.utils.env import init_env
29
+
30
+ from .infer import ACEInference
31
+ from .example import get_examples
32
+ from .utils import load_image
33
+
34
+
35
+ refresh_sty = '\U0001f504' # 🔄
36
+ clear_sty = '\U0001f5d1' # 🗑️
37
+ upload_sty = '\U0001f5bc' # 🖼️
38
+ sync_sty = '\U0001f4be' # 💾
39
+ chat_sty = '\U0001F4AC' # 💬
40
+ video_sty = '\U0001f3a5' # 🎥
41
+
42
+ lock = threading.Lock()
43
+
44
+
45
+ class ChatBotUI(object):
46
+ def __init__(self,
47
+ cfg_general_file,
48
+ root_work_dir='./'):
49
+
50
+ cfg = Config(cfg_file=cfg_general_file)
51
+ cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
52
+ if not FS.exists(cfg.WORK_DIR):
53
+ FS.make_dir(cfg.WORK_DIR)
54
+ cfg = init_env(cfg)
55
+ self.cache_dir = cfg.WORK_DIR
56
+ self.chatbot_examples = get_examples(self.cache_dir)
57
+ self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR
58
+ self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir,
59
+ '*.yaml'))
60
+ self.model_choices = dict()
61
+ for i in self.model_yamls:
62
+ model_name = '.'.join(i.split('/')[-1].split('.')[:-1])
63
+ self.model_choices[model_name] = i
64
+ print('Models: ', self.model_choices)
65
+
66
+ self.model_name = cfg.MODEL.EDIT_MODEL.DEFAULT
67
+ assert self.model_name in self.model_choices
68
+ model_cfg = Config(load=True,
69
+ cfg_file=self.model_choices[self.model_name])
70
+ self.pipe = ACEInference()
71
+ self.pipe.init_from_cfg(model_cfg)
72
+ self.retry_msg = ''
73
+ self.max_msgs = 20
74
+
75
+ self.enable_i2v = cfg.get('ENABLE_I2V', False)
76
+ if self.enable_i2v:
77
+ self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR
78
+ self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME
79
+ if self.i2v_model_name == 'CogVideoX-5b-I2V':
80
+ with FS.get_dir_to_local_dir(self.i2v_model_dir) as local_dir:
81
+ self.i2v_pipe = CogVideoXImageToVideoPipeline.from_pretrained(
82
+ local_dir, torch_dtype=torch.bfloat16).cuda()
83
+ else:
84
+ raise NotImplementedError
85
+
86
+ with FS.get_dir_to_local_dir(
87
+ cfg.MODEL.CAPTIONER.MODEL_DIR) as local_dir:
88
+ self.captioner = AutoModel.from_pretrained(
89
+ local_dir,
90
+ torch_dtype=torch.bfloat16,
91
+ low_cpu_mem_usage=True,
92
+ use_flash_attn=True,
93
+ trust_remote_code=True).eval().cuda()
94
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(
95
+ local_dir, trust_remote_code=True, use_fast=False)
96
+ self.llm_generation_config = dict(max_new_tokens=1024,
97
+ do_sample=True)
98
+ self.llm_prompt = cfg.LLM.PROMPT
99
+ self.llm_max_num = 2
100
+
101
+ with FS.get_dir_to_local_dir(
102
+ cfg.MODEL.ENHANCER.MODEL_DIR) as local_dir:
103
+ self.enhancer = transformers.pipeline(
104
+ 'text-generation',
105
+ model=local_dir,
106
+ model_kwargs={'torch_dtype': torch.bfloat16},
107
+ device_map='auto',
108
+ )
109
+
110
+ sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
111
+
112
+ For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
113
+ There are a few rules to follow:
114
+
115
+ You will only ever output a single video description per user request.
116
+
117
+ When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
118
+ Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
119
+
120
+ Video descriptions must have the same num of words as examples below. Extra words will be ignored.
121
+ """
122
+ self.enhance_ctx = [
123
+ {
124
+ 'role': 'system',
125
+ 'content': sys_prompt
126
+ },
127
+ {
128
+ 'role':
129
+ 'user',
130
+ 'content':
131
+ 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
132
+ },
133
+ {
134
+ 'role':
135
+ 'assistant',
136
+ 'content':
137
+ "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
138
+ },
139
+ {
140
+ 'role':
141
+ 'user',
142
+ 'content':
143
+ 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
144
+ },
145
+ {
146
+ 'role':
147
+ 'assistant',
148
+ 'content':
149
+ "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
150
+ },
151
+ {
152
+ 'role':
153
+ 'user',
154
+ 'content':
155
+ 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
156
+ },
157
+ {
158
+ 'role':
159
+ 'assistant',
160
+ 'content':
161
+ 'A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.',
162
+ },
163
+ ]
164
+
165
+ def create_ui(self):
166
+ css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}'
167
+ with gr.Blocks(css=css,
168
+ title='Chatbot',
169
+ head='Chatbot',
170
+ analytics_enabled=False):
171
+ self.history = gr.State(value=[])
172
+ self.images = gr.State(value={})
173
+ self.history_result = gr.State(value={})
174
+ with gr.Group():
175
+ with gr.Row(equal_height=True):
176
+ with gr.Column(visible=True) as self.chat_page:
177
+ self.chatbot = gr.Chatbot(
178
+ height=600,
179
+ value=[],
180
+ bubble_full_width=False,
181
+ show_copy_button=True,
182
+ container=False,
183
+ placeholder='<strong>Chat Box</strong>')
184
+ with gr.Row():
185
+ self.clear_btn = gr.Button(clear_sty +
186
+ ' Clear Chat',
187
+ size='sm')
188
+
189
+ with gr.Column(visible=False) as self.editor_page:
190
+ with gr.Tabs():
191
+ with gr.Tab(id='ImageUploader',
192
+ label='Image Uploader',
193
+ visible=True) as self.upload_tab:
194
+ self.image_uploader = gr.Image(
195
+ height=550,
196
+ interactive=True,
197
+ type='pil',
198
+ image_mode='RGB',
199
+ sources='upload',
200
+ elem_id='image_uploader',
201
+ format='png')
202
+ with gr.Row():
203
+ self.sub_btn_1 = gr.Button(
204
+ value='Submit',
205
+ elem_id='upload_submit')
206
+ self.ext_btn_1 = gr.Button(value='Exit')
207
+
208
+ with gr.Tab(id='ImageEditor',
209
+ label='Image Editor',
210
+ visible=False) as self.edit_tab:
211
+ self.mask_type = gr.Dropdown(
212
+ label='Mask Type',
213
+ choices=[
214
+ 'Background', 'Composite',
215
+ 'Outpainting'
216
+ ],
217
+ value='Background')
218
+ self.mask_type_info = gr.HTML(
219
+ value=
220
+ "<div style='background-color: white; padding-left: 15px; color: grey;'>Background mode will not erase the visual content in the mask area</div>"
221
+ )
222
+ with gr.Accordion(
223
+ label='Outpainting Setting',
224
+ open=True,
225
+ visible=False) as self.outpaint_tab:
226
+ with gr.Row(variant='panel'):
227
+ self.top_ext = gr.Slider(
228
+ show_label=True,
229
+ label='Top Extend Ratio',
230
+ minimum=0.0,
231
+ maximum=2.0,
232
+ step=0.1,
233
+ value=0.25)
234
+ self.bottom_ext = gr.Slider(
235
+ show_label=True,
236
+ label='Bottom Extend Ratio',
237
+ minimum=0.0,
238
+ maximum=2.0,
239
+ step=0.1,
240
+ value=0.25)
241
+ with gr.Row(variant='panel'):
242
+ self.left_ext = gr.Slider(
243
+ show_label=True,
244
+ label='Left Extend Ratio',
245
+ minimum=0.0,
246
+ maximum=2.0,
247
+ step=0.1,
248
+ value=0.25)
249
+ self.right_ext = gr.Slider(
250
+ show_label=True,
251
+ label='Right Extend Ratio',
252
+ minimum=0.0,
253
+ maximum=2.0,
254
+ step=0.1,
255
+ value=0.25)
256
+ with gr.Row(variant='panel'):
257
+ self.img_pad_btn = gr.Button(
258
+ value='Pad Image')
259
+
260
+ self.image_editor = gr.ImageMask(
261
+ value=None,
262
+ sources=[],
263
+ layers=False,
264
+ label='Edit Image',
265
+ elem_id='image_editor',
266
+ format='png')
267
+ with gr.Row():
268
+ self.sub_btn_2 = gr.Button(
269
+ value='Submit', elem_id='edit_submit')
270
+ self.ext_btn_2 = gr.Button(value='Exit')
271
+
272
+ with gr.Tab(id='ImageViewer',
273
+ label='Image Viewer',
274
+ visible=False) as self.image_view_tab:
275
+ self.image_viewer = ImageSlider(
276
+ label='Image',
277
+ type='pil',
278
+ show_download_button=True,
279
+ elem_id='image_viewer')
280
+
281
+ self.ext_btn_3 = gr.Button(value='Exit')
282
+
283
+ with gr.Tab(id='VideoViewer',
284
+ label='Video Viewer',
285
+ visible=False) as self.video_view_tab:
286
+ self.video_viewer = gr.Video(
287
+ label='Video',
288
+ interactive=False,
289
+ sources=[],
290
+ format='mp4',
291
+ show_download_button=True,
292
+ elem_id='video_viewer',
293
+ loop=True,
294
+ autoplay=True)
295
+
296
+ self.ext_btn_4 = gr.Button(value='Exit')
297
+
298
+ with gr.Accordion(label='Setting', open=False):
299
+ with gr.Row():
300
+ self.model_name_dd = gr.Dropdown(
301
+ choices=self.model_choices,
302
+ value=self.model_name,
303
+ label='Model Version')
304
+
305
+ with gr.Row():
306
+ self.negative_prompt = gr.Textbox(
307
+ value='',
308
+ placeholder=
309
+ 'Negative prompt used for Classifier-Free Guidance',
310
+ label='Negative Prompt',
311
+ container=False)
312
+
313
+ with gr.Row():
314
+ with gr.Column(scale=8, min_width=500):
315
+ with gr.Row():
316
+ self.step = gr.Slider(minimum=1,
317
+ maximum=1000,
318
+ value=20,
319
+ label='Sample Step')
320
+ self.cfg_scale = gr.Slider(
321
+ minimum=1.0,
322
+ maximum=20.0,
323
+ value=4.5,
324
+ label='Guidance Scale')
325
+ self.rescale = gr.Slider(minimum=0.0,
326
+ maximum=1.0,
327
+ value=0.5,
328
+ label='Rescale')
329
+ self.seed = gr.Slider(minimum=-1,
330
+ maximum=10000000,
331
+ value=-1,
332
+ label='Seed')
333
+ self.output_height = gr.Slider(
334
+ minimum=256,
335
+ maximum=1024,
336
+ value=512,
337
+ label='Output Height')
338
+ self.output_width = gr.Slider(
339
+ minimum=256,
340
+ maximum=1024,
341
+ value=512,
342
+ label='Output Width')
343
+ with gr.Column(scale=1, min_width=50):
344
+ self.use_history = gr.Checkbox(value=False,
345
+ label='Use History')
346
+ self.video_auto = gr.Checkbox(
347
+ value=False,
348
+ label='Auto Gen Video',
349
+ visible=self.enable_i2v)
350
+
351
+ with gr.Row(variant='panel',
352
+ equal_height=True,
353
+ visible=self.enable_i2v):
354
+ self.video_fps = gr.Slider(minimum=1,
355
+ maximum=16,
356
+ value=8,
357
+ label='Video FPS',
358
+ visible=True)
359
+ self.video_frames = gr.Slider(minimum=8,
360
+ maximum=49,
361
+ value=49,
362
+ label='Video Frame Num',
363
+ visible=True)
364
+ self.video_step = gr.Slider(minimum=1,
365
+ maximum=1000,
366
+ value=50,
367
+ label='Video Sample Step',
368
+ visible=True)
369
+ self.video_cfg_scale = gr.Slider(
370
+ minimum=1.0,
371
+ maximum=20.0,
372
+ value=6.0,
373
+ label='Video Guidance Scale',
374
+ visible=True)
375
+ self.video_seed = gr.Slider(minimum=-1,
376
+ maximum=10000000,
377
+ value=-1,
378
+ label='Video Seed',
379
+ visible=True)
380
+
381
+ with gr.Row(variant='panel',
382
+ equal_height=True,
383
+ show_progress=False):
384
+ with gr.Column(scale=1, min_width=100):
385
+ self.upload_btn = gr.Button(value=upload_sty +
386
+ ' Upload',
387
+ variant='secondary')
388
+ with gr.Column(scale=5, min_width=500):
389
+ self.text = gr.Textbox(
390
+ placeholder='Input "@" find history of image',
391
+ label='Instruction',
392
+ container=False)
393
+ with gr.Column(scale=1, min_width=100):
394
+ self.chat_btn = gr.Button(value=chat_sty + ' Chat',
395
+ variant='primary')
396
+ with gr.Column(scale=1, min_width=100):
397
+ self.retry_btn = gr.Button(value=refresh_sty +
398
+ ' Retry',
399
+ variant='secondary')
400
+ with gr.Column(scale=(1 if self.enable_i2v else 0),
401
+ min_width=0):
402
+ self.video_gen_btn = gr.Button(value=video_sty +
403
+ ' Gen Video',
404
+ variant='secondary',
405
+ visible=self.enable_i2v)
406
+ with gr.Column(scale=(1 if self.enable_i2v else 0),
407
+ min_width=0):
408
+ self.extend_prompt = gr.Checkbox(
409
+ value=True,
410
+ label='Extend Prompt',
411
+ visible=self.enable_i2v)
412
+
413
+ with gr.Row():
414
+ self.gallery = gr.Gallery(visible=False,
415
+ label='History',
416
+ columns=10,
417
+ allow_preview=False,
418
+ interactive=False)
419
+
420
+ self.eg = gr.Column(visible=True)
421
+
422
+ def set_callbacks(self, *args, **kwargs):
423
+
424
+ ########################################
425
+ def change_model(model_name):
426
+ if model_name not in self.model_choices:
427
+ gr.Info('The provided model name is not a valid choice!')
428
+ return model_name, gr.update(), gr.update()
429
+
430
+ if model_name != self.model_name:
431
+ lock.acquire()
432
+ del self.pipe
433
+ torch.cuda.empty_cache()
434
+ model_cfg = Config(load=True,
435
+ cfg_file=self.model_choices[model_name])
436
+ self.pipe = ACEInference()
437
+ self.pipe.init_from_cfg(model_cfg)
438
+ self.model_name = model_name
439
+ lock.release()
440
+
441
+ return model_name, gr.update(), gr.update()
442
+
443
+ self.model_name_dd.change(
444
+ change_model,
445
+ inputs=[self.model_name_dd],
446
+ outputs=[self.model_name_dd, self.chatbot, self.text])
447
+
448
+ ########################################
449
+ def generate_gallery(text, images):
450
+ if text.endswith(' '):
451
+ return gr.update(), gr.update(visible=False)
452
+ elif text.endswith('@'):
453
+ gallery_info = []
454
+ for image_id, image_meta in images.items():
455
+ thumbnail_path = image_meta['thumbnail']
456
+ gallery_info.append((thumbnail_path, image_id))
457
+ return gr.update(), gr.update(visible=True, value=gallery_info)
458
+ else:
459
+ gallery_info = []
460
+ match = re.search('@([^@ ]+)$', text)
461
+ if match:
462
+ prefix = match.group(1)
463
+ for image_id, image_meta in images.items():
464
+ if not image_id.startswith(prefix):
465
+ continue
466
+ thumbnail_path = image_meta['thumbnail']
467
+ gallery_info.append((thumbnail_path, image_id))
468
+
469
+ if len(gallery_info) > 0:
470
+ return gr.update(), gr.update(visible=True,
471
+ value=gallery_info)
472
+ else:
473
+ return gr.update(), gr.update(visible=False)
474
+ else:
475
+ return gr.update(), gr.update(visible=False)
476
+
477
+ self.text.input(generate_gallery,
478
+ inputs=[self.text, self.images],
479
+ outputs=[self.text, self.gallery],
480
+ show_progress='hidden')
481
+
482
+ ########################################
483
+ def select_image(text, evt: gr.SelectData):
484
+ image_id = evt.value['caption']
485
+ text = '@'.join(text.split('@')[:-1]) + f'@{image_id} '
486
+ return gr.update(value=text), gr.update(visible=False, value=None)
487
+
488
+ self.gallery.select(select_image,
489
+ inputs=self.text,
490
+ outputs=[self.text, self.gallery])
491
+
492
+ ########################################
493
+ def generate_video(message,
494
+ extend_prompt,
495
+ history,
496
+ images,
497
+ num_steps,
498
+ num_frames,
499
+ cfg_scale,
500
+ fps,
501
+ seed,
502
+ progress=gr.Progress(track_tqdm=True)):
503
+ generator = torch.Generator(device='cuda').manual_seed(seed)
504
+ img_ids = re.findall('@(.*?)[ ,;.?$]', message)
505
+ if len(img_ids) == 0:
506
+ history.append((
507
+ message,
508
+ 'Sorry, no images were found in the prompt to be used as the first frame of the video.'
509
+ ))
510
+ while len(history) >= self.max_msgs:
511
+ history.pop(0)
512
+ return history, self.get_history(
513
+ history), gr.update(), gr.update(visible=False)
514
+
515
+ img_id = img_ids[0]
516
+ prompt = re.sub(f'@{img_id}\s+', '', message)
517
+
518
+ if extend_prompt:
519
+ messages = copy.deepcopy(self.enhance_ctx)
520
+ messages.append({
521
+ 'role':
522
+ 'user',
523
+ 'content':
524
+ f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{prompt}"',
525
+ })
526
+ lock.acquire()
527
+ outputs = self.enhancer(
528
+ messages,
529
+ max_new_tokens=200,
530
+ )
531
+
532
+ prompt = outputs[0]['generated_text'][-1]['content']
533
+ print(prompt)
534
+ lock.release()
535
+
536
+ img_meta = images[img_id]
537
+ img_path = img_meta['image']
538
+ image = Image.open(img_path).convert('RGB')
539
+
540
+ lock.acquire()
541
+ video = self.i2v_pipe(
542
+ prompt=prompt,
543
+ image=image,
544
+ num_videos_per_prompt=1,
545
+ num_inference_steps=num_steps,
546
+ num_frames=num_frames,
547
+ guidance_scale=cfg_scale,
548
+ generator=generator,
549
+ ).frames[0]
550
+ lock.release()
551
+
552
+ out_video_path = export_to_video(video, fps=fps)
553
+ history.append((
554
+ f"Based on first frame @{img_id} and description '{prompt}', generate a video",
555
+ 'This is generated video:'))
556
+ history.append((None, out_video_path))
557
+ while len(history) >= self.max_msgs:
558
+ history.pop(0)
559
+
560
+ return history, self.get_history(history), gr.update(
561
+ value=''), gr.update(visible=False)
562
+
563
+ self.video_gen_btn.click(
564
+ generate_video,
565
+ inputs=[
566
+ self.text, self.extend_prompt, self.history, self.images,
567
+ self.video_step, self.video_frames, self.video_cfg_scale,
568
+ self.video_fps, self.video_seed
569
+ ],
570
+ outputs=[self.history, self.chatbot, self.text, self.gallery])
571
+
572
+ ########################################
573
+ def run_chat(message,
574
+ extend_prompt,
575
+ history,
576
+ images,
577
+ use_history,
578
+ history_result,
579
+ negative_prompt,
580
+ cfg_scale,
581
+ rescale,
582
+ step,
583
+ seed,
584
+ output_h,
585
+ output_w,
586
+ video_auto,
587
+ video_steps,
588
+ video_frames,
589
+ video_cfg_scale,
590
+ video_fps,
591
+ video_seed,
592
+ progress=gr.Progress(track_tqdm=True)):
593
+ self.retry_msg = message
594
+ gen_id = get_md5(message)[:12]
595
+ save_path = os.path.join(self.cache_dir, f'{gen_id}.png')
596
+
597
+ img_ids = re.findall('@(.*?)[ ,;.?$]', message)
598
+ history_io = None
599
+ new_message = message
600
+
601
+ if len(img_ids) > 0:
602
+ edit_image, edit_image_mask, edit_task = [], [], []
603
+ for i, img_id in enumerate(img_ids):
604
+ if img_id not in images:
605
+ gr.Info(
606
+ f'The input image ID {img_id} is not exist... Skip loading image.'
607
+ )
608
+ continue
609
+ placeholder = '{image}' if i == 0 else '{' + f'image{i}' + '}'
610
+ new_message = re.sub(f'@{img_id}', placeholder,
611
+ new_message)
612
+ img_meta = images[img_id]
613
+ img_path = img_meta['image']
614
+ img_mask = img_meta['mask']
615
+ img_mask_type = img_meta['mask_type']
616
+ if img_mask_type is not None and img_mask_type == 'Composite':
617
+ task = 'inpainting'
618
+ else:
619
+ task = ''
620
+ edit_image.append(Image.open(img_path).convert('RGB'))
621
+ edit_image_mask.append(
622
+ Image.open(img_mask).
623
+ convert('L') if img_mask is not None else None)
624
+ edit_task.append(task)
625
+
626
+ if use_history and (img_id in history_result):
627
+ history_io = history_result[img_id]
628
+
629
+ buffered = io.BytesIO()
630
+ edit_image[0].save(buffered, format='PNG')
631
+ img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
632
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
633
+ pre_info = f'Received one or more images, so image editing is conducted.\n The first input image @{img_ids[0]} is:\n {img_str}'
634
+ else:
635
+ pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
636
+ edit_image = None
637
+ edit_image_mask = None
638
+ edit_task = ''
639
+
640
+ print(new_message)
641
+ imgs = self.pipe(
642
+ input_image=edit_image,
643
+ input_mask=edit_image_mask,
644
+ task=edit_task,
645
+ prompt=[new_message] *
646
+ len(edit_image) if edit_image is not None else [new_message],
647
+ negative_prompt=[negative_prompt] * len(edit_image)
648
+ if edit_image is not None else [negative_prompt],
649
+ history_io=history_io,
650
+ output_height=output_h,
651
+ output_width=output_w,
652
+ sampler='ddim',
653
+ sample_steps=step,
654
+ guide_scale=cfg_scale,
655
+ guide_rescale=rescale,
656
+ seed=seed,
657
+ )
658
+
659
+ img = imgs[0]
660
+ img.save(save_path, format='PNG')
661
+
662
+ if history_io:
663
+ history_io_new = copy.deepcopy(history_io)
664
+ history_io_new['image'] += edit_image[:1]
665
+ history_io_new['mask'] += edit_image_mask[:1]
666
+ history_io_new['task'] += edit_task[:1]
667
+ history_io_new['prompt'] += [new_message]
668
+ history_io_new['image'] = history_io_new['image'][-5:]
669
+ history_io_new['mask'] = history_io_new['mask'][-5:]
670
+ history_io_new['task'] = history_io_new['task'][-5:]
671
+ history_io_new['prompt'] = history_io_new['prompt'][-5:]
672
+ history_result[gen_id] = history_io_new
673
+ elif edit_image is not None and len(edit_image) > 0:
674
+ history_io_new = {
675
+ 'image': edit_image[:1],
676
+ 'mask': edit_image_mask[:1],
677
+ 'task': edit_task[:1],
678
+ 'prompt': [new_message]
679
+ }
680
+ history_result[gen_id] = history_io_new
681
+
682
+ w, h = img.size
683
+ if w > h:
684
+ tb_w = 128
685
+ tb_h = int(h * tb_w / w)
686
+ else:
687
+ tb_h = 128
688
+ tb_w = int(w * tb_h / h)
689
+
690
+ thumbnail_path = os.path.join(self.cache_dir,
691
+ f'{gen_id}_thumbnail.jpg')
692
+ thumbnail = img.resize((tb_w, tb_h))
693
+ thumbnail.save(thumbnail_path, format='JPEG')
694
+
695
+ images[gen_id] = {
696
+ 'image': save_path,
697
+ 'mask': None,
698
+ 'mask_type': None,
699
+ 'thumbnail': thumbnail_path
700
+ }
701
+
702
+ buffered = io.BytesIO()
703
+ img.convert('RGB').save(buffered, format='PNG')
704
+ img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
705
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
706
+
707
+ history.append(
708
+ (message,
709
+ f'{pre_info} The generated image @{gen_id} is:\n {img_str}'))
710
+
711
+ if video_auto:
712
+ if video_seed is None or video_seed == -1:
713
+ video_seed = random.randint(0, 10000000)
714
+
715
+ lock.acquire()
716
+ generator = torch.Generator(
717
+ device='cuda').manual_seed(video_seed)
718
+ pixel_values = load_image(img.convert('RGB'),
719
+ max_num=self.llm_max_num).to(
720
+ torch.bfloat16).cuda()
721
+ prompt = self.captioner.chat(self.llm_tokenizer, pixel_values,
722
+ self.llm_prompt,
723
+ self.llm_generation_config)
724
+ print(prompt)
725
+ lock.release()
726
+
727
+ if extend_prompt:
728
+ messages = copy.deepcopy(self.enhance_ctx)
729
+ messages.append({
730
+ 'role':
731
+ 'user',
732
+ 'content':
733
+ f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{prompt}"',
734
+ })
735
+ lock.acquire()
736
+ outputs = self.enhancer(
737
+ messages,
738
+ max_new_tokens=200,
739
+ )
740
+ prompt = outputs[0]['generated_text'][-1]['content']
741
+ print(prompt)
742
+ lock.release()
743
+
744
+ lock.acquire()
745
+ video = self.i2v_pipe(
746
+ prompt=prompt,
747
+ image=img,
748
+ num_videos_per_prompt=1,
749
+ num_inference_steps=video_steps,
750
+ num_frames=video_frames,
751
+ guidance_scale=video_cfg_scale,
752
+ generator=generator,
753
+ ).frames[0]
754
+ lock.release()
755
+
756
+ out_video_path = export_to_video(video, fps=video_fps)
757
+ history.append((
758
+ f"Based on first frame @{gen_id} and description '{prompt}', generate a video",
759
+ 'This is generated video:'))
760
+ history.append((None, out_video_path))
761
+
762
+ while len(history) >= self.max_msgs:
763
+ history.pop(0)
764
+
765
+ return history, images, history_result, self.get_history(
766
+ history), gr.update(value=''), gr.update(visible=False)
767
+
768
+ chat_inputs = [
769
+ self.extend_prompt, self.history, self.images, self.use_history,
770
+ self.history_result, self.negative_prompt, self.cfg_scale,
771
+ self.rescale, self.step, self.seed, self.output_height,
772
+ self.output_width, self.video_auto, self.video_step,
773
+ self.video_frames, self.video_cfg_scale, self.video_fps,
774
+ self.video_seed
775
+ ]
776
+
777
+ chat_outputs = [
778
+ self.history, self.images, self.history_result, self.chatbot,
779
+ self.text, self.gallery
780
+ ]
781
+
782
+ self.chat_btn.click(run_chat,
783
+ inputs=[self.text] + chat_inputs,
784
+ outputs=chat_outputs)
785
+
786
+ self.text.submit(run_chat,
787
+ inputs=[self.text] + chat_inputs,
788
+ outputs=chat_outputs)
789
+
790
+ ########################################
791
+ def retry_chat(*args):
792
+ return run_chat(self.retry_msg, *args)
793
+
794
+ self.retry_btn.click(retry_chat,
795
+ inputs=chat_inputs,
796
+ outputs=chat_outputs)
797
+
798
+ ########################################
799
+ def run_example(task, img, img_mask, ref1, prompt, seed):
800
+ edit_image, edit_image_mask, edit_task = [], [], []
801
+ if img is not None:
802
+ w, h = img.size
803
+ if w > 2048:
804
+ ratio = w / 2048.
805
+ w = 2048
806
+ h = int(h / ratio)
807
+ if h > 2048:
808
+ ratio = h / 2048.
809
+ h = 2048
810
+ w = int(w / ratio)
811
+ img = img.resize((w, h))
812
+ edit_image.append(img)
813
+ edit_image_mask.append(
814
+ img_mask if img_mask is not None else None)
815
+ edit_task.append(task)
816
+ if ref1 is not None:
817
+ edit_image.append(ref1)
818
+ edit_image_mask.append(None)
819
+ edit_task.append('')
820
+
821
+ buffered = io.BytesIO()
822
+ img.save(buffered, format='PNG')
823
+ img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
824
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
825
+ pre_info = f'Received one or more images, so image editing is conducted.\n The first input image is:\n {img_str}'
826
+ else:
827
+ pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
828
+ edit_image = None
829
+ edit_image_mask = None
830
+ edit_task = ''
831
+
832
+ img_num = len(edit_image) if edit_image is not None else 1
833
+ imgs = self.pipe(
834
+ input_image=edit_image,
835
+ input_mask=edit_image_mask,
836
+ task=edit_task,
837
+ prompt=[prompt] * img_num,
838
+ negative_prompt=[''] * img_num,
839
+ seed=seed,
840
+ )
841
+
842
+ img = imgs[0]
843
+ buffered = io.BytesIO()
844
+ img.convert('RGB').save(buffered, format='PNG')
845
+ img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
846
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
847
+ history = [(prompt,
848
+ f'{pre_info} The generated image is:\n {img_str}')]
849
+ return self.get_history(history), gr.update(value=''), gr.update(
850
+ visible=False)
851
+
852
+ with self.eg:
853
+ self.example_task = gr.Text(label='Task Name',
854
+ value='',
855
+ visible=False)
856
+ self.example_image = gr.Image(label='Edit Image',
857
+ type='pil',
858
+ image_mode='RGB',
859
+ visible=False)
860
+ self.example_mask = gr.Image(label='Edit Image Mask',
861
+ type='pil',
862
+ image_mode='L',
863
+ visible=False)
864
+ self.example_ref_im1 = gr.Image(label='Ref Image',
865
+ type='pil',
866
+ image_mode='RGB',
867
+ visible=False)
868
+
869
+ self.examples = gr.Examples(
870
+ fn=run_example,
871
+ examples=self.chatbot_examples,
872
+ inputs=[
873
+ self.example_task, self.example_image, self.example_mask,
874
+ self.example_ref_im1, self.text, self.seed
875
+ ],
876
+ outputs=[self.chatbot, self.text, self.gallery],
877
+ run_on_click=True)
878
+
879
+ ########################################
880
+ def upload_image():
881
+ return (gr.update(visible=True,
882
+ scale=1), gr.update(visible=True, scale=1),
883
+ gr.update(visible=True), gr.update(visible=False),
884
+ gr.update(visible=False), gr.update(visible=False))
885
+
886
+ self.upload_btn.click(upload_image,
887
+ inputs=[],
888
+ outputs=[
889
+ self.chat_page, self.editor_page,
890
+ self.upload_tab, self.edit_tab,
891
+ self.image_view_tab, self.video_view_tab
892
+ ])
893
+
894
+ ########################################
895
+ def edit_image(evt: gr.SelectData):
896
+ if isinstance(evt.value, str):
897
+ img_b64s = re.findall(
898
+ '<img src="data:image/png;base64,(.*?)" style="pointer-events: none;">',
899
+ evt.value)
900
+ imgs = [
901
+ Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i))))
902
+ for i in img_b64s
903
+ ]
904
+ if len(imgs) > 0:
905
+ if len(imgs) == 2:
906
+ view_img = copy.deepcopy(imgs)
907
+ edit_img = copy.deepcopy(imgs[-1])
908
+ else:
909
+ view_img = [
910
+ copy.deepcopy(imgs[-1]),
911
+ copy.deepcopy(imgs[-1])
912
+ ]
913
+ edit_img = copy.deepcopy(imgs[-1])
914
+
915
+ return (gr.update(visible=True,
916
+ scale=1), gr.update(visible=True,
917
+ scale=1),
918
+ gr.update(visible=False), gr.update(visible=True),
919
+ gr.update(visible=True), gr.update(visible=False),
920
+ gr.update(value=edit_img),
921
+ gr.update(value=view_img), gr.update(value=None))
922
+ else:
923
+ return (gr.update(), gr.update(), gr.update(), gr.update(),
924
+ gr.update(), gr.update(), gr.update(), gr.update(),
925
+ gr.update())
926
+ elif isinstance(evt.value, dict) and evt.value.get(
927
+ 'component', '') == 'video':
928
+ value = evt.value['value']['video']['path']
929
+ return (gr.update(visible=True,
930
+ scale=1), gr.update(visible=True, scale=1),
931
+ gr.update(visible=False), gr.update(visible=False),
932
+ gr.update(visible=False), gr.update(visible=True),
933
+ gr.update(), gr.update(), gr.update(value=value))
934
+ else:
935
+ return (gr.update(), gr.update(), gr.update(), gr.update(),
936
+ gr.update(), gr.update(), gr.update(), gr.update(),
937
+ gr.update())
938
+
939
+ self.chatbot.select(edit_image,
940
+ outputs=[
941
+ self.chat_page, self.editor_page,
942
+ self.upload_tab, self.edit_tab,
943
+ self.image_view_tab, self.video_view_tab,
944
+ self.image_editor, self.image_viewer,
945
+ self.video_viewer
946
+ ])
947
+
948
+ self.image_viewer.change(lambda x: x,
949
+ inputs=self.image_viewer,
950
+ outputs=self.image_viewer)
951
+
952
+ ########################################
953
+ def submit_upload_image(image, history, images):
954
+ history, images = self.add_uploaded_image_to_history(
955
+ image, history, images)
956
+ return gr.update(visible=False), gr.update(
957
+ visible=True), gr.update(
958
+ value=self.get_history(history)), history, images
959
+
960
+ self.sub_btn_1.click(
961
+ submit_upload_image,
962
+ inputs=[self.image_uploader, self.history, self.images],
963
+ outputs=[
964
+ self.editor_page, self.chat_page, self.chatbot, self.history,
965
+ self.images
966
+ ])
967
+
968
+ ########################################
969
+ def submit_edit_image(imagemask, mask_type, history, images):
970
+ history, images = self.add_edited_image_to_history(
971
+ imagemask, mask_type, history, images)
972
+ return gr.update(visible=False), gr.update(
973
+ visible=True), gr.update(
974
+ value=self.get_history(history)), history, images
975
+
976
+ self.sub_btn_2.click(submit_edit_image,
977
+ inputs=[
978
+ self.image_editor, self.mask_type,
979
+ self.history, self.images
980
+ ],
981
+ outputs=[
982
+ self.editor_page, self.chat_page,
983
+ self.chatbot, self.history, self.images
984
+ ])
985
+
986
+ ########################################
987
+ def exit_edit():
988
+ return gr.update(visible=False), gr.update(visible=True, scale=3)
989
+
990
+ self.ext_btn_1.click(exit_edit,
991
+ outputs=[self.editor_page, self.chat_page])
992
+ self.ext_btn_2.click(exit_edit,
993
+ outputs=[self.editor_page, self.chat_page])
994
+ self.ext_btn_3.click(exit_edit,
995
+ outputs=[self.editor_page, self.chat_page])
996
+ self.ext_btn_4.click(exit_edit,
997
+ outputs=[self.editor_page, self.chat_page])
998
+
999
+ ########################################
1000
+ def update_mask_type_info(mask_type):
1001
+ if mask_type == 'Background':
1002
+ info = 'Background mode will not erase the visual content in the mask area'
1003
+ visible = False
1004
+ elif mask_type == 'Composite':
1005
+ info = 'Composite mode will erase the visual content in the mask area'
1006
+ visible = False
1007
+ elif mask_type == 'Outpainting':
1008
+ info = 'Outpaint mode is used for preparing input image for outpainting task'
1009
+ visible = True
1010
+ return (gr.update(
1011
+ visible=True,
1012
+ value=
1013
+ f"<div style='background-color: white; padding-left: 15px; color: grey;'>{info}</div>"
1014
+ ), gr.update(visible=visible))
1015
+
1016
+ self.mask_type.change(update_mask_type_info,
1017
+ inputs=self.mask_type,
1018
+ outputs=[self.mask_type_info, self.outpaint_tab])
1019
+
1020
+ ########################################
1021
+ def extend_image(top_ratio, bottom_ratio, left_ratio, right_ratio,
1022
+ image):
1023
+ img = cv2.cvtColor(image['background'], cv2.COLOR_RGBA2RGB)
1024
+ h, w = img.shape[:2]
1025
+ new_h = int(h * (top_ratio + bottom_ratio + 1))
1026
+ new_w = int(w * (left_ratio + right_ratio + 1))
1027
+ start_h = int(h * top_ratio)
1028
+ start_w = int(w * left_ratio)
1029
+ new_img = np.zeros((new_h, new_w, 3), dtype=np.uint8)
1030
+ new_mask = np.ones((new_h, new_w, 1), dtype=np.uint8) * 255
1031
+ new_img[start_h:start_h + h, start_w:start_w + w, :] = img
1032
+ new_mask[start_h:start_h + h, start_w:start_w + w] = 0
1033
+ layer = np.concatenate([new_img, new_mask], axis=2)
1034
+ value = {
1035
+ 'background': new_img,
1036
+ 'composite': new_img,
1037
+ 'layers': [layer]
1038
+ }
1039
+ return gr.update(value=value)
1040
+
1041
+ self.img_pad_btn.click(extend_image,
1042
+ inputs=[
1043
+ self.top_ext, self.bottom_ext,
1044
+ self.left_ext, self.right_ext,
1045
+ self.image_editor
1046
+ ],
1047
+ outputs=self.image_editor)
1048
+
1049
+ ########################################
1050
+ def clear_chat(history, images, history_result):
1051
+ history.clear()
1052
+ images.clear()
1053
+ history_result.clear()
1054
+ return history, images, history_result, self.get_history(history)
1055
+
1056
+ self.clear_btn.click(
1057
+ clear_chat,
1058
+ inputs=[self.history, self.images, self.history_result],
1059
+ outputs=[
1060
+ self.history, self.images, self.history_result, self.chatbot
1061
+ ])
1062
+
1063
+ def get_history(self, history):
1064
+ info = []
1065
+ for item in history:
1066
+ new_item = [None, None]
1067
+ if isinstance(item[0], str) and item[0].endswith('.mp4'):
1068
+ new_item[0] = gr.Video(item[0], format='mp4')
1069
+ else:
1070
+ new_item[0] = item[0]
1071
+ if isinstance(item[1], str) and item[1].endswith('.mp4'):
1072
+ new_item[1] = gr.Video(item[1], format='mp4')
1073
+ else:
1074
+ new_item[1] = item[1]
1075
+ info.append(new_item)
1076
+ return info
1077
+
1078
+ def generate_random_string(self, length=20):
1079
+ letters_and_digits = string.ascii_letters + string.digits
1080
+ random_string = ''.join(
1081
+ random.choice(letters_and_digits) for i in range(length))
1082
+ return random_string
1083
+
1084
+ def add_edited_image_to_history(self, image, mask_type, history, images):
1085
+ if mask_type == 'Composite':
1086
+ img = Image.fromarray(image['composite'])
1087
+ else:
1088
+ img = Image.fromarray(image['background'])
1089
+
1090
+ img_id = get_md5(self.generate_random_string())[:12]
1091
+ save_path = os.path.join(self.cache_dir, f'{img_id}.png')
1092
+ img.convert('RGB').save(save_path)
1093
+
1094
+ mask = image['layers'][0][:, :, 3]
1095
+ mask = Image.fromarray(mask).convert('RGB')
1096
+ mask_path = os.path.join(self.cache_dir, f'{img_id}_mask.png')
1097
+ mask.save(mask_path)
1098
+
1099
+ w, h = img.size
1100
+ if w > h:
1101
+ tb_w = 128
1102
+ tb_h = int(h * tb_w / w)
1103
+ else:
1104
+ tb_h = 128
1105
+ tb_w = int(w * tb_h / h)
1106
+
1107
+ if mask_type == 'Background':
1108
+ comp_mask = np.array(mask, dtype=np.uint8)
1109
+ mask_alpha = (comp_mask[:, :, 0:1].astype(np.float32) *
1110
+ 0.6).astype(np.uint8)
1111
+ comp_mask = np.concatenate([comp_mask, mask_alpha], axis=2)
1112
+ thumbnail = Image.alpha_composite(
1113
+ img.convert('RGBA'),
1114
+ Image.fromarray(comp_mask).convert('RGBA')).convert('RGB')
1115
+ else:
1116
+ thumbnail = img.convert('RGB')
1117
+
1118
+ thumbnail_path = os.path.join(self.cache_dir,
1119
+ f'{img_id}_thumbnail.jpg')
1120
+ thumbnail = thumbnail.resize((tb_w, tb_h))
1121
+ thumbnail.save(thumbnail_path, format='JPEG')
1122
+
1123
+ buffered = io.BytesIO()
1124
+ img.convert('RGB').save(buffered, format='PNG')
1125
+ img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1126
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1127
+
1128
+ buffered = io.BytesIO()
1129
+ mask.convert('RGB').save(buffered, format='PNG')
1130
+ mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1131
+ mask_str = f'<img src="data:image/png;base64,{mask_b64}" style="pointer-events: none;">'
1132
+
1133
+ images[img_id] = {
1134
+ 'image': save_path,
1135
+ 'mask': mask_path,
1136
+ 'mask_type': mask_type,
1137
+ 'thumbnail': thumbnail_path
1138
+ }
1139
+ history.append((
1140
+ None,
1141
+ f'This is edited image and mask:\n {img_str} {mask_str} image ID is: {img_id}'
1142
+ ))
1143
+ return history, images
1144
+
1145
+ def add_uploaded_image_to_history(self, img, history, images):
1146
+ img_id = get_md5(self.generate_random_string())[:12]
1147
+ save_path = os.path.join(self.cache_dir, f'{img_id}.png')
1148
+ w, h = img.size
1149
+ if w > 2048:
1150
+ ratio = w / 2048.
1151
+ w = 2048
1152
+ h = int(h / ratio)
1153
+ if h > 2048:
1154
+ ratio = h / 2048.
1155
+ h = 2048
1156
+ w = int(w / ratio)
1157
+ img = img.resize((w, h))
1158
+ img.save(save_path)
1159
+
1160
+ w, h = img.size
1161
+ if w > h:
1162
+ tb_w = 128
1163
+ tb_h = int(h * tb_w / w)
1164
+ else:
1165
+ tb_h = 128
1166
+ tb_w = int(w * tb_h / h)
1167
+ thumbnail_path = os.path.join(self.cache_dir,
1168
+ f'{img_id}_thumbnail.jpg')
1169
+ thumbnail = img.resize((tb_w, tb_h))
1170
+ thumbnail.save(thumbnail_path, format='JPEG')
1171
+
1172
+ images[img_id] = {
1173
+ 'image': save_path,
1174
+ 'mask': None,
1175
+ 'mask_type': None,
1176
+ 'thumbnail': thumbnail_path
1177
+ }
1178
+
1179
+ buffered = io.BytesIO()
1180
+ img.convert('RGB').save(buffered, format='PNG')
1181
+ img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1182
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1183
+
1184
+ history.append(
1185
+ (None,
1186
+ f'This is uploaded image:\n {img_str} image ID is: {img_id}'))
1187
+ return history, images
1188
+
1189
+
1190
+
1191
+ if __name__ == '__main__':
1192
+ cfg = Config(cfg_file="config/chatbot_ui.yaml")
1193
+
1194
+ with gr.Blocks() as demo:
1195
+ chatbot = ChatBotUI(cfg)
1196
+ chatbot.create_bot_ui()
1197
+ chatbot.set_callbacks()
1198
+
1199
  demo.launch()
config/chatbot_ui.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ WORK_DIR: ./cache/chatbot
2
+ FILE_SYSTEM:
3
+ - NAME: LocalFs
4
+ TEMP_DIR: ./cache
5
+ - NAME: ModelscopeFs
6
+ TEMP_DIR: ./cache
7
+ - NAME: HuggingfaceFs
8
+ TEMP_DIR: ./cache
9
+ #
10
+ ENABLE_I2V: False
11
+ #
12
+ MODEL:
13
+ EDIT_MODEL:
14
+ MODEL_CFG_DIR: config/models/
15
+ DEFAULT: ace_0.6b_512
16
+ I2V:
17
+ MODEL_NAME: CogVideoX-5b-I2V
18
+ MODEL_DIR: ms://ZhipuAI/CogVideoX-5b-I2V/
19
+ CAPTIONER:
20
+ MODEL_NAME: InternVL2-2B
21
+ MODEL_DIR: ms://OpenGVLab/InternVL2-2B/
22
+ PROMPT: '<image>\nThis image is the first frame of a video. Based on this image, please imagine what changes may occur in the next few seconds of the video. Please output brief description, such as "a dog running" or "a person turns to left". No more than 30 words.'
23
+ ENHANCER:
24
+ MODEL_NAME: Meta-Llama-3.1-8B-Instruct
25
+ MODEL_DIR: ms://LLM-Research/Meta-Llama-3.1-8B-Instruct/
config/models/ace_0.6b_512.yaml ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NAME: ACE_0.6B_512
2
+ IS_DEFAULT: False
3
+ DEFAULT_PARAS:
4
+ PARAS:
5
+ #
6
+ INPUT:
7
+ INPUT_IMAGE:
8
+ INPUT_MASK:
9
+ TASK:
10
+ PROMPT: ""
11
+ NEGATIVE_PROMPT: ""
12
+ OUTPUT_HEIGHT: 512
13
+ OUTPUT_WIDTH: 512
14
+ SAMPLER: ddim
15
+ SAMPLE_STEPS: 20
16
+ GUIDE_SCALE: 4.5
17
+ GUIDE_RESCALE: 0.5
18
+ SEED: -1
19
+ TAR_INDEX: 0
20
+ OUTPUT:
21
+ LATENT:
22
+ IMAGES:
23
+ SEED:
24
+ MODULES_PARAS:
25
+ FIRST_STAGE_MODEL:
26
+ FUNCTION:
27
+ - NAME: encode
28
+ DTYPE: float16
29
+ INPUT: ["IMAGE"]
30
+ - NAME: decode
31
+ DTYPE: float16
32
+ INPUT: ["LATENT"]
33
+ #
34
+ DIFFUSION_MODEL:
35
+ FUNCTION:
36
+ - NAME: forward
37
+ DTYPE: float16
38
+ INPUT: ["SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE"]
39
+ #
40
+ COND_STAGE_MODEL:
41
+ FUNCTION:
42
+ - NAME: encode_list
43
+ DTYPE: bfloat16
44
+ INPUT: ["PROMPT"]
45
+ #
46
+ MODEL:
47
+ NAME: LdmACE
48
+ PRETRAINED_MODEL:
49
+ IGNORE_KEYS: [ ]
50
+ SCALE_FACTOR: 0.18215
51
+ SIZE_FACTOR: 8
52
+ DECODER_BIAS: 0.5
53
+ DEFAULT_N_PROMPT: ""
54
+ TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
55
+ USE_TEXT_POS_EMBEDDINGS: True
56
+ #
57
+ DIFFUSION:
58
+ NAME: ACEDiffusion
59
+ PREDICTION_TYPE: eps
60
+ MIN_SNR_GAMMA:
61
+ NOISE_SCHEDULER:
62
+ NAME: LinearScheduler
63
+ NUM_TIMESTEPS: 1000
64
+ BETA_MIN: 0.0001
65
+ BETA_MAX: 0.02
66
+ #
67
+ DIFFUSION_MODEL:
68
+ NAME: DiTACE
69
+ PRETRAINED_MODEL: hf://scepter-studio/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth
70
+ IGNORE_KEYS: [ ]
71
+ PATCH_SIZE: 2
72
+ IN_CHANNELS: 4
73
+ HIDDEN_SIZE: 1152
74
+ DEPTH: 28
75
+ NUM_HEADS: 16
76
+ MLP_RATIO: 4.0
77
+ PRED_SIGMA: True
78
+ DROP_PATH: 0.0
79
+ WINDOW_DIZE: 0
80
+ Y_CHANNELS: 4096
81
+ MAX_SEQ_LEN: 1024
82
+ QK_NORM: True
83
+ USE_GRAD_CHECKPOINT: True
84
+ ATTENTION_BACKEND: flash_attn
85
+ #
86
+ FIRST_STAGE_MODEL:
87
+ NAME: AutoencoderKL
88
+ EMBED_DIM: 4
89
+ PRETRAINED_MODEL: hf://scepter-studio/ACE-0.6B-512px@models/vae/vae.bin
90
+ IGNORE_KEYS: []
91
+ #
92
+ ENCODER:
93
+ NAME: Encoder
94
+ CH: 128
95
+ OUT_CH: 3
96
+ NUM_RES_BLOCKS: 2
97
+ IN_CHANNELS: 3
98
+ ATTN_RESOLUTIONS: [ ]
99
+ CH_MULT: [ 1, 2, 4, 4 ]
100
+ Z_CHANNELS: 4
101
+ DOUBLE_Z: True
102
+ DROPOUT: 0.0
103
+ RESAMP_WITH_CONV: True
104
+ #
105
+ DECODER:
106
+ NAME: Decoder
107
+ CH: 128
108
+ OUT_CH: 3
109
+ NUM_RES_BLOCKS: 2
110
+ IN_CHANNELS: 3
111
+ ATTN_RESOLUTIONS: [ ]
112
+ CH_MULT: [ 1, 2, 4, 4 ]
113
+ Z_CHANNELS: 4
114
+ DROPOUT: 0.0
115
+ RESAMP_WITH_CONV: True
116
+ GIVE_PRE_END: False
117
+ TANH_OUT: False
118
+ #
119
+ COND_STAGE_MODEL:
120
+ NAME: ACETextEmbedder
121
+ PRETRAINED_MODEL: hf://scepter-studio/ACE-0.6B-512px@models/text_encoder/t5-v1_1-xxl/
122
+ TOKENIZER_PATH: hf://scepter-studio/ACE-0.6B-512px@models/tokenizer/t5-v1_1-xxl
123
+ LENGTH: 120
124
+ T5_DTYPE: bfloat16
125
+ ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
126
+ CLEAN: whitespace
127
+ USE_GRAD: False
example.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import os
4
+
5
+ from scepter.modules.utils.file_system import FS
6
+
7
+
8
+ def download_image(image, local_path=None):
9
+ if not FS.exists(local_path):
10
+ local_path = FS.get_from(image, local_path=local_path)
11
+ return local_path
12
+
13
+
14
+ def get_examples(cache_dir):
15
+ print('Downloading Examples ...')
16
+ examples = [
17
+ [
18
+ 'Image Segmentation',
19
+ download_image(
20
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/db3ebaa81899.png?raw=true',
21
+ os.path.join(cache_dir, 'examples/db3ebaa81899.png')), None,
22
+ None, '{image} Segmentation', 6666
23
+ ],
24
+ [
25
+ 'Depth Estimation',
26
+ download_image(
27
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f1927c4692ba.png?raw=true',
28
+ os.path.join(cache_dir, 'examples/f1927c4692ba.png')), None,
29
+ None, '{image} Depth Estimation', 6666
30
+ ],
31
+ [
32
+ 'Pose Estimation',
33
+ download_image(
34
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/014e5bf3b4d1.png?raw=true',
35
+ os.path.join(cache_dir, 'examples/014e5bf3b4d1.png')), None,
36
+ None, '{image} distinguish the poses of the figures', 999999
37
+ ],
38
+ [
39
+ 'Scribble Extraction',
40
+ download_image(
41
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5f59a202f8ac.png?raw=true',
42
+ os.path.join(cache_dir, 'examples/5f59a202f8ac.png')), None,
43
+ None, 'Generate a scribble of {image}, please.', 6666
44
+ ],
45
+ [
46
+ 'Mosaic',
47
+ download_image(
48
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a2f52361eea.png?raw=true',
49
+ os.path.join(cache_dir, 'examples/3a2f52361eea.png')), None,
50
+ None, 'Adapt {image} into a mosaic representation.', 6666
51
+ ],
52
+ [
53
+ 'Edge map Extraction',
54
+ download_image(
55
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/b9d1e519d6e5.png?raw=true',
56
+ os.path.join(cache_dir, 'examples/b9d1e519d6e5.png')), None,
57
+ None, 'Get the edge-enhanced result for {image}.', 6666
58
+ ],
59
+ [
60
+ 'Grayscale',
61
+ download_image(
62
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4ebbe2ba29b.png?raw=true',
63
+ os.path.join(cache_dir, 'examples/c4ebbe2ba29b.png')), None,
64
+ None, 'transform {image} into a black and white one', 6666
65
+ ],
66
+ [
67
+ 'Contour Extraction',
68
+ download_image(
69
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/19652d0f6c4b.png?raw=true',
70
+ os.path.join(cache_dir,
71
+ 'examples/19652d0f6c4b.png')), None, None,
72
+ 'Would you be able to make a contour picture from {image} for me?',
73
+ 6666
74
+ ],
75
+ [
76
+ 'Controllable Generation',
77
+ download_image(
78
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/249cda2844b7.png?raw=true',
79
+ os.path.join(cache_dir,
80
+ 'examples/249cda2844b7.png')), None, None,
81
+ 'Following the segmentation outcome in mask of {image}, develop a real-life image using the explanatory note in "a mighty cat lying on the bed”.',
82
+ 6666
83
+ ],
84
+ [
85
+ 'Controllable Generation',
86
+ download_image(
87
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/411f6c4b8e6c.png?raw=true',
88
+ os.path.join(cache_dir,
89
+ 'examples/411f6c4b8e6c.png')), None, None,
90
+ 'use the depth map {image} and the text caption "a cut white cat" to create a corresponding graphic image',
91
+ 999999
92
+ ],
93
+ [
94
+ 'Controllable Generation',
95
+ download_image(
96
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a35c96ed137a.png?raw=true',
97
+ os.path.join(cache_dir,
98
+ 'examples/a35c96ed137a.png')), None, None,
99
+ 'help translate this posture schema {image} into a colored image based on the context I provided "A beautiful woman Climbing the climbing wall, wearing a harness and climbing gear, skillfully maneuvering up the wall with her back to the camera, with a safety rope."',
100
+ 3599999
101
+ ],
102
+ [
103
+ 'Controllable Generation',
104
+ download_image(
105
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/dcb2fc86f1ce.png?raw=true',
106
+ os.path.join(cache_dir,
107
+ 'examples/dcb2fc86f1ce.png')), None, None,
108
+ 'Transform and generate an image using mosaic {image} and "Monarch butterflies gracefully perch on vibrant purple flowers, showcasing their striking orange and black wings in a lush garden setting." description',
109
+ 6666
110
+ ],
111
+ [
112
+ 'Controllable Generation',
113
+ download_image(
114
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/4cd4ee494962.png?raw=true',
115
+ os.path.join(cache_dir,
116
+ 'examples/4cd4ee494962.png')), None, None,
117
+ 'make this {image} colorful as per the "beautiful sunflowers"',
118
+ 6666
119
+ ],
120
+ [
121
+ 'Controllable Generation',
122
+ download_image(
123
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a47e3a9cd166.png?raw=true',
124
+ os.path.join(cache_dir,
125
+ 'examples/a47e3a9cd166.png')), None, None,
126
+ 'Take the edge conscious {image} and the written guideline "A whimsical animated character is depicted holding a delectable cake adorned with blue and white frosting and a drizzle of chocolate. The character wears a yellow headband with a bow, matching a cozy yellow sweater. Her dark hair is styled in a braid, tied with a yellow ribbon. With a golden fork in hand, she stands ready to enjoy a slice, exuding an air of joyful anticipation. The scene is creatively rendered with a charming and playful aesthetic." and produce a realistic image.',
127
+ 613725
128
+ ],
129
+ [
130
+ 'Controllable Generation',
131
+ download_image(
132
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d890ed8a3ac2.png?raw=true',
133
+ os.path.join(cache_dir,
134
+ 'examples/d890ed8a3ac2.png')), None, None,
135
+ 'creating a vivid image based on {image} and description "This image features a delicious rectangular tart with a flaky, golden-brown crust. The tart is topped with evenly sliced tomatoes, layered over a creamy cheese filling. Aromatic herbs are sprinkled on top, adding a touch of green and enhancing the visual appeal. The background includes a soft, textured fabric and scattered white flowers, creating an elegant and inviting presentation. Bright red tomatoes in the upper right corner hint at the fresh ingredients used in the dish."',
136
+ 6666
137
+ ],
138
+ [
139
+ 'Controllable Generation',
140
+ download_image(
141
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/131ca90fd2a9.png?raw=true',
142
+ os.path.join(cache_dir,
143
+ 'examples/131ca90fd2a9.png')), None, None,
144
+ '"A person sits contemplatively on the ground, surrounded by falling autumn leaves. Dressed in a green sweater and dark blue pants, they rest their chin on their hand, exuding a relaxed demeanor. Their stylish checkered slip-on shoes add a touch of flair, while a black purse lies in their lap. The backdrop of muted brown enhances the warm, cozy atmosphere of the scene." , generate the image that corresponds to the given scribble {image}.',
145
+ 613725
146
+ ],
147
+ [
148
+ 'Image Denoising',
149
+ download_image(
150
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/0844a686a179.png?raw=true',
151
+ os.path.join(cache_dir,
152
+ 'examples/0844a686a179.png')), None, None,
153
+ 'Eliminate noise interference in {image} and maximize the crispness to obtain superior high-definition quality',
154
+ 6666
155
+ ],
156
+ [
157
+ 'Inpainting',
158
+ download_image(
159
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b.png?raw=true',
160
+ os.path.join(cache_dir, 'examples/fa91b6b7e59b.png')),
161
+ download_image(
162
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b_mask.png?raw=true',
163
+ os.path.join(cache_dir,
164
+ 'examples/fa91b6b7e59b_mask.png')), None,
165
+ 'Ensure to overhaul the parts of the {image} indicated by the mask.',
166
+ 6666
167
+ ],
168
+ [
169
+ 'Inpainting',
170
+ download_image(
171
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26.png?raw=true',
172
+ os.path.join(cache_dir, 'examples/632899695b26.png')),
173
+ download_image(
174
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26_mask.png?raw=true',
175
+ os.path.join(cache_dir,
176
+ 'examples/632899695b26_mask.png')), None,
177
+ 'Refashion the mask portion of {image} in accordance with "A yellow egg with a smiling face painted on it"',
178
+ 6666
179
+ ],
180
+ [
181
+ 'Outpainting',
182
+ download_image(
183
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f.png?raw=true',
184
+ os.path.join(cache_dir, 'examples/f2b22c08be3f.png')),
185
+ download_image(
186
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f_mask.png?raw=true',
187
+ os.path.join(cache_dir,
188
+ 'examples/f2b22c08be3f_mask.png')), None,
189
+ 'Could the {image} be widened within the space designated by mask, while retaining the original?',
190
+ 6666
191
+ ],
192
+ [
193
+ 'General Editing',
194
+ download_image(
195
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/354d17594afe.png?raw=true',
196
+ os.path.join(cache_dir,
197
+ 'examples/354d17594afe.png')), None, None,
198
+ '{image} change the dog\'s posture to walking in the water, and change the background to green plants and a pond.',
199
+ 6666
200
+ ],
201
+ [
202
+ 'General Editing',
203
+ download_image(
204
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/38946455752b.png?raw=true',
205
+ os.path.join(cache_dir,
206
+ 'examples/38946455752b.png')), None, None,
207
+ '{image} change the color of the dress from white to red and the model\'s hair color red brown to blonde.Other parts remain unchanged',
208
+ 6669
209
+ ],
210
+ [
211
+ 'Facial Editing',
212
+ download_image(
213
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3ba5202f0cd8.png?raw=true',
214
+ os.path.join(cache_dir,
215
+ 'examples/3ba5202f0cd8.png')), None, None,
216
+ 'Keep the same facial feature in @3ba5202f0cd8, change the woman\'s clothing from a Blue denim jacket to a white turtleneck sweater and adjust her posture so that she is supporting her chin with both hands. Other aspects, such as background, hairstyle, facial expression, etc, remain unchanged.',
217
+ 99999
218
+ ],
219
+ [
220
+ 'Facial Editing',
221
+ download_image(
222
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/369365b94725.png?raw=true',
223
+ os.path.join(cache_dir, 'examples/369365b94725.png')), None,
224
+ None, '{image} Make her looking at the camera', 6666
225
+ ],
226
+ [
227
+ 'Facial Editing',
228
+ download_image(
229
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/92751f2e4a0e.png?raw=true',
230
+ os.path.join(cache_dir, 'examples/92751f2e4a0e.png')), None,
231
+ None, '{image} Remove the smile from his face', 9899999
232
+ ],
233
+ [
234
+ 'Render Text',
235
+ download_image(
236
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48.png?raw=true',
237
+ os.path.join(cache_dir, 'examples/33e9f27c2c48.png')),
238
+ download_image(
239
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48_mask.png?raw=true',
240
+ os.path.join(cache_dir,
241
+ 'examples/33e9f27c2c48_mask.png')), None,
242
+ 'Put the text "C A T" at the position marked by mask in the {image}',
243
+ 6666
244
+ ],
245
+ [
246
+ 'Remove Text',
247
+ download_image(
248
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/8530a6711b2e.png?raw=true',
249
+ os.path.join(cache_dir, 'examples/8530a6711b2e.png')), None,
250
+ None, 'Aim to remove any textual element in {image}', 6666
251
+ ],
252
+ [
253
+ 'Remove Text',
254
+ download_image(
255
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6.png?raw=true',
256
+ os.path.join(cache_dir, 'examples/c4d7fb28f8f6.png')),
257
+ download_image(
258
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6_mask.png?raw=true',
259
+ os.path.join(cache_dir,
260
+ 'examples/c4d7fb28f8f6_mask.png')), None,
261
+ 'Rub out any text found in the mask sector of the {image}.', 6666
262
+ ],
263
+ [
264
+ 'Remove Object',
265
+ download_image(
266
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e2f318fa5e5b.png?raw=true',
267
+ os.path.join(cache_dir,
268
+ 'examples/e2f318fa5e5b.png')), None, None,
269
+ 'Remove the unicorn in this {image}, ensuring a smooth edit.',
270
+ 99999
271
+ ],
272
+ [
273
+ 'Remove Object',
274
+ download_image(
275
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00.png?raw=true',
276
+ os.path.join(cache_dir, 'examples/1ae96d8aca00.png')),
277
+ download_image(
278
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00_mask.png?raw=true',
279
+ os.path.join(cache_dir, 'examples/1ae96d8aca00_mask.png')),
280
+ None, 'Discard the contents of the mask area from {image}.', 99999
281
+ ],
282
+ [
283
+ 'Add Object',
284
+ download_image(
285
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511.png?raw=true',
286
+ os.path.join(cache_dir, 'examples/80289f48e511.png')),
287
+ download_image(
288
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511_mask.png?raw=true',
289
+ os.path.join(cache_dir,
290
+ 'examples/80289f48e511_mask.png')), None,
291
+ 'add a Hot Air Balloon into the {image}, per the mask', 613725
292
+ ],
293
+ [
294
+ 'Style Transfer',
295
+ download_image(
296
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d725cb2009e8.png?raw=true',
297
+ os.path.join(cache_dir, 'examples/d725cb2009e8.png')), None,
298
+ None, 'Change the style of {image} to colored pencil style', 99999
299
+ ],
300
+ [
301
+ 'Style Transfer',
302
+ download_image(
303
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e0f48b3fd010.png?raw=true',
304
+ os.path.join(cache_dir, 'examples/e0f48b3fd010.png')), None,
305
+ None, 'make {image} to Walt Disney Animation style', 99999
306
+ ],
307
+ [
308
+ 'Style Transfer',
309
+ download_image(
310
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/9e73e7eeef55.png?raw=true',
311
+ os.path.join(cache_dir, 'examples/9e73e7eeef55.png')), None,
312
+ download_image(
313
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/2e02975293d6.png?raw=true',
314
+ os.path.join(cache_dir, 'examples/2e02975293d6.png')),
315
+ 'edit {image} based on the style of {image1} ', 99999
316
+ ],
317
+ [
318
+ 'Try On',
319
+ download_image(
320
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96.png?raw=true',
321
+ os.path.join(cache_dir, 'examples/ee4ca60b8c96.png')),
322
+ download_image(
323
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96_mask.png?raw=true',
324
+ os.path.join(cache_dir, 'examples/ee4ca60b8c96_mask.png')),
325
+ download_image(
326
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ebe825bbfe3c.png?raw=true',
327
+ os.path.join(cache_dir, 'examples/ebe825bbfe3c.png')),
328
+ 'Change the cloth in {image} to the one in {image1}', 99999
329
+ ],
330
+ [
331
+ 'Workflow',
332
+ download_image(
333
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/cb85353c004b.png?raw=true',
334
+ os.path.join(cache_dir, 'examples/cb85353c004b.png')), None,
335
+ None, '<workflow> ice cream {image}', 99999
336
+ ],
337
+ ]
338
+ print('Finish. Start building UI ...')
339
+ return examples
infer.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import copy
4
+ import math
5
+ import random
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torchvision.transforms.functional as TF
13
+
14
+ from scepter.modules.model.registry import DIFFUSIONS
15
+ from scepter.modules.model.utils.basic_utils import (
16
+ check_list_of_list,
17
+ pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor,
18
+ to_device,
19
+ unpack_tensor_into_imagelist
20
+ )
21
+ from scepter.modules.utils.distribute import we
22
+ from scepter.modules.utils.logger import get_logger
23
+
24
+ from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
25
+
26
+
27
+ def process_edit_image(images,
28
+ masks,
29
+ tasks,
30
+ max_seq_len=1024,
31
+ max_aspect_ratio=4,
32
+ d=16,
33
+ **kwargs):
34
+
35
+ if not isinstance(images, list):
36
+ images = [images]
37
+ if not isinstance(masks, list):
38
+ masks = [masks]
39
+ if not isinstance(tasks, list):
40
+ tasks = [tasks]
41
+
42
+ img_tensors = []
43
+ mask_tensors = []
44
+ for img, mask, task in zip(images, masks, tasks):
45
+ if mask is None or mask == '':
46
+ mask = Image.new('L', img.size, 0)
47
+ W, H = img.size
48
+ if H / W > max_aspect_ratio:
49
+ img = TF.center_crop(img, [int(max_aspect_ratio * W), W])
50
+ mask = TF.center_crop(mask, [int(max_aspect_ratio * W), W])
51
+ elif W / H > max_aspect_ratio:
52
+ img = TF.center_crop(img, [H, int(max_aspect_ratio * H)])
53
+ mask = TF.center_crop(mask, [H, int(max_aspect_ratio * H)])
54
+
55
+ H, W = img.height, img.width
56
+ scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d))))
57
+ rH = int(H * scale) // d * d # ensure divisible by self.d
58
+ rW = int(W * scale) // d * d
59
+
60
+ img = TF.resize(img, (rH, rW),
61
+ interpolation=TF.InterpolationMode.BICUBIC)
62
+ mask = TF.resize(mask, (rH, rW),
63
+ interpolation=TF.InterpolationMode.NEAREST_EXACT)
64
+
65
+ mask = np.asarray(mask)
66
+ mask = np.where(mask > 128, 1, 0)
67
+ mask = mask.astype(
68
+ np.float32) if np.any(mask) else np.ones_like(mask).astype(
69
+ np.float32)
70
+
71
+ img_tensor = TF.to_tensor(img).to(we.device_id)
72
+ img_tensor = TF.normalize(img_tensor,
73
+ mean=[0.5, 0.5, 0.5],
74
+ std=[0.5, 0.5, 0.5])
75
+ mask_tensor = TF.to_tensor(mask).to(we.device_id)
76
+ if task in ['inpainting', 'Try On', 'Inpainting']:
77
+ mask_indicator = mask_tensor.repeat(3, 1, 1)
78
+ img_tensor[mask_indicator == 1] = -1.0
79
+ img_tensors.append(img_tensor)
80
+ mask_tensors.append(mask_tensor)
81
+ return img_tensors, mask_tensors
82
+
83
+
84
+ class TextEmbedding(nn.Module):
85
+ def __init__(self, embedding_shape):
86
+ super().__init__()
87
+ self.pos = nn.Parameter(data=torch.zeros(embedding_shape))
88
+
89
+
90
+ class ACEInference(DiffusionInference):
91
+ def __init__(self, logger=None):
92
+ if logger is None:
93
+ logger = get_logger(name='scepter')
94
+ self.logger = logger
95
+ self.loaded_model = {}
96
+ self.loaded_model_name = [
97
+ 'diffusion_model', 'first_stage_model', 'cond_stage_model'
98
+ ]
99
+
100
+ def init_from_cfg(self, cfg):
101
+ self.name = cfg.NAME
102
+ self.is_default = cfg.get('IS_DEFAULT', False)
103
+ module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
104
+ assert cfg.have('MODEL')
105
+
106
+ self.diffusion_model = self.infer_model(
107
+ cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
108
+ 'DIFFUSION_MODEL',
109
+ None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None
110
+ self.first_stage_model = self.infer_model(
111
+ cfg.MODEL.FIRST_STAGE_MODEL,
112
+ module_paras.get(
113
+ 'FIRST_STAGE_MODEL',
114
+ None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None
115
+ self.cond_stage_model = self.infer_model(
116
+ cfg.MODEL.COND_STAGE_MODEL,
117
+ module_paras.get(
118
+ 'COND_STAGE_MODEL',
119
+ None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None
120
+ self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
121
+ logger=self.logger)
122
+
123
+ self.interpolate_func = lambda x: (F.interpolate(
124
+ x.unsqueeze(0),
125
+ scale_factor=1 / self.size_factor,
126
+ mode='nearest-exact') if x is not None else None)
127
+ self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', [])
128
+ self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS',
129
+ False)
130
+ if self.use_text_pos_embeddings:
131
+ self.text_position_embeddings = TextEmbedding(
132
+ (10, 4096)).eval().requires_grad_(False).to(we.device_id)
133
+ else:
134
+ self.text_position_embeddings = None
135
+
136
+ self.max_seq_len = cfg.MODEL.DIFFUSION_MODEL.MAX_SEQ_LEN
137
+ self.scale_factor = cfg.get('SCALE_FACTOR', 0.18215)
138
+ self.size_factor = cfg.get('SIZE_FACTOR', 8)
139
+ self.decoder_bias = cfg.get('DECODER_BIAS', 0)
140
+ self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
141
+
142
+ @torch.no_grad()
143
+ def encode_first_stage(self, x, **kwargs):
144
+ _, dtype = self.get_function_info(self.first_stage_model, 'encode')
145
+ with torch.autocast('cuda',
146
+ enabled=(dtype != 'float32'),
147
+ dtype=getattr(torch, dtype)):
148
+ z = [
149
+ self.scale_factor * get_model(self.first_stage_model)._encode(
150
+ i.unsqueeze(0).to(getattr(torch, dtype))) for i in x
151
+ ]
152
+ return z
153
+
154
+ @torch.no_grad()
155
+ def decode_first_stage(self, z):
156
+ _, dtype = self.get_function_info(self.first_stage_model, 'decode')
157
+ with torch.autocast('cuda',
158
+ enabled=(dtype != 'float32'),
159
+ dtype=getattr(torch, dtype)):
160
+ x = [
161
+ get_model(self.first_stage_model)._decode(
162
+ 1. / self.scale_factor * i.to(getattr(torch, dtype)))
163
+ for i in z
164
+ ]
165
+ return x
166
+
167
+ @torch.no_grad()
168
+ def __call__(self,
169
+ image=None,
170
+ mask=None,
171
+ prompt='',
172
+ task=None,
173
+ negative_prompt='',
174
+ output_height=512,
175
+ output_width=512,
176
+ sampler='ddim',
177
+ sample_steps=20,
178
+ guide_scale=4.5,
179
+ guide_rescale=0.5,
180
+ seed=-1,
181
+ history_io=None,
182
+ tar_index=0,
183
+ **kwargs):
184
+ input_image, input_mask = image, mask
185
+ g = torch.Generator(device=we.device_id)
186
+ seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
187
+ g.manual_seed(int(seed))
188
+
189
+ if input_image is not None:
190
+ assert isinstance(input_image, list) and isinstance(
191
+ input_mask, list)
192
+ if task is None:
193
+ task = [''] * len(input_image)
194
+ if not isinstance(prompt, list):
195
+ prompt = [prompt] * len(input_image)
196
+ if history_io is not None and len(history_io) > 0:
197
+ his_image, his_maks, his_prompt, his_task = history_io[
198
+ 'image'], history_io['mask'], history_io[
199
+ 'prompt'], history_io['task']
200
+ assert len(his_image) == len(his_maks) == len(
201
+ his_prompt) == len(his_task)
202
+ input_image = his_image + input_image
203
+ input_mask = his_maks + input_mask
204
+ task = his_task + task
205
+ prompt = his_prompt + [prompt[-1]]
206
+ prompt = [
207
+ pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
208
+ for i, pp in enumerate(prompt)
209
+ ]
210
+
211
+ edit_image, edit_image_mask = process_edit_image(
212
+ input_image, input_mask, task, max_seq_len=self.max_seq_len)
213
+
214
+ image, image_mask = edit_image[tar_index], edit_image_mask[
215
+ tar_index]
216
+ edit_image, edit_image_mask = [edit_image], [edit_image_mask]
217
+
218
+ else:
219
+ edit_image = edit_image_mask = [[]]
220
+ image = torch.zeros(
221
+ size=[3, int(output_height),
222
+ int(output_width)])
223
+ image_mask = torch.ones(
224
+ size=[1, int(output_height),
225
+ int(output_width)])
226
+ if not isinstance(prompt, list):
227
+ prompt = [prompt]
228
+
229
+ image, image_mask, prompt = [image], [image_mask], [prompt]
230
+ assert check_list_of_list(prompt) and check_list_of_list(
231
+ edit_image) and check_list_of_list(edit_image_mask)
232
+ # Assign Negative Prompt
233
+ if isinstance(negative_prompt, list):
234
+ negative_prompt = negative_prompt[0]
235
+ assert isinstance(negative_prompt, str)
236
+
237
+ n_prompt = copy.deepcopy(prompt)
238
+ for nn_p_id, nn_p in enumerate(n_prompt):
239
+ assert isinstance(nn_p, list)
240
+ n_prompt[nn_p_id][-1] = negative_prompt
241
+
242
+ ctx, null_ctx = {}, {}
243
+
244
+ # Get Noise Shape
245
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
246
+ image = to_device(image)
247
+ x = self.encode_first_stage(image)
248
+ self.dynamic_unload(self.first_stage_model,
249
+ 'first_stage_model',
250
+ skip_loaded=True)
251
+ noise = [
252
+ torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
253
+ for i in x
254
+ ]
255
+ noise, x_shapes = pack_imagelist_into_tensor(noise)
256
+ ctx['x_shapes'] = null_ctx['x_shapes'] = x_shapes
257
+
258
+ image_mask = to_device(image_mask, strict=False)
259
+ cond_mask = [self.interpolate_func(i) for i in image_mask
260
+ ] if image_mask is not None else [None] * len(image)
261
+ ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
262
+
263
+ # Encode Prompt
264
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
265
+ function_name, dtype = self.get_function_info(self.cond_stage_model)
266
+ cont, cont_mask = getattr(get_model(self.cond_stage_model),
267
+ function_name)(prompt)
268
+ cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
269
+ cont_mask)
270
+ null_cont, null_cont_mask = getattr(get_model(self.cond_stage_model),
271
+ function_name)(n_prompt)
272
+ null_cont, null_cont_mask = self.cond_stage_embeddings(
273
+ prompt, edit_image, null_cont, null_cont_mask)
274
+ self.dynamic_unload(self.cond_stage_model,
275
+ 'cond_stage_model',
276
+ skip_loaded=False)
277
+ ctx['crossattn'] = cont
278
+ null_ctx['crossattn'] = null_cont
279
+
280
+ # Encode Edit Images
281
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
282
+ edit_image = [to_device(i, strict=False) for i in edit_image]
283
+ edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
284
+ e_img, e_mask = [], []
285
+ for u, m in zip(edit_image, edit_image_mask):
286
+ if u is None:
287
+ continue
288
+ if m is None:
289
+ m = [None] * len(u)
290
+ e_img.append(self.encode_first_stage(u, **kwargs))
291
+ e_mask.append([self.interpolate_func(i) for i in m])
292
+ self.dynamic_unload(self.first_stage_model,
293
+ 'first_stage_model',
294
+ skip_loaded=True)
295
+ null_ctx['edit'] = ctx['edit'] = e_img
296
+ null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
297
+
298
+ # Diffusion Process
299
+ self.dynamic_load(self.diffusion_model, 'diffusion_model')
300
+ function_name, dtype = self.get_function_info(self.diffusion_model)
301
+ with torch.autocast('cuda',
302
+ enabled=dtype in ('float16', 'bfloat16'),
303
+ dtype=getattr(torch, dtype)):
304
+ latent = self.diffusion.sample(
305
+ noise=noise,
306
+ sampler=sampler,
307
+ model=get_model(self.diffusion_model),
308
+ model_kwargs=[{
309
+ 'cond':
310
+ ctx,
311
+ 'mask':
312
+ cont_mask,
313
+ 'text_position_embeddings':
314
+ self.text_position_embeddings.pos if hasattr(
315
+ self.text_position_embeddings, 'pos') else None
316
+ }, {
317
+ 'cond':
318
+ null_ctx,
319
+ 'mask':
320
+ null_cont_mask,
321
+ 'text_position_embeddings':
322
+ self.text_position_embeddings.pos if hasattr(
323
+ self.text_position_embeddings, 'pos') else None
324
+ }] if guide_scale is not None and guide_scale > 1 else {
325
+ 'cond':
326
+ null_ctx,
327
+ 'mask':
328
+ cont_mask,
329
+ 'text_position_embeddings':
330
+ self.text_position_embeddings.pos if hasattr(
331
+ self.text_position_embeddings, 'pos') else None
332
+ },
333
+ steps=sample_steps,
334
+ show_progress=True,
335
+ seed=seed,
336
+ guide_scale=guide_scale,
337
+ guide_rescale=guide_rescale,
338
+ return_intermediate=None,
339
+ **kwargs)
340
+ self.dynamic_unload(self.diffusion_model,
341
+ 'diffusion_model',
342
+ skip_loaded=False)
343
+
344
+ # Decode to Pixel Space
345
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
346
+ samples = unpack_tensor_into_imagelist(latent, x_shapes)
347
+ x_samples = self.decode_first_stage(samples)
348
+ self.dynamic_unload(self.first_stage_model,
349
+ 'first_stage_model',
350
+ skip_loaded=False)
351
+
352
+ imgs = [
353
+ torch.clamp((x_i + 1.0) / 2.0 + self.decoder_bias / 255,
354
+ min=0.0,
355
+ max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
356
+ for x_i in x_samples
357
+ ]
358
+ imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
359
+ return imgs
360
+
361
+ def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
362
+ if self.use_text_pos_embeddings and not torch.sum(
363
+ self.text_position_embeddings.pos) > 0:
364
+ identifier_cont, _ = getattr(get_model(self.cond_stage_model),
365
+ 'encode')(self.text_indentifers,
366
+ return_mask=True)
367
+ self.text_position_embeddings.load_state_dict(
368
+ {'pos': identifier_cont[:, 0, :]})
369
+
370
+ cont_, cont_mask_ = [], []
371
+ for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):
372
+ if isinstance(pp, list):
373
+ cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])
374
+ cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])
375
+ else:
376
+ raise NotImplementedError
377
+
378
+ return cont_, cont_mask_
modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import data, model, solver
modules/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import dataset
modules/data/dataset/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dataset import ACEDemoDataset
modules/data/dataset/dataset.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import io
5
+ import math
6
+ import os
7
+ import sys
8
+ from collections import defaultdict
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torchvision.transforms as T
13
+ from PIL import Image
14
+ from torchvision.transforms.functional import InterpolationMode
15
+
16
+ from scepter.modules.data.dataset.base_dataset import BaseDataset
17
+ from scepter.modules.data.dataset.registry import DATASETS
18
+ from scepter.modules.transform.io import pillow_convert
19
+ from scepter.modules.utils.config import dict_to_yaml
20
+ from scepter.modules.utils.file_system import FS
21
+
22
+ Image.MAX_IMAGE_PIXELS = None
23
+
24
+ @DATASETS.register_class()
25
+ class ACEDemoDataset(BaseDataset):
26
+ para_dict = {
27
+ 'MS_DATASET_NAME': {
28
+ 'value': '',
29
+ 'description': 'Modelscope dataset name.'
30
+ },
31
+ 'MS_DATASET_NAMESPACE': {
32
+ 'value': '',
33
+ 'description': 'Modelscope dataset namespace.'
34
+ },
35
+ 'MS_DATASET_SUBNAME': {
36
+ 'value': '',
37
+ 'description': 'Modelscope dataset subname.'
38
+ },
39
+ 'MS_DATASET_SPLIT': {
40
+ 'value': '',
41
+ 'description':
42
+ 'Modelscope dataset split set name, default is train.'
43
+ },
44
+ 'MS_REMAP_KEYS': {
45
+ 'value':
46
+ None,
47
+ 'description':
48
+ 'Modelscope dataset header of list file, the default is Target:FILE; '
49
+ 'If your file is not this header, please set this field, which is a map dict.'
50
+ "For example, { 'Image:FILE': 'Target:FILE' } will replace the filed Image:FILE to Target:FILE"
51
+ },
52
+ 'MS_REMAP_PATH': {
53
+ 'value':
54
+ None,
55
+ 'description':
56
+ 'When modelscope dataset name is not None, that means you use the dataset from modelscope,'
57
+ ' default is None. But if you want to use the datalist from modelscope and the file from '
58
+ 'local device, you can use this field to set the root path of your images. '
59
+ },
60
+ 'TRIGGER_WORDS': {
61
+ 'value':
62
+ '',
63
+ 'description':
64
+ 'The words used to describe the common features of your data, especially when you customize a '
65
+ 'tuner. Use these words you can get what you want.'
66
+ },
67
+ 'HIGHLIGHT_KEYWORDS': {
68
+ 'value':
69
+ '',
70
+ 'description':
71
+ 'The keywords you want to highlight in prompt, which will be replace by <HIGHLIGHT_KEYWORDS>.'
72
+ },
73
+ 'KEYWORDS_SIGN': {
74
+ 'value':
75
+ '',
76
+ 'description':
77
+ 'The keywords sign you want to add, which is like <{HIGHLIGHT_KEYWORDS}{KEYWORDS_SIGN}>'
78
+ },
79
+ }
80
+
81
+ def __init__(self, cfg, logger=None):
82
+ super().__init__(cfg=cfg, logger=logger)
83
+ from modelscope import MsDataset
84
+ from modelscope.utils.constant import DownloadMode
85
+ ms_dataset_name = cfg.get('MS_DATASET_NAME', None)
86
+ ms_dataset_namespace = cfg.get('MS_DATASET_NAMESPACE', None)
87
+ ms_dataset_subname = cfg.get('MS_DATASET_SUBNAME', None)
88
+ ms_dataset_split = cfg.get('MS_DATASET_SPLIT', 'train')
89
+ ms_remap_keys = cfg.get('MS_REMAP_KEYS', None)
90
+ ms_remap_path = cfg.get('MS_REMAP_PATH', None)
91
+
92
+ self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)
93
+ self.max_aspect_ratio = cfg.get('MAX_ASPECT_RATIO', 4)
94
+ self.d = cfg.get('DOWNSAMPLE_RATIO', 16)
95
+ self.replace_style = cfg.get('REPLACE_STYLE', False)
96
+ self.trigger_words = cfg.get('TRIGGER_WORDS', '')
97
+ self.replace_keywords = cfg.get('HIGHLIGHT_KEYWORDS', '')
98
+ self.keywords_sign = cfg.get('KEYWORDS_SIGN', '')
99
+ self.add_indicator = cfg.get('ADD_INDICATOR', False)
100
+ # Use modelscope dataset
101
+ if not ms_dataset_name:
102
+ raise ValueError(
103
+ 'Your must set MS_DATASET_NAME as modelscope dataset or your local dataset orignized '
104
+ 'as modelscope dataset.')
105
+ if FS.exists(ms_dataset_name):
106
+ ms_dataset_name = FS.get_dir_to_local_dir(ms_dataset_name)
107
+ self.ms_dataset_name = ms_dataset_name
108
+ # ms_remap_path = ms_dataset_name
109
+ try:
110
+ self.data = MsDataset.load(str(ms_dataset_name),
111
+ namespace=ms_dataset_namespace,
112
+ subset_name=ms_dataset_subname,
113
+ split=ms_dataset_split)
114
+ except Exception:
115
+ self.logger.info(
116
+ "Load Modelscope dataset failed, retry with download_mode='force_redownload'."
117
+ )
118
+ try:
119
+ self.data = MsDataset.load(
120
+ str(ms_dataset_name),
121
+ namespace=ms_dataset_namespace,
122
+ subset_name=ms_dataset_subname,
123
+ split=ms_dataset_split,
124
+ download_mode=DownloadMode.FORCE_REDOWNLOAD)
125
+ except Exception as sec_e:
126
+ raise ValueError(f'Load Modelscope dataset failed {sec_e}.')
127
+ if ms_remap_keys:
128
+ self.data = self.data.remap_columns(ms_remap_keys.get_dict())
129
+
130
+ if ms_remap_path:
131
+
132
+ def map_func(example):
133
+ return {
134
+ k: os.path.join(ms_remap_path, v)
135
+ if k.endswith(':FILE') else v
136
+ for k, v in example.items()
137
+ }
138
+
139
+ self.data = self.data.ds_instance.map(map_func)
140
+
141
+ self.transforms = T.Compose([
142
+ T.ToTensor(),
143
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
144
+ ])
145
+
146
+ def __len__(self):
147
+ if self.mode == 'train':
148
+ return sys.maxsize
149
+ else:
150
+ return len(self.data)
151
+
152
+ def _get(self, index: int):
153
+ current_data = self.data[index % len(self.data)]
154
+
155
+ tar_image_path = current_data.get('Target:FILE', '')
156
+ src_image_path = current_data.get('Source:FILE', '')
157
+
158
+ style = current_data.get('Style', '')
159
+ prompt = current_data.get('Prompt', current_data.get('prompt', ''))
160
+ if self.replace_style and not style == '':
161
+ prompt = prompt.replace(style, f'<{self.keywords_sign}>')
162
+
163
+ elif not self.replace_keywords.strip() == '':
164
+ prompt = prompt.replace(
165
+ self.replace_keywords,
166
+ '<' + self.replace_keywords + f'{self.keywords_sign}>')
167
+
168
+ if not self.trigger_words == '':
169
+ prompt = self.trigger_words.strip() + ' ' + prompt
170
+
171
+ src_image = self.load_image(self.ms_dataset_name,
172
+ src_image_path,
173
+ cvt_type='RGB')
174
+ tar_image = self.load_image(self.ms_dataset_name,
175
+ tar_image_path,
176
+ cvt_type='RGB')
177
+ src_image = self.image_preprocess(src_image)
178
+ tar_image = self.image_preprocess(tar_image)
179
+
180
+ tar_image = self.transforms(tar_image)
181
+ src_image = self.transforms(src_image)
182
+ src_mask = torch.ones_like(src_image[[0]])
183
+ tar_mask = torch.ones_like(tar_image[[0]])
184
+ if self.add_indicator:
185
+ if '{image}' not in prompt:
186
+ prompt = '{image}, ' + prompt
187
+
188
+ return {
189
+ 'edit_image': [src_image],
190
+ 'edit_image_mask': [src_mask],
191
+ 'image': tar_image,
192
+ 'image_mask': tar_mask,
193
+ 'prompt': [prompt],
194
+ }
195
+
196
+ def load_image(self, prefix, img_path, cvt_type=None):
197
+ if img_path is None or img_path == '':
198
+ return None
199
+ img_path = os.path.join(prefix, img_path)
200
+ with FS.get_object(img_path) as image_bytes:
201
+ image = Image.open(io.BytesIO(image_bytes))
202
+ if cvt_type is not None:
203
+ image = pillow_convert(image, cvt_type)
204
+ return image
205
+
206
+ def image_preprocess(self,
207
+ img,
208
+ size=None,
209
+ interpolation=InterpolationMode.BILINEAR):
210
+ H, W = img.height, img.width
211
+ if H / W > self.max_aspect_ratio:
212
+ img = T.CenterCrop((self.max_aspect_ratio * W, W))(img)
213
+ elif W / H > self.max_aspect_ratio:
214
+ img = T.CenterCrop((H, self.max_aspect_ratio * H))(img)
215
+
216
+ if size is None:
217
+ # resize image for max_seq_len, while keep the aspect ratio
218
+ H, W = img.height, img.width
219
+ scale = min(
220
+ 1.0,
221
+ math.sqrt(self.max_seq_len / ((H / self.d) * (W / self.d))))
222
+ rH = int(
223
+ H * scale) // self.d * self.d # ensure divisible by self.d
224
+ rW = int(W * scale) // self.d * self.d
225
+ else:
226
+ rH, rW = size
227
+ img = T.Resize((rH, rW), interpolation=interpolation,
228
+ antialias=True)(img)
229
+ return np.array(img, dtype=np.uint8)
230
+
231
+ @staticmethod
232
+ def get_config_template():
233
+ return dict_to_yaml('DATASet',
234
+ __class__.__name__,
235
+ ACEDemoDataset.para_dict,
236
+ set_name=True)
237
+
238
+ @staticmethod
239
+ def collate_fn(batch):
240
+ collect = defaultdict(list)
241
+ for sample in batch:
242
+ for k, v in sample.items():
243
+ collect[k].append(v)
244
+
245
+ new_batch = dict()
246
+ for k, v in collect.items():
247
+ if all([i is None for i in v]):
248
+ new_batch[k] = None
249
+ else:
250
+ new_batch[k] = v
251
+
252
+ return new_batch
modules/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import backbone, embedder, diffusion, network
modules/model/backbone/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from .ace import DiTACE
modules/model/backbone/ace.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import re
4
+ from collections import OrderedDict
5
+ from functools import partial
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange
10
+ from torch.nn.utils.rnn import pad_sequence
11
+ from torch.utils.checkpoint import checkpoint_sequential
12
+
13
+ from scepter.modules.model.base_model import BaseModel
14
+ from scepter.modules.model.registry import BACKBONES
15
+ from scepter.modules.utils.config import dict_to_yaml
16
+ from scepter.modules.utils.file_system import FS
17
+
18
+ from .layers import (
19
+ Mlp,
20
+ TimestepEmbedder,
21
+ PatchEmbed,
22
+ DiTACEBlock,
23
+ T2IFinalLayer
24
+ )
25
+ from .pos_embed import rope_params
26
+
27
+
28
+ @BACKBONES.register_class()
29
+ class DiTACE(BaseModel):
30
+
31
+ para_dict = {
32
+ 'PATCH_SIZE': {
33
+ 'value': 2,
34
+ 'description': ''
35
+ },
36
+ 'IN_CHANNELS': {
37
+ 'value': 4,
38
+ 'description': ''
39
+ },
40
+ 'HIDDEN_SIZE': {
41
+ 'value': 1152,
42
+ 'description': ''
43
+ },
44
+ 'DEPTH': {
45
+ 'value': 28,
46
+ 'description': ''
47
+ },
48
+ 'NUM_HEADS': {
49
+ 'value': 16,
50
+ 'description': ''
51
+ },
52
+ 'MLP_RATIO': {
53
+ 'value': 4.0,
54
+ 'description': ''
55
+ },
56
+ 'PRED_SIGMA': {
57
+ 'value': True,
58
+ 'description': ''
59
+ },
60
+ 'DROP_PATH': {
61
+ 'value': 0.,
62
+ 'description': ''
63
+ },
64
+ 'WINDOW_SIZE': {
65
+ 'value': 0,
66
+ 'description': ''
67
+ },
68
+ 'WINDOW_BLOCK_INDEXES': {
69
+ 'value': None,
70
+ 'description': ''
71
+ },
72
+ 'Y_CHANNELS': {
73
+ 'value': 4096,
74
+ 'description': ''
75
+ },
76
+ 'ATTENTION_BACKEND': {
77
+ 'value': None,
78
+ 'description': ''
79
+ },
80
+ 'QK_NORM': {
81
+ 'value': True,
82
+ 'description': 'Whether to use RMSNorm for query and key.',
83
+ },
84
+ }
85
+ para_dict.update(BaseModel.para_dict)
86
+
87
+ def __init__(self, cfg, logger):
88
+ super().__init__(cfg, logger=logger)
89
+ self.window_block_indexes = cfg.get('WINDOW_BLOCK_INDEXES', None)
90
+ if self.window_block_indexes is None:
91
+ self.window_block_indexes = []
92
+ self.pred_sigma = cfg.get('PRED_SIGMA', True)
93
+ self.in_channels = cfg.get('IN_CHANNELS', 4)
94
+ self.out_channels = self.in_channels * 2 if self.pred_sigma else self.in_channels
95
+ self.patch_size = cfg.get('PATCH_SIZE', 2)
96
+ self.num_heads = cfg.get('NUM_HEADS', 16)
97
+ self.hidden_size = cfg.get('HIDDEN_SIZE', 1152)
98
+ self.y_channels = cfg.get('Y_CHANNELS', 4096)
99
+ self.drop_path = cfg.get('DROP_PATH', 0.)
100
+ self.depth = cfg.get('DEPTH', 28)
101
+ self.mlp_ratio = cfg.get('MLP_RATIO', 4.0)
102
+ self.use_grad_checkpoint = cfg.get('USE_GRAD_CHECKPOINT', False)
103
+ self.attention_backend = cfg.get('ATTENTION_BACKEND', None)
104
+ self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)
105
+ self.qk_norm = cfg.get('QK_NORM', False)
106
+ self.ignore_keys = cfg.get('IGNORE_KEYS', [])
107
+ assert (self.hidden_size % self.num_heads
108
+ ) == 0 and (self.hidden_size // self.num_heads) % 2 == 0
109
+ d = self.hidden_size // self.num_heads
110
+ self.freqs = torch.cat(
111
+ [
112
+ rope_params(self.max_seq_len, d - 4 * (d // 6)), # T (~1/3)
113
+ rope_params(self.max_seq_len, 2 * (d // 6)), # H (~1/3)
114
+ rope_params(self.max_seq_len, 2 * (d // 6)) # W (~1/3)
115
+ ],
116
+ dim=1)
117
+
118
+ # init embedder
119
+ self.x_embedder = PatchEmbed(self.patch_size,
120
+ self.in_channels + 1,
121
+ self.hidden_size,
122
+ bias=True,
123
+ flatten=False)
124
+ self.t_embedder = TimestepEmbedder(self.hidden_size)
125
+ self.y_embedder = Mlp(in_features=self.y_channels,
126
+ hidden_features=self.hidden_size,
127
+ out_features=self.hidden_size,
128
+ act_layer=lambda: nn.GELU(approximate='tanh'),
129
+ drop=0)
130
+ self.t_block = nn.Sequential(
131
+ nn.SiLU(),
132
+ nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True))
133
+ # init blocks
134
+ drop_path = [
135
+ x.item() for x in torch.linspace(0, self.drop_path, self.depth)
136
+ ]
137
+ self.blocks = nn.ModuleList([
138
+ DiTACEBlock(self.hidden_size,
139
+ self.num_heads,
140
+ mlp_ratio=self.mlp_ratio,
141
+ drop_path=drop_path[i],
142
+ window_size=self.window_size
143
+ if i in self.window_block_indexes else 0,
144
+ backend=self.attention_backend,
145
+ use_condition=True,
146
+ qk_norm=self.qk_norm) for i in range(self.depth)
147
+ ])
148
+ self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size,
149
+ self.out_channels)
150
+ self.initialize_weights()
151
+
152
+ def load_pretrained_model(self, pretrained_model):
153
+ if pretrained_model:
154
+ with FS.get_from(pretrained_model, wait_finish=True) as local_path:
155
+ model = torch.load(local_path, map_location='cpu')
156
+ if 'state_dict' in model:
157
+ model = model['state_dict']
158
+ new_ckpt = OrderedDict()
159
+ for k, v in model.items():
160
+ if self.ignore_keys is not None:
161
+ if (isinstance(self.ignore_keys, str) and re.match(self.ignore_keys, k)) or \
162
+ (isinstance(self.ignore_keys, list) and k in self.ignore_keys):
163
+ continue
164
+ k = k.replace('.cross_attn.q_linear.', '.cross_attn.q.')
165
+ k = k.replace('.cross_attn.proj.',
166
+ '.cross_attn.o.').replace(
167
+ '.attn.proj.', '.attn.o.')
168
+ if '.cross_attn.kv_linear.' in k:
169
+ k_p, v_p = torch.split(v, v.shape[0] // 2)
170
+ new_ckpt[k.replace('.cross_attn.kv_linear.',
171
+ '.cross_attn.k.')] = k_p
172
+ new_ckpt[k.replace('.cross_attn.kv_linear.',
173
+ '.cross_attn.v.')] = v_p
174
+ elif '.attn.qkv.' in k:
175
+ q_p, k_p, v_p = torch.split(v, v.shape[0] // 3)
176
+ new_ckpt[k.replace('.attn.qkv.', '.attn.q.')] = q_p
177
+ new_ckpt[k.replace('.attn.qkv.', '.attn.k.')] = k_p
178
+ new_ckpt[k.replace('.attn.qkv.', '.attn.v.')] = v_p
179
+ elif 'y_embedder.y_proj.' in k:
180
+ new_ckpt[k.replace('y_embedder.y_proj.',
181
+ 'y_embedder.')] = v
182
+ elif k in ('x_embedder.proj.weight'):
183
+ model_p = self.state_dict()[k]
184
+ if v.shape != model_p.shape:
185
+ model_p.zero_()
186
+ model_p[:, :4, :, :].copy_(v)
187
+ new_ckpt[k] = torch.nn.parameter.Parameter(model_p)
188
+ else:
189
+ new_ckpt[k] = v
190
+ elif k in ('x_embedder.proj.bias'):
191
+ new_ckpt[k] = v
192
+ else:
193
+ new_ckpt[k] = v
194
+ missing, unexpected = self.load_state_dict(new_ckpt,
195
+ strict=False)
196
+ print(
197
+ f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys'
198
+ )
199
+ if len(missing) > 0:
200
+ print(f'Missing Keys:\n {missing}')
201
+ if len(unexpected) > 0:
202
+ print(f'\nUnexpected Keys:\n {unexpected}')
203
+
204
+ def forward(self,
205
+ x,
206
+ t=None,
207
+ cond=dict(),
208
+ mask=None,
209
+ text_position_embeddings=None,
210
+ gc_seg=-1,
211
+ **kwargs):
212
+ if self.freqs.device != x.device:
213
+ self.freqs = self.freqs.to(x.device)
214
+ if isinstance(cond, dict):
215
+ context = cond.get('crossattn', None)
216
+ else:
217
+ context = cond
218
+ if text_position_embeddings is not None:
219
+ # default use the text_position_embeddings in state_dict
220
+ # if state_dict doesn't including this key, use the arg: text_position_embeddings
221
+ proj_position_embeddings = self.y_embedder(
222
+ text_position_embeddings)
223
+ else:
224
+ proj_position_embeddings = None
225
+
226
+ ctx_batch, txt_lens = [], []
227
+ if mask is not None and isinstance(mask, list):
228
+ for ctx, ctx_mask in zip(context, mask):
229
+ for frame_id, one_ctx in enumerate(zip(ctx, ctx_mask)):
230
+ u, m = one_ctx
231
+ t_len = m.flatten().sum() # l
232
+ u = u[:t_len]
233
+ u = self.y_embedder(u)
234
+ if frame_id == 0:
235
+ u = u + proj_position_embeddings[
236
+ len(ctx) -
237
+ 1] if proj_position_embeddings is not None else u
238
+ else:
239
+ u = u + proj_position_embeddings[
240
+ frame_id -
241
+ 1] if proj_position_embeddings is not None else u
242
+ ctx_batch.append(u)
243
+ txt_lens.append(t_len)
244
+ else:
245
+ raise TypeError
246
+ y = torch.cat(ctx_batch, dim=0)
247
+ txt_lens = torch.LongTensor(txt_lens).to(x.device, non_blocking=True)
248
+
249
+ batch_frames = []
250
+ for u, shape, m in zip(x, cond['x_shapes'], cond['x_mask']):
251
+ u = u[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
252
+ m = torch.ones_like(u[[0], :, :]) if m is None else m.squeeze(0)
253
+ batch_frames.append([torch.cat([u, m], dim=0).unsqueeze(0)])
254
+ if 'edit' in cond:
255
+ for i, (edit, edit_mask) in enumerate(
256
+ zip(cond['edit'], cond['edit_mask'])):
257
+ if edit is None:
258
+ continue
259
+ for u, m in zip(edit, edit_mask):
260
+ u = u.squeeze(0)
261
+ m = torch.ones_like(
262
+ u[[0], :, :]) if m is None else m.squeeze(0)
263
+ batch_frames[i].append(
264
+ torch.cat([u, m], dim=0).unsqueeze(0))
265
+
266
+ patch_batch, shape_batch, self_x_len, cross_x_len = [], [], [], []
267
+ for frames in batch_frames:
268
+ patches, patch_shapes = [], []
269
+ self_x_len.append(0)
270
+ for frame_id, u in enumerate(frames):
271
+ u = self.x_embedder(u)
272
+ h, w = u.size(2), u.size(3)
273
+ u = rearrange(u, '1 c h w -> (h w) c')
274
+ if frame_id == 0:
275
+ u = u + proj_position_embeddings[
276
+ len(frames) -
277
+ 1] if proj_position_embeddings is not None else u
278
+ else:
279
+ u = u + proj_position_embeddings[
280
+ frame_id -
281
+ 1] if proj_position_embeddings is not None else u
282
+ patches.append(u)
283
+ patch_shapes.append([h, w])
284
+ cross_x_len.append(h * w) # b*s, 1
285
+ self_x_len[-1] += h * w # b, 1
286
+ # u = torch.cat(patches, dim=0)
287
+ patch_batch.extend(patches)
288
+ shape_batch.append(
289
+ torch.LongTensor(patch_shapes).to(x.device, non_blocking=True))
290
+ # repeat t to align with x
291
+ t = torch.cat([t[i].repeat(l) for i, l in enumerate(self_x_len)])
292
+ self_x_len, cross_x_len = (torch.LongTensor(self_x_len).to(
293
+ x.device, non_blocking=True), torch.LongTensor(cross_x_len).to(
294
+ x.device, non_blocking=True))
295
+ # x = pad_sequence(tuple(patch_batch), batch_first=True) # b, s*max(cl), c
296
+ x = torch.cat(patch_batch, dim=0)
297
+ x_shapes = pad_sequence(tuple(shape_batch),
298
+ batch_first=True) # b, max(len(frames)), 2
299
+ t = self.t_embedder(t) # (N, D)
300
+ t0 = self.t_block(t)
301
+ # y = self.y_embedder(context)
302
+
303
+ kwargs = dict(y=y,
304
+ t=t0,
305
+ x_shapes=x_shapes,
306
+ self_x_len=self_x_len,
307
+ cross_x_len=cross_x_len,
308
+ freqs=self.freqs,
309
+ txt_lens=txt_lens)
310
+ if self.use_grad_checkpoint and gc_seg >= 0:
311
+ x = checkpoint_sequential(
312
+ functions=[partial(block, **kwargs) for block in self.blocks],
313
+ segments=gc_seg if gc_seg > 0 else len(self.blocks),
314
+ input=x,
315
+ use_reentrant=False)
316
+ else:
317
+ for block in self.blocks:
318
+ x = block(x, **kwargs)
319
+ x = self.final_layer(x, t) # b*s*n, d
320
+ outs, cur_length = [], 0
321
+ p = self.patch_size
322
+ for seq_length, shape in zip(self_x_len, shape_batch):
323
+ x_i = x[cur_length:cur_length + seq_length]
324
+ h, w = shape[0].tolist()
325
+ u = x_i[:h * w].view(h, w, p, p, -1)
326
+ u = rearrange(u, 'h w p q c -> (h p w q) c'
327
+ ) # dump into sequence for following tensor ops
328
+ cur_length = cur_length + seq_length
329
+ outs.append(u)
330
+ x = pad_sequence(tuple(outs), batch_first=True).permute(0, 2, 1)
331
+ if self.pred_sigma:
332
+ return x.chunk(2, dim=1)[0]
333
+ else:
334
+ return x
335
+
336
+ def initialize_weights(self):
337
+ # Initialize transformer layers:
338
+ def _basic_init(module):
339
+ if isinstance(module, nn.Linear):
340
+ torch.nn.init.xavier_uniform_(module.weight)
341
+ if module.bias is not None:
342
+ nn.init.constant_(module.bias, 0)
343
+
344
+ self.apply(_basic_init)
345
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
346
+ w = self.x_embedder.proj.weight.data
347
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
348
+ # Initialize timestep embedding MLP:
349
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
350
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
351
+ nn.init.normal_(self.t_block[1].weight, std=0.02)
352
+ # Initialize caption embedding MLP:
353
+ if hasattr(self, 'y_embedder'):
354
+ nn.init.normal_(self.y_embedder.fc1.weight, std=0.02)
355
+ nn.init.normal_(self.y_embedder.fc2.weight, std=0.02)
356
+ # Zero-out adaLN modulation layers
357
+ for block in self.blocks:
358
+ nn.init.constant_(block.cross_attn.o.weight, 0)
359
+ nn.init.constant_(block.cross_attn.o.bias, 0)
360
+ # Zero-out output layers:
361
+ nn.init.constant_(self.final_layer.linear.weight, 0)
362
+ nn.init.constant_(self.final_layer.linear.bias, 0)
363
+
364
+ @property
365
+ def dtype(self):
366
+ return next(self.parameters()).dtype
367
+
368
+ @staticmethod
369
+ def get_config_template():
370
+ return dict_to_yaml('BACKBONE',
371
+ __class__.__name__,
372
+ DiTACE.para_dict,
373
+ set_name=True)
modules/model/backbone/layers.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import math
4
+ import warnings
5
+ import torch
6
+ import torch.nn as nn
7
+ from .pos_embed import rope_apply_multires as rope_apply
8
+
9
+ try:
10
+ from flash_attn import (flash_attn_varlen_func)
11
+ FLASHATTN_IS_AVAILABLE = True
12
+ except ImportError as e:
13
+ FLASHATTN_IS_AVAILABLE = False
14
+ flash_attn_varlen_func = None
15
+ warnings.warn(f'{e}')
16
+
17
+ __all__ = [
18
+ "drop_path",
19
+ "modulate",
20
+ "PatchEmbed",
21
+ "DropPath",
22
+ "RMSNorm",
23
+ "Mlp",
24
+ "TimestepEmbedder",
25
+ "DiTEditBlock",
26
+ "MultiHeadAttentionDiTEdit",
27
+ "T2IFinalLayer",
28
+ ]
29
+
30
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
31
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
32
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
33
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
34
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
35
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
36
+ 'survival rate' as the argument.
37
+ """
38
+ if drop_prob == 0. or not training:
39
+ return x
40
+ keep_prob = 1 - drop_prob
41
+ shape = (x.shape[0], ) + (1, ) * (
42
+ x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
43
+ random_tensor = keep_prob + torch.rand(
44
+ shape, dtype=x.dtype, device=x.device)
45
+ random_tensor.floor_() # binarize
46
+ output = x.div(keep_prob) * random_tensor
47
+ return output
48
+
49
+
50
+ def modulate(x, shift, scale, unsqueeze=False):
51
+ if unsqueeze:
52
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
53
+ else:
54
+ return x * (1 + scale) + shift
55
+
56
+
57
+ class PatchEmbed(nn.Module):
58
+ """ 2D Image to Patch Embedding
59
+ """
60
+ def __init__(
61
+ self,
62
+ patch_size=16,
63
+ in_chans=3,
64
+ embed_dim=768,
65
+ norm_layer=None,
66
+ flatten=True,
67
+ bias=True,
68
+ ):
69
+ super().__init__()
70
+ self.flatten = flatten
71
+ self.proj = nn.Conv2d(in_chans,
72
+ embed_dim,
73
+ kernel_size=patch_size,
74
+ stride=patch_size,
75
+ bias=bias)
76
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
77
+
78
+ def forward(self, x):
79
+ x = self.proj(x)
80
+ if self.flatten:
81
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
82
+ x = self.norm(x)
83
+ return x
84
+
85
+
86
+ class DropPath(nn.Module):
87
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
88
+ """
89
+ def __init__(self, drop_prob=None):
90
+ super(DropPath, self).__init__()
91
+ self.drop_prob = drop_prob
92
+
93
+ def forward(self, x):
94
+ return drop_path(x, self.drop_prob, self.training)
95
+
96
+
97
+ class RMSNorm(nn.Module):
98
+ def __init__(self, dim, eps=1e-6):
99
+ super().__init__()
100
+ self.dim = dim
101
+ self.eps = eps
102
+ self.weight = nn.Parameter(torch.ones(dim))
103
+
104
+ def forward(self, x):
105
+ return self._norm(x.float()).type_as(x) * self.weight
106
+
107
+ def _norm(self, x):
108
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
109
+
110
+
111
+ class Mlp(nn.Module):
112
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
113
+ """
114
+ def __init__(self,
115
+ in_features,
116
+ hidden_features=None,
117
+ out_features=None,
118
+ act_layer=nn.GELU,
119
+ drop=0.):
120
+ super().__init__()
121
+ out_features = out_features or in_features
122
+ hidden_features = hidden_features or in_features
123
+ self.fc1 = nn.Linear(in_features, hidden_features)
124
+ self.act = act_layer()
125
+ self.fc2 = nn.Linear(hidden_features, out_features)
126
+ self.drop = nn.Dropout(drop)
127
+
128
+ def forward(self, x):
129
+ x = self.fc1(x)
130
+ x = self.act(x)
131
+ x = self.drop(x)
132
+ x = self.fc2(x)
133
+ x = self.drop(x)
134
+ return x
135
+
136
+
137
+ class TimestepEmbedder(nn.Module):
138
+ """
139
+ Embeds scalar timesteps into vector representations.
140
+ """
141
+ def __init__(self, hidden_size, frequency_embedding_size=256):
142
+ super().__init__()
143
+ self.mlp = nn.Sequential(
144
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
145
+ nn.SiLU(),
146
+ nn.Linear(hidden_size, hidden_size, bias=True),
147
+ )
148
+ self.frequency_embedding_size = frequency_embedding_size
149
+
150
+ @staticmethod
151
+ def timestep_embedding(t, dim, max_period=10000):
152
+ """
153
+ Create sinusoidal timestep embeddings.
154
+ :param t: a 1-D Tensor of N indices, one per batch element.
155
+ These may be fractional.
156
+ :param dim: the dimension of the output.
157
+ :param max_period: controls the minimum frequency of the embeddings.
158
+ :return: an (N, D) Tensor of positional embeddings.
159
+ """
160
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
161
+ half = dim // 2
162
+ freqs = torch.exp(
163
+ -math.log(max_period) *
164
+ torch.arange(start=0, end=half, dtype=torch.float32) /
165
+ half).to(device=t.device)
166
+ args = t[:, None].float() * freqs[None]
167
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
168
+ if dim % 2:
169
+ embedding = torch.cat(
170
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
171
+ return embedding
172
+
173
+ def forward(self, t):
174
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
175
+ t_emb = self.mlp(t_freq)
176
+ return t_emb
177
+
178
+
179
+ class DiTACEBlock(nn.Module):
180
+ def __init__(self,
181
+ hidden_size,
182
+ num_heads,
183
+ mlp_ratio=4.0,
184
+ drop_path=0.,
185
+ window_size=0,
186
+ backend=None,
187
+ use_condition=True,
188
+ qk_norm=False,
189
+ **block_kwargs):
190
+ super().__init__()
191
+ self.hidden_size = hidden_size
192
+ self.use_condition = use_condition
193
+ self.norm1 = nn.LayerNorm(hidden_size,
194
+ elementwise_affine=False,
195
+ eps=1e-6)
196
+ self.attn = MultiHeadAttention(hidden_size,
197
+ num_heads=num_heads,
198
+ qkv_bias=True,
199
+ backend=backend,
200
+ qk_norm=qk_norm,
201
+ **block_kwargs)
202
+ if self.use_condition:
203
+ self.cross_attn = MultiHeadAttention(
204
+ hidden_size,
205
+ context_dim=hidden_size,
206
+ num_heads=num_heads,
207
+ qkv_bias=True,
208
+ backend=backend,
209
+ qk_norm=qk_norm,
210
+ **block_kwargs)
211
+ self.norm2 = nn.LayerNorm(hidden_size,
212
+ elementwise_affine=False,
213
+ eps=1e-6)
214
+ # to be compatible with lower version pytorch
215
+ approx_gelu = lambda: nn.GELU(approximate='tanh')
216
+ self.mlp = Mlp(in_features=hidden_size,
217
+ hidden_features=int(hidden_size * mlp_ratio),
218
+ act_layer=approx_gelu,
219
+ drop=0)
220
+ self.drop_path = DropPath(
221
+ drop_path) if drop_path > 0. else nn.Identity()
222
+ self.window_size = window_size
223
+ self.scale_shift_table = nn.Parameter(
224
+ torch.randn(6, hidden_size) / hidden_size**0.5)
225
+
226
+ def forward(self, x, y, t, **kwargs):
227
+ B = x.size(0)
228
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
229
+ self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
230
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
231
+ shift_msa.squeeze(1), scale_msa.squeeze(1), gate_msa.squeeze(1),
232
+ shift_mlp.squeeze(1), scale_mlp.squeeze(1), gate_mlp.squeeze(1))
233
+ x = x + self.drop_path(gate_msa * self.attn(
234
+ modulate(self.norm1(x), shift_msa, scale_msa, unsqueeze=False), **
235
+ kwargs))
236
+ if self.use_condition:
237
+ x = x + self.cross_attn(x, context=y, **kwargs)
238
+
239
+ x = x + self.drop_path(gate_mlp * self.mlp(
240
+ modulate(self.norm2(x), shift_mlp, scale_mlp, unsqueeze=False)))
241
+ return x
242
+
243
+
244
+ class MultiHeadAttention(nn.Module):
245
+ def __init__(self,
246
+ dim,
247
+ context_dim=None,
248
+ num_heads=None,
249
+ head_dim=None,
250
+ attn_drop=0.0,
251
+ qkv_bias=False,
252
+ dropout=0.0,
253
+ backend=None,
254
+ qk_norm=False,
255
+ eps=1e-6,
256
+ **block_kwargs):
257
+ super().__init__()
258
+ # consider head_dim first, then num_heads
259
+ num_heads = dim // head_dim if head_dim else num_heads
260
+ head_dim = dim // num_heads
261
+ assert num_heads * head_dim == dim
262
+ context_dim = context_dim or dim
263
+ self.dim = dim
264
+ self.context_dim = context_dim
265
+ self.num_heads = num_heads
266
+ self.head_dim = head_dim
267
+ self.scale = math.pow(head_dim, -0.25)
268
+ # layers
269
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
270
+ self.k = nn.Linear(context_dim, dim, bias=qkv_bias)
271
+ self.v = nn.Linear(context_dim, dim, bias=qkv_bias)
272
+ self.o = nn.Linear(dim, dim)
273
+ self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
274
+ self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
275
+
276
+ self.dropout = nn.Dropout(dropout)
277
+ self.attention_op = None
278
+ self.attn_drop = nn.Dropout(attn_drop)
279
+ self.backend = backend
280
+ assert self.backend in ('flash_attn', 'xformer_attn', 'pytorch_attn',
281
+ None)
282
+ if FLASHATTN_IS_AVAILABLE and self.backend in ('flash_attn', None):
283
+ self.backend = 'flash_attn'
284
+ self.softmax_scale = block_kwargs.get('softmax_scale', None)
285
+ self.causal = block_kwargs.get('causal', False)
286
+ self.window_size = block_kwargs.get('window_size', (-1, -1))
287
+ self.deterministic = block_kwargs.get('deterministic', False)
288
+ else:
289
+ raise NotImplementedError
290
+
291
+ def flash_attn(self, x, context=None, **kwargs):
292
+ '''
293
+ The implementation will be very slow when mask is not None,
294
+ because we need rearange the x/context features according to mask.
295
+ Args:
296
+ x:
297
+ context:
298
+ mask:
299
+ **kwargs:
300
+ Returns: x
301
+ '''
302
+ dtype = kwargs.get('dtype', torch.float16)
303
+
304
+ def half(x):
305
+ return x if x.dtype in [torch.float16, torch.bfloat16
306
+ ] else x.to(dtype)
307
+
308
+ x_shapes = kwargs['x_shapes']
309
+ freqs = kwargs['freqs']
310
+ self_x_len = kwargs['self_x_len']
311
+ cross_x_len = kwargs['cross_x_len']
312
+ txt_lens = kwargs['txt_lens']
313
+ n, d = self.num_heads, self.head_dim
314
+
315
+ if context is None:
316
+ # self-attn
317
+ q = self.norm_q(self.q(x)).view(-1, n, d)
318
+ k = self.norm_q(self.k(x)).view(-1, n, d)
319
+ v = self.v(x).view(-1, n, d)
320
+ q = rope_apply(q, self_x_len, x_shapes, freqs, pad=False)
321
+ k = rope_apply(k, self_x_len, x_shapes, freqs, pad=False)
322
+ q_lens = k_lens = self_x_len
323
+ else:
324
+ # cross-attn
325
+ q = self.norm_q(self.q(x)).view(-1, n, d)
326
+ k = self.norm_q(self.k(context)).view(-1, n, d)
327
+ v = self.v(context).view(-1, n, d)
328
+ q_lens = cross_x_len
329
+ k_lens = txt_lens
330
+
331
+ cu_seqlens_q = torch.cat([q_lens.new_zeros([1]),
332
+ q_lens]).cumsum(0, dtype=torch.int32)
333
+ cu_seqlens_k = torch.cat([k_lens.new_zeros([1]),
334
+ k_lens]).cumsum(0, dtype=torch.int32)
335
+ max_seqlen_q = q_lens.max()
336
+ max_seqlen_k = k_lens.max()
337
+
338
+ out_dtype = q.dtype
339
+ q, k, v = half(q), half(k), half(v)
340
+ x = flash_attn_varlen_func(q,
341
+ k,
342
+ v,
343
+ cu_seqlens_q=cu_seqlens_q,
344
+ cu_seqlens_k=cu_seqlens_k,
345
+ max_seqlen_q=max_seqlen_q,
346
+ max_seqlen_k=max_seqlen_k,
347
+ dropout_p=self.attn_drop.p,
348
+ softmax_scale=self.softmax_scale,
349
+ causal=self.causal,
350
+ window_size=self.window_size,
351
+ deterministic=self.deterministic)
352
+
353
+ x = x.type(out_dtype)
354
+ x = x.reshape(-1, n * d)
355
+ x = self.o(x)
356
+ x = self.dropout(x)
357
+ return x
358
+
359
+ def forward(self, x, context=None, **kwargs):
360
+ x = getattr(self, self.backend)(x, context=context, **kwargs)
361
+ return x
362
+
363
+
364
+ class T2IFinalLayer(nn.Module):
365
+ """
366
+ The final layer of PixArt.
367
+ """
368
+ def __init__(self, hidden_size, patch_size, out_channels):
369
+ super().__init__()
370
+ self.norm_final = nn.LayerNorm(hidden_size,
371
+ elementwise_affine=False,
372
+ eps=1e-6)
373
+ self.linear = nn.Linear(hidden_size,
374
+ patch_size * patch_size * out_channels,
375
+ bias=True)
376
+ self.scale_shift_table = nn.Parameter(
377
+ torch.randn(2, hidden_size) / hidden_size**0.5)
378
+ self.out_channels = out_channels
379
+
380
+ def forward(self, x, t):
381
+ shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2,
382
+ dim=1)
383
+ shift, scale = shift.squeeze(1), scale.squeeze(1)
384
+ x = modulate(self.norm_final(x), shift, scale)
385
+ x = self.linear(x)
386
+ return x
modules/model/backbone/pos_embed.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from einops import rearrange
3
+
4
+ import torch
5
+ import torch.cuda.amp as amp
6
+ import torch.nn.functional as F
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+ def frame_pad(x, seq_len, shapes):
10
+ max_h, max_w = np.max(shapes, 0)
11
+ frames = []
12
+ cur_len = 0
13
+ for h, w in shapes:
14
+ frame_len = h * w
15
+ frames.append(
16
+ F.pad(
17
+ x[cur_len:cur_len + frame_len].view(h, w, -1),
18
+ (0, 0, 0, max_w - w, 0, max_h - h)) # .view(max_h * max_w, -1)
19
+ )
20
+ cur_len += frame_len
21
+ if cur_len >= seq_len:
22
+ break
23
+ return torch.stack(frames)
24
+
25
+
26
+ def frame_unpad(x, shapes):
27
+ max_h, max_w = np.max(shapes, 0)
28
+ x = rearrange(x, '(b h w) n c -> b h w n c', h=max_h, w=max_w)
29
+ frames = []
30
+ for i, (h, w) in enumerate(shapes):
31
+ if i >= len(x):
32
+ break
33
+ frames.append(rearrange(x[i, :h, :w], 'h w n c -> (h w) n c'))
34
+ return torch.concat(frames)
35
+
36
+
37
+ @amp.autocast(enabled=False)
38
+ def rope_apply_multires(x, x_lens, x_shapes, freqs, pad=True):
39
+ """
40
+ x: [B*L, N, C].
41
+ x_lens: [B].
42
+ x_shapes: [B, F, 2].
43
+ freqs: [M, C // 2].
44
+ """
45
+ n, c = x.size(1), x.size(2) // 2
46
+ # split freqs
47
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
48
+ # loop over samples
49
+ output = []
50
+ st = 0
51
+ for i, (seq_len,
52
+ shapes) in enumerate(zip(x_lens.tolist(), x_shapes.tolist())):
53
+ x_i = frame_pad(x[st:st + seq_len], seq_len, shapes) # f, h, w, c
54
+ f, h, w = x_i.shape[:3]
55
+ pad_seq_len = f * h * w
56
+ # precompute multipliers
57
+ x_i = torch.view_as_complex(
58
+ x_i.to(torch.float64).reshape(pad_seq_len, n, -1, 2))
59
+ freqs_i = torch.cat([
60
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
61
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
62
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
63
+ ],
64
+ dim=-1).reshape(pad_seq_len, 1, -1)
65
+ # apply rotary embedding
66
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2).type_as(x)
67
+ x_i = frame_unpad(x_i, shapes)
68
+ # append to collection
69
+ output.append(x_i)
70
+ st += seq_len
71
+ return pad_sequence(output) if pad else torch.concat(output)
72
+
73
+
74
+ @amp.autocast(enabled=False)
75
+ def rope_params(max_seq_len, dim, theta=10000):
76
+ """
77
+ Precompute the frequency tensor for complex exponentials.
78
+ """
79
+ assert dim % 2 == 0
80
+ freqs = torch.outer(
81
+ torch.arange(max_seq_len),
82
+ 1.0 / torch.pow(theta,
83
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
84
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
85
+ return freqs
modules/model/diffusion/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ from .diffusions import ACEDiffusion
5
+ from .samplers import DDIMSampler
6
+ from .schedules import LinearScheduler
modules/model/diffusion/diffusions.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import math
4
+ import os
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ from tqdm import trange
9
+
10
+ from scepter.modules.model.registry import (DIFFUSION_SAMPLERS, DIFFUSIONS,
11
+ NOISE_SCHEDULERS)
12
+ from scepter.modules.utils.config import Config, dict_to_yaml
13
+ from scepter.modules.utils.distribute import we
14
+ from scepter.modules.utils.file_system import FS
15
+
16
+
17
+ @DIFFUSIONS.register_class()
18
+ class ACEDiffusion(object):
19
+ para_dict = {
20
+ 'NOISE_SCHEDULER': {},
21
+ 'SAMPLER_SCHEDULER': {},
22
+ 'MIN_SNR_GAMMA': {
23
+ 'value': None,
24
+ 'description': 'The minimum SNR gamma value for the loss function.'
25
+ },
26
+ 'PREDICTION_TYPE': {
27
+ 'value': 'eps',
28
+ 'description':
29
+ 'The type of prediction to use for the loss function.'
30
+ }
31
+ }
32
+
33
+ def __init__(self, cfg, logger=None):
34
+ super(ACEDiffusion, self).__init__()
35
+ self.logger = logger
36
+ self.cfg = cfg
37
+ self.init_params()
38
+
39
+ def init_params(self):
40
+ self.min_snr_gamma = self.cfg.get('MIN_SNR_GAMMA', None)
41
+ self.prediction_type = self.cfg.get('PREDICTION_TYPE', 'eps')
42
+ self.noise_scheduler = NOISE_SCHEDULERS.build(self.cfg.NOISE_SCHEDULER,
43
+ logger=self.logger)
44
+ self.sampler_scheduler = NOISE_SCHEDULERS.build(self.cfg.get(
45
+ 'SAMPLER_SCHEDULER', self.cfg.NOISE_SCHEDULER),
46
+ logger=self.logger)
47
+ self.num_timesteps = self.noise_scheduler.num_timesteps
48
+ if self.cfg.have('WORK_DIR') and we.rank == 0:
49
+ schedule_visualization = os.path.join(self.cfg.WORK_DIR,
50
+ 'noise_schedule.png')
51
+ with FS.put_to(schedule_visualization) as local_path:
52
+ self.noise_scheduler.plot_noise_sampling_map(local_path)
53
+ schedule_visualization = os.path.join(self.cfg.WORK_DIR,
54
+ 'sampler_schedule.png')
55
+ with FS.put_to(schedule_visualization) as local_path:
56
+ self.sampler_scheduler.plot_noise_sampling_map(local_path)
57
+
58
+ def sample(self,
59
+ noise,
60
+ model,
61
+ model_kwargs={},
62
+ steps=20,
63
+ sampler=None,
64
+ use_dynamic_cfg=False,
65
+ guide_scale=None,
66
+ guide_rescale=None,
67
+ show_progress=False,
68
+ return_intermediate=None,
69
+ intermediate_callback=None,
70
+ **kwargs):
71
+ assert isinstance(steps, (int, torch.LongTensor))
72
+ assert return_intermediate in (None, 'x0', 'xt')
73
+ assert isinstance(sampler, (str, dict, Config))
74
+ intermediates = []
75
+
76
+ def callback_fn(x_t, t, sigma=None, alpha=None):
77
+ timestamp = t
78
+ t = t.repeat(len(x_t)).round().long().to(x_t.device)
79
+ sigma = sigma.repeat(len(x_t), *([1] * (len(sigma.shape) - 1)))
80
+ alpha = alpha.repeat(len(x_t), *([1] * (len(alpha.shape) - 1)))
81
+
82
+ if guide_scale is None or guide_scale == 1.0:
83
+ out = model(x=x_t, t=t, **model_kwargs)
84
+ else:
85
+ if use_dynamic_cfg:
86
+ guidance_scale = 1 + guide_scale * (
87
+ (1 - math.cos(math.pi * (
88
+ (steps - timestamp.item()) / steps)**5.0)) / 2)
89
+ else:
90
+ guidance_scale = guide_scale
91
+ y_out = model(x=x_t, t=t, **model_kwargs[0])
92
+ u_out = model(x=x_t, t=t, **model_kwargs[1])
93
+ out = u_out + guidance_scale * (y_out - u_out)
94
+ if guide_rescale is not None and guide_rescale > 0.0:
95
+ ratio = (
96
+ y_out.flatten(1).std(dim=1) /
97
+ (out.flatten(1).std(dim=1) + 1e-12)).view((-1, ) + (1, ) *
98
+ (y_out.ndim - 1))
99
+ out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
100
+
101
+ if self.prediction_type == 'x0':
102
+ x0 = out
103
+ elif self.prediction_type == 'eps':
104
+ x0 = (x_t - sigma * out) / alpha
105
+ elif self.prediction_type == 'v':
106
+ x0 = alpha * x_t - sigma * out
107
+ else:
108
+ raise NotImplementedError(
109
+ f'prediction_type {self.prediction_type} not implemented')
110
+
111
+ return x0
112
+
113
+ sampler_ins = self.get_sampler(sampler)
114
+
115
+ # this is ignored for schnell
116
+ sampler_output = sampler_ins.preprare_sampler(
117
+ noise,
118
+ steps=steps,
119
+ prediction_type=self.prediction_type,
120
+ scheduler_ins=self.sampler_scheduler,
121
+ callback_fn=callback_fn)
122
+
123
+ for _ in trange(steps, disable=not show_progress):
124
+ trange.desc = sampler_output.msg
125
+ sampler_output = sampler_ins.step(sampler_output)
126
+ if return_intermediate == 'x_0':
127
+ intermediates.append(sampler_output.x_0)
128
+ elif return_intermediate == 'x_t':
129
+ intermediates.append(sampler_output.x_t)
130
+ if intermediate_callback is not None:
131
+ intermediate_callback(intermediates[-1])
132
+ return (sampler_output.x_0, intermediates
133
+ ) if return_intermediate is not None else sampler_output.x_0
134
+
135
+ def loss(self,
136
+ x_0,
137
+ model,
138
+ model_kwargs={},
139
+ reduction='mean',
140
+ noise=None,
141
+ **kwargs):
142
+ # use noise scheduler to add noise
143
+ if noise is None:
144
+ noise = torch.randn_like(x_0)
145
+ schedule_output = self.noise_scheduler.add_noise(x_0, noise, **kwargs)
146
+ x_t, t, sigma, alpha = schedule_output.x_t, schedule_output.t, schedule_output.sigma, schedule_output.alpha
147
+ out = model(x=x_t, t=t, **model_kwargs)
148
+
149
+ # mse loss
150
+ target = {
151
+ 'eps': noise,
152
+ 'x0': x_0,
153
+ 'v': alpha * noise - sigma * x_0
154
+ }[self.prediction_type]
155
+
156
+ loss = (out - target).pow(2)
157
+ if reduction == 'mean':
158
+ loss = loss.flatten(1).mean(dim=1)
159
+
160
+ if self.min_snr_gamma is not None:
161
+ alphas = self.noise_scheduler.alphas.to(x_0.device)[t]
162
+ sigmas = self.noise_scheduler.sigmas.pow(2).to(x_0.device)[t]
163
+ snrs = (alphas / sigmas).clamp(min=1e-20)
164
+ min_snrs = snrs.clamp(max=self.min_snr_gamma)
165
+ weights = min_snrs / snrs
166
+ else:
167
+ weights = 1
168
+
169
+ loss = loss * weights
170
+ return loss
171
+
172
+ def get_sampler(self, sampler):
173
+ if isinstance(sampler, str):
174
+ if sampler not in DIFFUSION_SAMPLERS.class_map:
175
+ if self.logger is not None:
176
+ self.logger.info(
177
+ f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'
178
+ )
179
+ else:
180
+ print(
181
+ f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'
182
+ )
183
+ return None
184
+ sampler_cfg = Config(cfg_dict={'NAME': sampler}, load=False)
185
+ sampler_ins = DIFFUSION_SAMPLERS.build(sampler_cfg,
186
+ logger=self.logger)
187
+ elif isinstance(sampler, (Config, dict, OrderedDict)):
188
+ if isinstance(sampler, (dict, OrderedDict)):
189
+ sampler = Config(
190
+ cfg_dict={k.upper(): v
191
+ for k, v in dict(sampler).items()},
192
+ load=False)
193
+ sampler_ins = DIFFUSION_SAMPLERS.build(sampler, logger=self.logger)
194
+ else:
195
+ raise NotImplementedError
196
+ return sampler_ins
197
+
198
+ def __repr__(self) -> str:
199
+ return f'{self.__class__.__name__}' + ' ' + super().__repr__()
200
+
201
+ @staticmethod
202
+ def get_config_template():
203
+ return dict_to_yaml('DIFFUSIONS',
204
+ __class__.__name__,
205
+ ACEDiffusion.para_dict,
206
+ set_name=True)
modules/model/diffusion/samplers.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import torch
4
+
5
+ from scepter.modules.model.registry import DIFFUSION_SAMPLERS
6
+ from scepter.modules.model.diffusion.samplers import BaseDiffusionSampler
7
+ from scepter.modules.model.diffusion.util import _i
8
+
9
+ def _i(tensor, t, x):
10
+ """
11
+ Index tensor using t and format the output according to x.
12
+ """
13
+ shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
14
+ if isinstance(t, torch.Tensor):
15
+ t = t.to(tensor.device)
16
+ return tensor[t].view(shape).to(x.device)
17
+
18
+
19
+ @DIFFUSION_SAMPLERS.register_class('ddim')
20
+ class DDIMSampler(BaseDiffusionSampler):
21
+ def init_params(self):
22
+ super().init_params()
23
+ self.eta = self.cfg.get('ETA', 0.)
24
+ self.discretization_type = self.cfg.get('DISCRETIZATION_TYPE',
25
+ 'trailing')
26
+
27
+ def preprare_sampler(self,
28
+ noise,
29
+ steps=20,
30
+ scheduler_ins=None,
31
+ prediction_type='',
32
+ sigmas=None,
33
+ betas=None,
34
+ alphas=None,
35
+ callback_fn=None,
36
+ **kwargs):
37
+ output = super().preprare_sampler(noise, steps, scheduler_ins,
38
+ prediction_type, sigmas, betas,
39
+ alphas, callback_fn, **kwargs)
40
+ sigmas = output.sigmas
41
+ sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
42
+ sigmas_vp = (sigmas**2 / (1 + sigmas**2))**0.5
43
+ sigmas_vp[sigmas == float('inf')] = 1.
44
+ output.add_custom_field('sigmas_vp', sigmas_vp)
45
+ return output
46
+
47
+ def step(self, sampler_output):
48
+ x_t = sampler_output.x_t
49
+ step = sampler_output.step
50
+ t = sampler_output.ts[step]
51
+ sigmas_vp = sampler_output.sigmas_vp.to(x_t.device)
52
+ alpha_init = _i(sampler_output.alphas_init, step, x_t[:1])
53
+ sigma_init = _i(sampler_output.sigmas_init, step, x_t[:1])
54
+
55
+ x = sampler_output.callback_fn(x_t, t, sigma_init, alpha_init)
56
+ noise_factor = self.eta * (sigmas_vp[step + 1]**2 /
57
+ sigmas_vp[step]**2 *
58
+ (1 - (1 - sigmas_vp[step]**2) /
59
+ (1 - sigmas_vp[step + 1]**2)))
60
+ d = (x_t - (1 - sigmas_vp[step]**2)**0.5 * x) / sigmas_vp[step]
61
+ x = (1 - sigmas_vp[step + 1] ** 2) ** 0.5 * x + \
62
+ (sigmas_vp[step + 1] ** 2 - noise_factor ** 2) ** 0.5 * d
63
+ sampler_output.x_0 = x
64
+ if sigmas_vp[step + 1] > 0:
65
+ x += noise_factor * torch.randn_like(x)
66
+ sampler_output.x_t = x
67
+ sampler_output.step += 1
68
+ sampler_output.msg = f'step {step}'
69
+ return sampler_output
modules/model/diffusion/schedules.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import torch
4
+
5
+ from scepter.modules.model.registry import NOISE_SCHEDULERS
6
+ from scepter.modules.model.diffusion.schedules import BaseNoiseScheduler
7
+
8
+
9
+ @NOISE_SCHEDULERS.register_class()
10
+ class LinearScheduler(BaseNoiseScheduler):
11
+ para_dict = {}
12
+
13
+ def init_params(self):
14
+ super().init_params()
15
+ self.beta_min = self.cfg.get('BETA_MIN', 0.00085)
16
+ self.beta_max = self.cfg.get('BETA_MAX', 0.012)
17
+
18
+ def betas_to_sigmas(self, betas):
19
+ return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
20
+
21
+ def get_schedule(self):
22
+ betas = torch.linspace(self.beta_min,
23
+ self.beta_max,
24
+ self.num_timesteps,
25
+ dtype=torch.float32)
26
+ sigmas = self.betas_to_sigmas(betas)
27
+ self._sigmas = sigmas
28
+ self._betas = betas
29
+ self._alphas = torch.sqrt(1 - sigmas**2)
30
+ self._timesteps = torch.arange(len(sigmas), dtype=torch.float32)
modules/model/embedder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .embedder import ACETextEmbedder
modules/model/embedder/embedder.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import warnings
4
+ from contextlib import nullcontext
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.utils.dlpack
9
+ from scepter.modules.model.embedder.base_embedder import BaseEmbedder
10
+ from scepter.modules.model.registry import EMBEDDERS
11
+ from scepter.modules.model.tokenizer.tokenizer_component import (
12
+ basic_clean, canonicalize, heavy_clean, whitespace_clean)
13
+ from scepter.modules.utils.config import dict_to_yaml
14
+ from scepter.modules.utils.distribute import we
15
+ from scepter.modules.utils.file_system import FS
16
+
17
+ try:
18
+ from transformers import AutoTokenizer, T5EncoderModel
19
+ except Exception as e:
20
+ warnings.warn(
21
+ f'Import transformers error, please deal with this problem: {e}')
22
+
23
+
24
+ @EMBEDDERS.register_class()
25
+ class ACETextEmbedder(BaseEmbedder):
26
+ """
27
+ Uses the OpenCLIP transformer encoder for text
28
+ """
29
+ """
30
+ Uses the OpenCLIP transformer encoder for text
31
+ """
32
+ para_dict = {
33
+ 'PRETRAINED_MODEL': {
34
+ 'value':
35
+ 'google/umt5-small',
36
+ 'description':
37
+ 'Pretrained Model for umt5, modelcard path or local path.'
38
+ },
39
+ 'TOKENIZER_PATH': {
40
+ 'value': 'google/umt5-small',
41
+ 'description':
42
+ 'Tokenizer Path for umt5, modelcard path or local path.'
43
+ },
44
+ 'FREEZE': {
45
+ 'value': True,
46
+ 'description': ''
47
+ },
48
+ 'USE_GRAD': {
49
+ 'value': False,
50
+ 'description': 'Compute grad or not.'
51
+ },
52
+ 'CLEAN': {
53
+ 'value':
54
+ 'whitespace',
55
+ 'description':
56
+ 'Set the clean strtegy for tokenizer, used when TOKENIZER_PATH is not None.'
57
+ },
58
+ 'LAYER': {
59
+ 'value': 'last',
60
+ 'description': ''
61
+ },
62
+ 'LEGACY': {
63
+ 'value':
64
+ True,
65
+ 'description':
66
+ 'Whether use legacy returnd feature or not ,default True.'
67
+ }
68
+ }
69
+
70
+ def __init__(self, cfg, logger=None):
71
+ super().__init__(cfg, logger=logger)
72
+ pretrained_path = cfg.get('PRETRAINED_MODEL', None)
73
+ self.t5_dtype = cfg.get('T5_DTYPE', 'float32')
74
+ assert pretrained_path
75
+ with FS.get_dir_to_local_dir(pretrained_path,
76
+ wait_finish=True) as local_path:
77
+ self.model = T5EncoderModel.from_pretrained(
78
+ local_path,
79
+ torch_dtype=getattr(
80
+ torch,
81
+ 'float' if self.t5_dtype == 'float32' else self.t5_dtype))
82
+ tokenizer_path = cfg.get('TOKENIZER_PATH', None)
83
+ self.length = cfg.get('LENGTH', 77)
84
+
85
+ self.use_grad = cfg.get('USE_GRAD', False)
86
+ self.clean = cfg.get('CLEAN', 'whitespace')
87
+ self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
88
+ if tokenizer_path:
89
+ self.tokenize_kargs = {'return_tensors': 'pt'}
90
+ with FS.get_dir_to_local_dir(tokenizer_path,
91
+ wait_finish=True) as local_path:
92
+ if self.added_identifier is not None and isinstance(
93
+ self.added_identifier, list):
94
+ self.tokenizer = AutoTokenizer.from_pretrained(local_path)
95
+ else:
96
+ self.tokenizer = AutoTokenizer.from_pretrained(local_path)
97
+ if self.length is not None:
98
+ self.tokenize_kargs.update({
99
+ 'padding': 'max_length',
100
+ 'truncation': True,
101
+ 'max_length': self.length
102
+ })
103
+ self.eos_token = self.tokenizer(
104
+ self.tokenizer.eos_token)['input_ids'][0]
105
+ else:
106
+ self.tokenizer = None
107
+ self.tokenize_kargs = {}
108
+
109
+ self.use_grad = cfg.get('USE_GRAD', False)
110
+ self.clean = cfg.get('CLEAN', 'whitespace')
111
+
112
+ def freeze(self):
113
+ self.model = self.model.eval()
114
+ for param in self.parameters():
115
+ param.requires_grad = False
116
+
117
+ # encode && encode_text
118
+ def forward(self, tokens, return_mask=False, use_mask=True):
119
+ # tokenization
120
+ embedding_context = nullcontext if self.use_grad else torch.no_grad
121
+ with embedding_context():
122
+ if use_mask:
123
+ x = self.model(tokens.input_ids.to(we.device_id),
124
+ tokens.attention_mask.to(we.device_id))
125
+ else:
126
+ x = self.model(tokens.input_ids.to(we.device_id))
127
+ x = x.last_hidden_state
128
+
129
+ if return_mask:
130
+ return x.detach() + 0.0, tokens.attention_mask.to(we.device_id)
131
+ else:
132
+ return x.detach() + 0.0, None
133
+
134
+ def _clean(self, text):
135
+ if self.clean == 'whitespace':
136
+ text = whitespace_clean(basic_clean(text))
137
+ elif self.clean == 'lower':
138
+ text = whitespace_clean(basic_clean(text)).lower()
139
+ elif self.clean == 'canonicalize':
140
+ text = canonicalize(basic_clean(text))
141
+ elif self.clean == 'heavy':
142
+ text = heavy_clean(basic_clean(text))
143
+ return text
144
+
145
+ def encode(self, text, return_mask=False, use_mask=True):
146
+ if isinstance(text, str):
147
+ text = [text]
148
+ if self.clean:
149
+ text = [self._clean(u) for u in text]
150
+ assert self.tokenizer is not None
151
+ cont, mask = [], []
152
+ with torch.autocast(device_type='cuda',
153
+ enabled=self.t5_dtype in ('float16', 'bfloat16'),
154
+ dtype=getattr(torch, self.t5_dtype)):
155
+ for tt in text:
156
+ tokens = self.tokenizer([tt], **self.tokenize_kargs)
157
+ one_cont, one_mask = self(tokens,
158
+ return_mask=return_mask,
159
+ use_mask=use_mask)
160
+ cont.append(one_cont)
161
+ mask.append(one_mask)
162
+ if return_mask:
163
+ return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
164
+ else:
165
+ return torch.cat(cont, dim=0)
166
+
167
+ def encode_list(self, text_list, return_mask=True):
168
+ cont_list = []
169
+ mask_list = []
170
+ for pp in text_list:
171
+ cont, cont_mask = self.encode(pp, return_mask=return_mask)
172
+ cont_list.append(cont)
173
+ mask_list.append(cont_mask)
174
+ if return_mask:
175
+ return cont_list, mask_list
176
+ else:
177
+ return cont_list
178
+
179
+ @staticmethod
180
+ def get_config_template():
181
+ return dict_to_yaml('MODELS',
182
+ __class__.__name__,
183
+ ACETextEmbedder.para_dict,
184
+ set_name=True)
modules/model/network/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .ldm_ace import LdmACE
modules/model/network/ldm_ace.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import copy
4
+ import random
5
+ from contextlib import nullcontext
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from scepter.modules.model.network.ldm import LatentDiffusion
12
+ from scepter.modules.model.registry import MODELS
13
+ from scepter.modules.utils.config import dict_to_yaml
14
+ from scepter.modules.utils.distribute import we
15
+
16
+ from ..utils.basic_utils import (
17
+ check_list_of_list,
18
+ pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor,
19
+ to_device,
20
+ unpack_tensor_into_imagelist
21
+ )
22
+
23
+
24
+ class TextEmbedding(nn.Module):
25
+ def __init__(self, embedding_shape):
26
+ super().__init__()
27
+ self.pos = nn.Parameter(data=torch.zeros(embedding_shape))
28
+
29
+
30
+ @MODELS.register_class()
31
+ class LdmACE(LatentDiffusion):
32
+ para_dict = LatentDiffusion.para_dict
33
+ para_dict['DECODER_BIAS'] = {'value': 0, 'description': ''}
34
+
35
+ def __init__(self, cfg, logger=None):
36
+ super().__init__(cfg, logger=logger)
37
+ self.interpolate_func = lambda x: (F.interpolate(
38
+ x.unsqueeze(0),
39
+ scale_factor=1 / self.size_factor,
40
+ mode='nearest-exact') if x is not None else None)
41
+
42
+ self.text_indentifers = cfg.get('TEXT_IDENTIFIER', [])
43
+ self.use_text_pos_embeddings = cfg.get('USE_TEXT_POS_EMBEDDINGS',
44
+ False)
45
+ if self.use_text_pos_embeddings:
46
+ self.text_position_embeddings = TextEmbedding(
47
+ (10, 4096)).eval().requires_grad_(False)
48
+ else:
49
+ self.text_position_embeddings = None
50
+
51
+ self.logger.info(self.model)
52
+
53
+ @torch.no_grad()
54
+ def encode_first_stage(self, x, **kwargs):
55
+ return [
56
+ self.scale_factor *
57
+ self.first_stage_model._encode(i.unsqueeze(0).to(torch.float16))
58
+ for i in x
59
+ ]
60
+
61
+ @torch.no_grad()
62
+ def decode_first_stage(self, z):
63
+ return [
64
+ self.first_stage_model._decode(1. / self.scale_factor *
65
+ i.to(torch.float16)) for i in z
66
+ ]
67
+
68
+ def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
69
+ if self.use_text_pos_embeddings and not torch.sum(
70
+ self.text_position_embeddings.pos) > 0:
71
+ identifier_cont, identifier_cont_mask = getattr(
72
+ self.cond_stage_model, 'encode')(self.text_indentifers,
73
+ return_mask=True)
74
+ self.text_position_embeddings.load_state_dict(
75
+ {'pos': identifier_cont[:, 0, :]})
76
+ cont_, cont_mask_ = [], []
77
+ for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):
78
+ if isinstance(pp, list):
79
+ cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])
80
+ cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])
81
+ else:
82
+ raise NotImplementedError
83
+
84
+ return cont_, cont_mask_
85
+
86
+ def limit_batch_data(self, batch_data_list, log_num):
87
+ if log_num and log_num > 0:
88
+ batch_data_list_limited = []
89
+ for sub_data in batch_data_list:
90
+ if sub_data is not None:
91
+ sub_data = sub_data[:log_num]
92
+ batch_data_list_limited.append(sub_data)
93
+ return batch_data_list_limited
94
+ else:
95
+ return batch_data_list
96
+
97
+ def forward_train(self,
98
+ edit_image=[],
99
+ edit_image_mask=[],
100
+ image=None,
101
+ image_mask=None,
102
+ noise=None,
103
+ prompt=[],
104
+ **kwargs):
105
+ '''
106
+ Args:
107
+ edit_image: list of list of edit_image
108
+ edit_image_mask: list of list of edit_image_mask
109
+ image: target image
110
+ image_mask: target image mask
111
+ noise: default is None, generate automaticly
112
+ prompt: list of list of text
113
+ **kwargs:
114
+ Returns:
115
+ '''
116
+ assert check_list_of_list(prompt) and check_list_of_list(
117
+ edit_image) and check_list_of_list(edit_image_mask)
118
+ assert len(edit_image) == len(edit_image_mask) == len(prompt)
119
+ assert self.cond_stage_model is not None
120
+ gc_seg = kwargs.pop('gc_seg', [])
121
+ gc_seg = int(gc_seg[0]) if len(gc_seg) > 0 else 0
122
+ context = {}
123
+
124
+ # process image
125
+ image = to_device(image)
126
+ x_start = self.encode_first_stage(image, **kwargs)
127
+ x_start, x_shapes = pack_imagelist_into_tensor(x_start) # B, C, L
128
+ n, _, _ = x_start.shape
129
+ t = torch.randint(0, self.num_timesteps, (n, ),
130
+ device=x_start.device).long()
131
+ context['x_shapes'] = x_shapes
132
+
133
+ # process image mask
134
+ image_mask = to_device(image_mask, strict=False)
135
+ context['x_mask'] = [self.interpolate_func(i) for i in image_mask
136
+ ] if image_mask is not None else [None] * n
137
+
138
+ # process text
139
+ # with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
140
+ prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt]
141
+ try:
142
+ cont, cont_mask = getattr(self.cond_stage_model,
143
+ 'encode_list')(prompt_, return_mask=True)
144
+ except Exception as e:
145
+ print(e, prompt_)
146
+ cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
147
+ cont_mask)
148
+ context['crossattn'] = cont
149
+
150
+ # process edit image & edit image mask
151
+ edit_image = [to_device(i, strict=False) for i in edit_image]
152
+ edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
153
+ e_img, e_mask = [], []
154
+ for u, m in zip(edit_image, edit_image_mask):
155
+ if m is None:
156
+ m = [None] * len(u) if u is not None else [None]
157
+ e_img.append(
158
+ self.encode_first_stage(u, **kwargs) if u is not None else u)
159
+ e_mask.append([
160
+ self.interpolate_func(i) if i is not None else None for i in m
161
+ ])
162
+ context['edit'], context['edit_mask'] = e_img, e_mask
163
+
164
+ # process loss
165
+ loss = self.diffusion.loss(
166
+ x_0=x_start,
167
+ t=t,
168
+ noise=noise,
169
+ model=self.model,
170
+ model_kwargs={
171
+ 'cond':
172
+ context,
173
+ 'mask':
174
+ cont_mask,
175
+ 'gc_seg':
176
+ gc_seg,
177
+ 'text_position_embeddings':
178
+ self.text_position_embeddings.pos if hasattr(
179
+ self.text_position_embeddings, 'pos') else None
180
+ },
181
+ **kwargs)
182
+ loss = loss.mean()
183
+ ret = {'loss': loss, 'probe_data': {'prompt': prompt}}
184
+ return ret
185
+
186
+ @torch.no_grad()
187
+ def forward_test(self,
188
+ edit_image=[],
189
+ edit_image_mask=[],
190
+ image=None,
191
+ image_mask=None,
192
+ prompt=[],
193
+ n_prompt=[],
194
+ sampler='ddim',
195
+ sample_steps=20,
196
+ guide_scale=4.5,
197
+ guide_rescale=0.5,
198
+ log_num=-1,
199
+ seed=2024,
200
+ **kwargs):
201
+
202
+ assert check_list_of_list(prompt) and check_list_of_list(
203
+ edit_image) and check_list_of_list(edit_image_mask)
204
+ assert len(edit_image) == len(edit_image_mask) == len(prompt)
205
+ assert self.cond_stage_model is not None
206
+ # gc_seg is unused
207
+ kwargs.pop('gc_seg', -1)
208
+ # prepare data
209
+ context, null_context = {}, {}
210
+
211
+ prompt, n_prompt, image, image_mask, edit_image, edit_image_mask = self.limit_batch_data(
212
+ [prompt, n_prompt, image, image_mask, edit_image, edit_image_mask],
213
+ log_num)
214
+ g = torch.Generator(device=we.device_id)
215
+ seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
216
+ g.manual_seed(seed)
217
+ n_prompt = copy.deepcopy(prompt)
218
+ # only modify the last prompt to be zero
219
+ for nn_p_id, nn_p in enumerate(n_prompt):
220
+ if isinstance(nn_p, str):
221
+ n_prompt[nn_p_id] = ['']
222
+ elif isinstance(nn_p, list):
223
+ n_prompt[nn_p_id][-1] = ''
224
+ else:
225
+ raise NotImplementedError
226
+ # process image
227
+ image = to_device(image)
228
+ x = self.encode_first_stage(image, **kwargs)
229
+ noise = [
230
+ torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
231
+ for i in x
232
+ ]
233
+ noise, x_shapes = pack_imagelist_into_tensor(noise)
234
+ context['x_shapes'] = null_context['x_shapes'] = x_shapes
235
+
236
+ # process image mask
237
+ image_mask = to_device(image_mask, strict=False)
238
+ cond_mask = [self.interpolate_func(i) for i in image_mask
239
+ ] if image_mask is not None else [None] * len(image)
240
+ context['x_mask'] = null_context['x_mask'] = cond_mask
241
+ # process text
242
+ # with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
243
+ prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt]
244
+ cont, cont_mask = getattr(self.cond_stage_model,
245
+ 'encode_list')(prompt_, return_mask=True)
246
+ cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
247
+ cont_mask)
248
+ null_cont, null_cont_mask = getattr(self.cond_stage_model,
249
+ 'encode_list')(n_prompt,
250
+ return_mask=True)
251
+ null_cont, null_cont_mask = self.cond_stage_embeddings(
252
+ prompt, edit_image, null_cont, null_cont_mask)
253
+ context['crossattn'] = cont
254
+ null_context['crossattn'] = null_cont
255
+
256
+ # processe edit image & edit image mask
257
+ edit_image = [to_device(i, strict=False) for i in edit_image]
258
+ edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
259
+ e_img, e_mask = [], []
260
+ for u, m in zip(edit_image, edit_image_mask):
261
+ if u is None:
262
+ continue
263
+ if m is None:
264
+ m = [None] * len(u)
265
+ e_img.append(self.encode_first_stage(u, **kwargs))
266
+ e_mask.append([self.interpolate_func(i) for i in m])
267
+ null_context['edit'] = context['edit'] = e_img
268
+ null_context['edit_mask'] = context['edit_mask'] = e_mask
269
+
270
+ # process sample
271
+ model = self.model_ema if self.use_ema and self.eval_ema else self.model
272
+ embedding_context = model.no_sync if isinstance(model, torch.distributed.fsdp.FullyShardedDataParallel) \
273
+ else nullcontext
274
+ with embedding_context():
275
+ samples = self.diffusion.sample(
276
+ sampler=sampler,
277
+ noise=noise,
278
+ model=model,
279
+ model_kwargs=[{
280
+ 'cond':
281
+ context,
282
+ 'mask':
283
+ cont_mask,
284
+ 'text_position_embeddings':
285
+ self.text_position_embeddings.pos if hasattr(
286
+ self.text_position_embeddings, 'pos') else None
287
+ }, {
288
+ 'cond':
289
+ null_context,
290
+ 'mask':
291
+ null_cont_mask,
292
+ 'text_position_embeddings':
293
+ self.text_position_embeddings.pos if hasattr(
294
+ self.text_position_embeddings, 'pos') else None
295
+ }] if guide_scale is not None and guide_scale > 1 else {
296
+ 'cond':
297
+ context,
298
+ 'mask':
299
+ cont_mask,
300
+ 'text_position_embeddings':
301
+ self.text_position_embeddings.pos if hasattr(
302
+ self.text_position_embeddings, 'pos') else None
303
+ },
304
+ steps=sample_steps,
305
+ guide_scale=guide_scale,
306
+ guide_rescale=guide_rescale,
307
+ show_progress=True,
308
+ **kwargs)
309
+
310
+ samples = unpack_tensor_into_imagelist(samples, x_shapes)
311
+ x_samples = self.decode_first_stage(samples)
312
+ outputs = list()
313
+ for i in range(len(prompt)):
314
+ rec_img = torch.clamp(
315
+ (x_samples[i] + 1.0) / 2.0 + self.decoder_bias / 255,
316
+ min=0.0,
317
+ max=1.0)
318
+ rec_img = rec_img.squeeze(0)
319
+ edit_imgs, edit_img_masks = [], []
320
+ if edit_image is not None and edit_image[i] is not None:
321
+ if edit_image_mask[i] is None:
322
+ edit_image_mask[i] = [None] * len(edit_image[i])
323
+ for edit_img, edit_mask in zip(edit_image[i],
324
+ edit_image_mask[i]):
325
+ edit_img = torch.clamp((edit_img + 1.0) / 2.0,
326
+ min=0.0,
327
+ max=1.0)
328
+ edit_imgs.append(edit_img.squeeze(0))
329
+ if edit_mask is None:
330
+ edit_mask = torch.ones_like(edit_img[[0], :, :])
331
+ edit_img_masks.append(edit_mask)
332
+ one_tup = {
333
+ 'reconstruct_image': rec_img,
334
+ 'instruction': prompt[i],
335
+ 'edit_image': edit_imgs if len(edit_imgs) > 0 else None,
336
+ 'edit_mask': edit_img_masks if len(edit_imgs) > 0 else None
337
+ }
338
+ if image is not None:
339
+ if image_mask is None:
340
+ image_mask = [None] * len(image)
341
+ ori_img = torch.clamp((image[i] + 1.0) / 2.0, min=0.0, max=1.0)
342
+ one_tup['target_image'] = ori_img.squeeze(0)
343
+ one_tup['target_mask'] = image_mask[i] if image_mask[
344
+ i] is not None else torch.ones_like(ori_img[[0], :, :])
345
+ outputs.append(one_tup)
346
+ return outputs
347
+
348
+ @staticmethod
349
+ def get_config_template():
350
+ return dict_to_yaml('MODEL',
351
+ __class__.__name__,
352
+ LdmACE.para_dict,
353
+ set_name=True)
modules/model/utils/basic_utils.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from inspect import isfunction
4
+
5
+ import torch
6
+ from torch.nn.utils.rnn import pad_sequence
7
+
8
+ from scepter.modules.utils.distribute import we
9
+
10
+
11
+ def exists(x):
12
+ return x is not None
13
+
14
+
15
+ def default(val, d):
16
+ if exists(val):
17
+ return val
18
+ return d() if isfunction(d) else d
19
+
20
+
21
+ def disabled_train(self, mode=True):
22
+ """Overwrite model.train with this function to make sure train/eval mode
23
+ does not change anymore."""
24
+ return self
25
+
26
+
27
+ def transfer_size(para_num):
28
+ if para_num > 1000 * 1000 * 1000 * 1000:
29
+ bill = para_num / (1000 * 1000 * 1000 * 1000)
30
+ return '{:.2f}T'.format(bill)
31
+ elif para_num > 1000 * 1000 * 1000:
32
+ gyte = para_num / (1000 * 1000 * 1000)
33
+ return '{:.2f}B'.format(gyte)
34
+ elif para_num > (1000 * 1000):
35
+ meta = para_num / (1000 * 1000)
36
+ return '{:.2f}M'.format(meta)
37
+ elif para_num > 1000:
38
+ kelo = para_num / 1000
39
+ return '{:.2f}K'.format(kelo)
40
+ else:
41
+ return para_num
42
+
43
+
44
+ def count_params(model):
45
+ total_params = sum(p.numel() for p in model.parameters())
46
+ return transfer_size(total_params)
47
+
48
+
49
+ def expand_dims_like(x, y):
50
+ while x.dim() != y.dim():
51
+ x = x.unsqueeze(-1)
52
+ return x
53
+
54
+
55
+ def unpack_tensor_into_imagelist(image_tensor, shapes):
56
+ image_list = []
57
+ for img, shape in zip(image_tensor, shapes):
58
+ h, w = shape[0], shape[1]
59
+ image_list.append(img[:, :h * w].view(1, -1, h, w))
60
+
61
+ return image_list
62
+
63
+
64
+ def find_example(tensor_list, image_list):
65
+ for i in tensor_list:
66
+ if isinstance(i, torch.Tensor):
67
+ return torch.zeros_like(i)
68
+ for i in image_list:
69
+ if isinstance(i, torch.Tensor):
70
+ _, c, h, w = i.size()
71
+ return torch.zeros_like(i.view(c, h * w).transpose(1, 0))
72
+ return None
73
+
74
+
75
+ def pack_imagelist_into_tensor_v2(image_list):
76
+ # allow None
77
+ example = None
78
+ image_tensor, shapes = [], []
79
+ for img in image_list:
80
+ if img is None:
81
+ example = find_example(image_tensor,
82
+ image_list) if example is None else example
83
+ image_tensor.append(example)
84
+ shapes.append(None)
85
+ continue
86
+ _, c, h, w = img.size()
87
+ image_tensor.append(img.view(c, h * w).transpose(1, 0)) # h*w, c
88
+ shapes.append((h, w))
89
+
90
+ image_tensor = pad_sequence(image_tensor,
91
+ batch_first=True).permute(0, 2, 1) # b, c, l
92
+ return image_tensor, shapes
93
+
94
+
95
+ def to_device(inputs, strict=True):
96
+ if inputs is None:
97
+ return None
98
+ if strict:
99
+ assert all(isinstance(i, torch.Tensor) for i in inputs)
100
+ return [i.to(we.device_id) if i is not None else None for i in inputs]
101
+
102
+
103
+ def check_list_of_list(ll):
104
+ return isinstance(ll, list) and all(isinstance(i, list) for i in ll)
modules/solver/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .ace_solver import ACESolverV1
modules/solver/ace_solver.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import numpy as np
4
+ import torch
5
+ from tqdm import tqdm
6
+
7
+ from scepter.modules.utils.data import transfer_data_to_cuda
8
+ from scepter.modules.utils.distribute import we
9
+ from scepter.modules.utils.probe import ProbeData
10
+ from scepter.modules.solver.registry import SOLVERS
11
+ from scepter.modules.solver.diffusion_solver import LatentDiffusionSolver
12
+
13
+
14
+
15
+ @SOLVERS.register_class()
16
+ class ACESolverV1(LatentDiffusionSolver):
17
+ def __init__(self, cfg, logger=None):
18
+ super().__init__(cfg, logger=logger)
19
+ self.log_train_num = cfg.get('LOG_TRAIN_NUM', -1)
20
+
21
+ def save_results(self, results):
22
+ log_data, log_label = [], []
23
+ for result in results:
24
+ ret_images, ret_labels = [], []
25
+ edit_image = result.get('edit_image', None)
26
+ edit_mask = result.get('edit_mask', None)
27
+ if edit_image is not None:
28
+ for i, edit_img in enumerate(result['edit_image']):
29
+ if edit_img is None:
30
+ continue
31
+ ret_images.append(
32
+ (edit_img.permute(1, 2, 0).cpu().numpy() * 255).astype(
33
+ np.uint8))
34
+ ret_labels.append(f'edit_image{i}; ')
35
+ if edit_mask is not None:
36
+ ret_images.append(
37
+ (edit_mask[i].permute(1, 2, 0).cpu().numpy() *
38
+ 255).astype(np.uint8))
39
+ ret_labels.append(f'edit_mask{i}; ')
40
+
41
+ target_image = result.get('target_image', None)
42
+ target_mask = result.get('target_mask', None)
43
+ if target_image is not None:
44
+ ret_images.append(
45
+ (target_image.permute(1, 2, 0).cpu().numpy() * 255).astype(
46
+ np.uint8))
47
+ ret_labels.append('target_image; ')
48
+ if target_mask is not None:
49
+ ret_images.append(
50
+ (target_mask.permute(1, 2, 0).cpu().numpy() *
51
+ 255).astype(np.uint8))
52
+ ret_labels.append('target_mask; ')
53
+
54
+ reconstruct_image = result.get('reconstruct_image', None)
55
+ if reconstruct_image is not None:
56
+ ret_images.append(
57
+ (reconstruct_image.permute(1, 2, 0).cpu().numpy() *
58
+ 255).astype(np.uint8))
59
+ ret_labels.append(f"{result['instruction']}")
60
+ log_data.append(ret_images)
61
+ log_label.append(ret_labels)
62
+ return log_data, log_label
63
+
64
+ @torch.no_grad()
65
+ def run_eval(self):
66
+ self.eval_mode()
67
+ self.before_all_iter(self.hooks_dict[self._mode])
68
+ all_results = []
69
+ for batch_idx, batch_data in tqdm(
70
+ enumerate(self.datas[self._mode].dataloader)):
71
+ self.before_iter(self.hooks_dict[self._mode])
72
+ if self.sample_args:
73
+ batch_data.update(self.sample_args.get_lowercase_dict())
74
+ with torch.autocast(device_type='cuda',
75
+ enabled=self.use_amp,
76
+ dtype=self.dtype):
77
+ results = self.run_step_eval(transfer_data_to_cuda(batch_data),
78
+ batch_idx,
79
+ step=self.total_iter,
80
+ rank=we.rank)
81
+ all_results.extend(results)
82
+ self.after_iter(self.hooks_dict[self._mode])
83
+ log_data, log_label = self.save_results(all_results)
84
+ self.register_probe({'eval_label': log_label})
85
+ self.register_probe({
86
+ 'eval_image':
87
+ ProbeData(log_data,
88
+ is_image=True,
89
+ build_html=True,
90
+ build_label=log_label)
91
+ })
92
+ self.after_all_iter(self.hooks_dict[self._mode])
93
+
94
+ @torch.no_grad()
95
+ def run_test(self):
96
+ self.test_mode()
97
+ self.before_all_iter(self.hooks_dict[self._mode])
98
+ all_results = []
99
+ for batch_idx, batch_data in tqdm(
100
+ enumerate(self.datas[self._mode].dataloader)):
101
+ self.before_iter(self.hooks_dict[self._mode])
102
+ if self.sample_args:
103
+ batch_data.update(self.sample_args.get_lowercase_dict())
104
+ with torch.autocast(device_type='cuda',
105
+ enabled=self.use_amp,
106
+ dtype=self.dtype):
107
+ results = self.run_step_eval(transfer_data_to_cuda(batch_data),
108
+ batch_idx,
109
+ step=self.total_iter,
110
+ rank=we.rank)
111
+ all_results.extend(results)
112
+ self.after_iter(self.hooks_dict[self._mode])
113
+ log_data, log_label = self.save_results(all_results)
114
+ self.register_probe({'test_label': log_label})
115
+ self.register_probe({
116
+ 'test_image':
117
+ ProbeData(log_data,
118
+ is_image=True,
119
+ build_html=True,
120
+ build_label=log_label)
121
+ })
122
+
123
+ self.after_all_iter(self.hooks_dict[self._mode])
124
+
125
+ @property
126
+ def probe_data(self):
127
+ if not we.debug and self.mode == 'train':
128
+ batch_data = transfer_data_to_cuda(
129
+ self.current_batch_data[self.mode])
130
+ self.eval_mode()
131
+ with torch.autocast(device_type='cuda',
132
+ enabled=self.use_amp,
133
+ dtype=self.dtype):
134
+ batch_data['log_num'] = self.log_train_num
135
+ results = self.run_step_eval(batch_data)
136
+ self.train_mode()
137
+ log_data, log_label = self.save_results(results)
138
+ self.register_probe({
139
+ 'train_image':
140
+ ProbeData(log_data,
141
+ is_image=True,
142
+ build_html=True,
143
+ build_label=log_label)
144
+ })
145
+ self.register_probe({'train_label': log_label})
146
+ return super(LatentDiffusionSolver, self).probe_data
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- huggingface_hub==0.25.2
 
 
1
+ huggingface_hub==0.25.2
2
+ scepter>=1.2.0
utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import torch
4
+ import torchvision.transforms as T
5
+ from PIL import Image
6
+ from torchvision.transforms.functional import InterpolationMode
7
+
8
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
9
+ IMAGENET_STD = (0.229, 0.224, 0.225)
10
+
11
+
12
+ def build_transform(input_size):
13
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
14
+ transform = T.Compose([
15
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
16
+ T.Resize((input_size, input_size),
17
+ interpolation=InterpolationMode.BICUBIC),
18
+ T.ToTensor(),
19
+ T.Normalize(mean=MEAN, std=STD)
20
+ ])
21
+ return transform
22
+
23
+
24
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
25
+ image_size):
26
+ best_ratio_diff = float('inf')
27
+ best_ratio = (1, 1)
28
+ area = width * height
29
+ for ratio in target_ratios:
30
+ target_aspect_ratio = ratio[0] / ratio[1]
31
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
32
+ if ratio_diff < best_ratio_diff:
33
+ best_ratio_diff = ratio_diff
34
+ best_ratio = ratio
35
+ elif ratio_diff == best_ratio_diff:
36
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
37
+ best_ratio = ratio
38
+ return best_ratio
39
+
40
+
41
+ def dynamic_preprocess(image,
42
+ min_num=1,
43
+ max_num=12,
44
+ image_size=448,
45
+ use_thumbnail=False):
46
+ orig_width, orig_height = image.size
47
+ aspect_ratio = orig_width / orig_height
48
+
49
+ # calculate the existing image aspect ratio
50
+ target_ratios = set((i, j) for n in range(min_num, max_num + 1)
51
+ for i in range(1, n + 1) for j in range(1, n + 1)
52
+ if i * j <= max_num and i * j >= min_num)
53
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
54
+
55
+ # find the closest aspect ratio to the target
56
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
57
+ target_ratios, orig_width,
58
+ orig_height, image_size)
59
+
60
+ # calculate the target width and height
61
+ target_width = image_size * target_aspect_ratio[0]
62
+ target_height = image_size * target_aspect_ratio[1]
63
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
64
+
65
+ # resize the image
66
+ resized_img = image.resize((target_width, target_height))
67
+ processed_images = []
68
+ for i in range(blocks):
69
+ box = ((i % (target_width // image_size)) * image_size,
70
+ (i // (target_width // image_size)) * image_size,
71
+ ((i % (target_width // image_size)) + 1) * image_size,
72
+ ((i // (target_width // image_size)) + 1) * image_size)
73
+ # split the image
74
+ split_img = resized_img.crop(box)
75
+ processed_images.append(split_img)
76
+ assert len(processed_images) == blocks
77
+ if use_thumbnail and len(processed_images) != 1:
78
+ thumbnail_img = image.resize((image_size, image_size))
79
+ processed_images.append(thumbnail_img)
80
+ return processed_images
81
+
82
+
83
+ def load_image(image_file, input_size=448, max_num=12):
84
+ if isinstance(image_file, str):
85
+ image = Image.open(image_file).convert('RGB')
86
+ else:
87
+ image = image_file
88
+ transform = build_transform(input_size=input_size)
89
+ images = dynamic_preprocess(image,
90
+ image_size=input_size,
91
+ use_thumbnail=True,
92
+ max_num=max_num)
93
+ pixel_values = [transform(image) for image in images]
94
+ pixel_values = torch.stack(pixel_values)
95
+ return pixel_values