Spaces:
Running
on
Zero
Running
on
Zero
update file
Browse files- __init__.py +1 -0
- app.py +1197 -62
- config/chatbot_ui.yaml +25 -0
- config/models/ace_0.6b_512.yaml +127 -0
- example.py +339 -0
- infer.py +378 -0
- modules/__init__.py +1 -0
- modules/data/__init__.py +1 -0
- modules/data/dataset/__init__.py +1 -0
- modules/data/dataset/dataset.py +252 -0
- modules/model/__init__.py +1 -0
- modules/model/backbone/__init__.py +3 -0
- modules/model/backbone/ace.py +373 -0
- modules/model/backbone/layers.py +386 -0
- modules/model/backbone/pos_embed.py +85 -0
- modules/model/diffusion/__init__.py +6 -0
- modules/model/diffusion/diffusions.py +206 -0
- modules/model/diffusion/samplers.py +69 -0
- modules/model/diffusion/schedules.py +30 -0
- modules/model/embedder/__init__.py +1 -0
- modules/model/embedder/embedder.py +184 -0
- modules/model/network/__init__.py +1 -0
- modules/model/network/ldm_ace.py +353 -0
- modules/model/utils/basic_utils.py +104 -0
- modules/solver/__init__.py +1 -0
- modules/solver/ace_solver.py +146 -0
- requirements.txt +2 -1
- utils.py +95 -0
__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import modules
|
app.py
CHANGED
@@ -1,64 +1,1199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
demo.launch()
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import argparse
|
4 |
+
import base64
|
5 |
+
import copy
|
6 |
+
import glob
|
7 |
+
import io
|
8 |
+
import os
|
9 |
+
import random
|
10 |
+
import re
|
11 |
+
import string
|
12 |
+
import threading
|
13 |
+
|
14 |
+
import cv2
|
15 |
import gradio as gr
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import transformers
|
19 |
+
from diffusers import CogVideoXImageToVideoPipeline
|
20 |
+
from diffusers.utils import export_to_video
|
21 |
+
from gradio_imageslider import ImageSlider
|
22 |
+
from PIL import Image
|
23 |
+
from transformers import AutoModel, AutoTokenizer
|
24 |
+
|
25 |
+
from scepter.modules.utils.config import Config
|
26 |
+
from scepter.modules.utils.directory import get_md5
|
27 |
+
from scepter.modules.utils.file_system import FS
|
28 |
+
from scepter.studio.utils.env import init_env
|
29 |
+
|
30 |
+
from .infer import ACEInference
|
31 |
+
from .example import get_examples
|
32 |
+
from .utils import load_image
|
33 |
+
|
34 |
+
|
35 |
+
refresh_sty = '\U0001f504' # 🔄
|
36 |
+
clear_sty = '\U0001f5d1' # 🗑️
|
37 |
+
upload_sty = '\U0001f5bc' # 🖼️
|
38 |
+
sync_sty = '\U0001f4be' # 💾
|
39 |
+
chat_sty = '\U0001F4AC' # 💬
|
40 |
+
video_sty = '\U0001f3a5' # 🎥
|
41 |
+
|
42 |
+
lock = threading.Lock()
|
43 |
+
|
44 |
+
|
45 |
+
class ChatBotUI(object):
|
46 |
+
def __init__(self,
|
47 |
+
cfg_general_file,
|
48 |
+
root_work_dir='./'):
|
49 |
+
|
50 |
+
cfg = Config(cfg_file=cfg_general_file)
|
51 |
+
cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
|
52 |
+
if not FS.exists(cfg.WORK_DIR):
|
53 |
+
FS.make_dir(cfg.WORK_DIR)
|
54 |
+
cfg = init_env(cfg)
|
55 |
+
self.cache_dir = cfg.WORK_DIR
|
56 |
+
self.chatbot_examples = get_examples(self.cache_dir)
|
57 |
+
self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR
|
58 |
+
self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir,
|
59 |
+
'*.yaml'))
|
60 |
+
self.model_choices = dict()
|
61 |
+
for i in self.model_yamls:
|
62 |
+
model_name = '.'.join(i.split('/')[-1].split('.')[:-1])
|
63 |
+
self.model_choices[model_name] = i
|
64 |
+
print('Models: ', self.model_choices)
|
65 |
+
|
66 |
+
self.model_name = cfg.MODEL.EDIT_MODEL.DEFAULT
|
67 |
+
assert self.model_name in self.model_choices
|
68 |
+
model_cfg = Config(load=True,
|
69 |
+
cfg_file=self.model_choices[self.model_name])
|
70 |
+
self.pipe = ACEInference()
|
71 |
+
self.pipe.init_from_cfg(model_cfg)
|
72 |
+
self.retry_msg = ''
|
73 |
+
self.max_msgs = 20
|
74 |
+
|
75 |
+
self.enable_i2v = cfg.get('ENABLE_I2V', False)
|
76 |
+
if self.enable_i2v:
|
77 |
+
self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR
|
78 |
+
self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME
|
79 |
+
if self.i2v_model_name == 'CogVideoX-5b-I2V':
|
80 |
+
with FS.get_dir_to_local_dir(self.i2v_model_dir) as local_dir:
|
81 |
+
self.i2v_pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
82 |
+
local_dir, torch_dtype=torch.bfloat16).cuda()
|
83 |
+
else:
|
84 |
+
raise NotImplementedError
|
85 |
+
|
86 |
+
with FS.get_dir_to_local_dir(
|
87 |
+
cfg.MODEL.CAPTIONER.MODEL_DIR) as local_dir:
|
88 |
+
self.captioner = AutoModel.from_pretrained(
|
89 |
+
local_dir,
|
90 |
+
torch_dtype=torch.bfloat16,
|
91 |
+
low_cpu_mem_usage=True,
|
92 |
+
use_flash_attn=True,
|
93 |
+
trust_remote_code=True).eval().cuda()
|
94 |
+
self.llm_tokenizer = AutoTokenizer.from_pretrained(
|
95 |
+
local_dir, trust_remote_code=True, use_fast=False)
|
96 |
+
self.llm_generation_config = dict(max_new_tokens=1024,
|
97 |
+
do_sample=True)
|
98 |
+
self.llm_prompt = cfg.LLM.PROMPT
|
99 |
+
self.llm_max_num = 2
|
100 |
+
|
101 |
+
with FS.get_dir_to_local_dir(
|
102 |
+
cfg.MODEL.ENHANCER.MODEL_DIR) as local_dir:
|
103 |
+
self.enhancer = transformers.pipeline(
|
104 |
+
'text-generation',
|
105 |
+
model=local_dir,
|
106 |
+
model_kwargs={'torch_dtype': torch.bfloat16},
|
107 |
+
device_map='auto',
|
108 |
+
)
|
109 |
+
|
110 |
+
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
|
111 |
+
|
112 |
+
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
|
113 |
+
There are a few rules to follow:
|
114 |
+
|
115 |
+
You will only ever output a single video description per user request.
|
116 |
+
|
117 |
+
When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
|
118 |
+
Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
|
119 |
+
|
120 |
+
Video descriptions must have the same num of words as examples below. Extra words will be ignored.
|
121 |
+
"""
|
122 |
+
self.enhance_ctx = [
|
123 |
+
{
|
124 |
+
'role': 'system',
|
125 |
+
'content': sys_prompt
|
126 |
+
},
|
127 |
+
{
|
128 |
+
'role':
|
129 |
+
'user',
|
130 |
+
'content':
|
131 |
+
'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
|
132 |
+
},
|
133 |
+
{
|
134 |
+
'role':
|
135 |
+
'assistant',
|
136 |
+
'content':
|
137 |
+
"A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
|
138 |
+
},
|
139 |
+
{
|
140 |
+
'role':
|
141 |
+
'user',
|
142 |
+
'content':
|
143 |
+
'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
|
144 |
+
},
|
145 |
+
{
|
146 |
+
'role':
|
147 |
+
'assistant',
|
148 |
+
'content':
|
149 |
+
"A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
|
150 |
+
},
|
151 |
+
{
|
152 |
+
'role':
|
153 |
+
'user',
|
154 |
+
'content':
|
155 |
+
'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
|
156 |
+
},
|
157 |
+
{
|
158 |
+
'role':
|
159 |
+
'assistant',
|
160 |
+
'content':
|
161 |
+
'A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.',
|
162 |
+
},
|
163 |
+
]
|
164 |
+
|
165 |
+
def create_ui(self):
|
166 |
+
css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}'
|
167 |
+
with gr.Blocks(css=css,
|
168 |
+
title='Chatbot',
|
169 |
+
head='Chatbot',
|
170 |
+
analytics_enabled=False):
|
171 |
+
self.history = gr.State(value=[])
|
172 |
+
self.images = gr.State(value={})
|
173 |
+
self.history_result = gr.State(value={})
|
174 |
+
with gr.Group():
|
175 |
+
with gr.Row(equal_height=True):
|
176 |
+
with gr.Column(visible=True) as self.chat_page:
|
177 |
+
self.chatbot = gr.Chatbot(
|
178 |
+
height=600,
|
179 |
+
value=[],
|
180 |
+
bubble_full_width=False,
|
181 |
+
show_copy_button=True,
|
182 |
+
container=False,
|
183 |
+
placeholder='<strong>Chat Box</strong>')
|
184 |
+
with gr.Row():
|
185 |
+
self.clear_btn = gr.Button(clear_sty +
|
186 |
+
' Clear Chat',
|
187 |
+
size='sm')
|
188 |
+
|
189 |
+
with gr.Column(visible=False) as self.editor_page:
|
190 |
+
with gr.Tabs():
|
191 |
+
with gr.Tab(id='ImageUploader',
|
192 |
+
label='Image Uploader',
|
193 |
+
visible=True) as self.upload_tab:
|
194 |
+
self.image_uploader = gr.Image(
|
195 |
+
height=550,
|
196 |
+
interactive=True,
|
197 |
+
type='pil',
|
198 |
+
image_mode='RGB',
|
199 |
+
sources='upload',
|
200 |
+
elem_id='image_uploader',
|
201 |
+
format='png')
|
202 |
+
with gr.Row():
|
203 |
+
self.sub_btn_1 = gr.Button(
|
204 |
+
value='Submit',
|
205 |
+
elem_id='upload_submit')
|
206 |
+
self.ext_btn_1 = gr.Button(value='Exit')
|
207 |
+
|
208 |
+
with gr.Tab(id='ImageEditor',
|
209 |
+
label='Image Editor',
|
210 |
+
visible=False) as self.edit_tab:
|
211 |
+
self.mask_type = gr.Dropdown(
|
212 |
+
label='Mask Type',
|
213 |
+
choices=[
|
214 |
+
'Background', 'Composite',
|
215 |
+
'Outpainting'
|
216 |
+
],
|
217 |
+
value='Background')
|
218 |
+
self.mask_type_info = gr.HTML(
|
219 |
+
value=
|
220 |
+
"<div style='background-color: white; padding-left: 15px; color: grey;'>Background mode will not erase the visual content in the mask area</div>"
|
221 |
+
)
|
222 |
+
with gr.Accordion(
|
223 |
+
label='Outpainting Setting',
|
224 |
+
open=True,
|
225 |
+
visible=False) as self.outpaint_tab:
|
226 |
+
with gr.Row(variant='panel'):
|
227 |
+
self.top_ext = gr.Slider(
|
228 |
+
show_label=True,
|
229 |
+
label='Top Extend Ratio',
|
230 |
+
minimum=0.0,
|
231 |
+
maximum=2.0,
|
232 |
+
step=0.1,
|
233 |
+
value=0.25)
|
234 |
+
self.bottom_ext = gr.Slider(
|
235 |
+
show_label=True,
|
236 |
+
label='Bottom Extend Ratio',
|
237 |
+
minimum=0.0,
|
238 |
+
maximum=2.0,
|
239 |
+
step=0.1,
|
240 |
+
value=0.25)
|
241 |
+
with gr.Row(variant='panel'):
|
242 |
+
self.left_ext = gr.Slider(
|
243 |
+
show_label=True,
|
244 |
+
label='Left Extend Ratio',
|
245 |
+
minimum=0.0,
|
246 |
+
maximum=2.0,
|
247 |
+
step=0.1,
|
248 |
+
value=0.25)
|
249 |
+
self.right_ext = gr.Slider(
|
250 |
+
show_label=True,
|
251 |
+
label='Right Extend Ratio',
|
252 |
+
minimum=0.0,
|
253 |
+
maximum=2.0,
|
254 |
+
step=0.1,
|
255 |
+
value=0.25)
|
256 |
+
with gr.Row(variant='panel'):
|
257 |
+
self.img_pad_btn = gr.Button(
|
258 |
+
value='Pad Image')
|
259 |
+
|
260 |
+
self.image_editor = gr.ImageMask(
|
261 |
+
value=None,
|
262 |
+
sources=[],
|
263 |
+
layers=False,
|
264 |
+
label='Edit Image',
|
265 |
+
elem_id='image_editor',
|
266 |
+
format='png')
|
267 |
+
with gr.Row():
|
268 |
+
self.sub_btn_2 = gr.Button(
|
269 |
+
value='Submit', elem_id='edit_submit')
|
270 |
+
self.ext_btn_2 = gr.Button(value='Exit')
|
271 |
+
|
272 |
+
with gr.Tab(id='ImageViewer',
|
273 |
+
label='Image Viewer',
|
274 |
+
visible=False) as self.image_view_tab:
|
275 |
+
self.image_viewer = ImageSlider(
|
276 |
+
label='Image',
|
277 |
+
type='pil',
|
278 |
+
show_download_button=True,
|
279 |
+
elem_id='image_viewer')
|
280 |
+
|
281 |
+
self.ext_btn_3 = gr.Button(value='Exit')
|
282 |
+
|
283 |
+
with gr.Tab(id='VideoViewer',
|
284 |
+
label='Video Viewer',
|
285 |
+
visible=False) as self.video_view_tab:
|
286 |
+
self.video_viewer = gr.Video(
|
287 |
+
label='Video',
|
288 |
+
interactive=False,
|
289 |
+
sources=[],
|
290 |
+
format='mp4',
|
291 |
+
show_download_button=True,
|
292 |
+
elem_id='video_viewer',
|
293 |
+
loop=True,
|
294 |
+
autoplay=True)
|
295 |
+
|
296 |
+
self.ext_btn_4 = gr.Button(value='Exit')
|
297 |
+
|
298 |
+
with gr.Accordion(label='Setting', open=False):
|
299 |
+
with gr.Row():
|
300 |
+
self.model_name_dd = gr.Dropdown(
|
301 |
+
choices=self.model_choices,
|
302 |
+
value=self.model_name,
|
303 |
+
label='Model Version')
|
304 |
+
|
305 |
+
with gr.Row():
|
306 |
+
self.negative_prompt = gr.Textbox(
|
307 |
+
value='',
|
308 |
+
placeholder=
|
309 |
+
'Negative prompt used for Classifier-Free Guidance',
|
310 |
+
label='Negative Prompt',
|
311 |
+
container=False)
|
312 |
+
|
313 |
+
with gr.Row():
|
314 |
+
with gr.Column(scale=8, min_width=500):
|
315 |
+
with gr.Row():
|
316 |
+
self.step = gr.Slider(minimum=1,
|
317 |
+
maximum=1000,
|
318 |
+
value=20,
|
319 |
+
label='Sample Step')
|
320 |
+
self.cfg_scale = gr.Slider(
|
321 |
+
minimum=1.0,
|
322 |
+
maximum=20.0,
|
323 |
+
value=4.5,
|
324 |
+
label='Guidance Scale')
|
325 |
+
self.rescale = gr.Slider(minimum=0.0,
|
326 |
+
maximum=1.0,
|
327 |
+
value=0.5,
|
328 |
+
label='Rescale')
|
329 |
+
self.seed = gr.Slider(minimum=-1,
|
330 |
+
maximum=10000000,
|
331 |
+
value=-1,
|
332 |
+
label='Seed')
|
333 |
+
self.output_height = gr.Slider(
|
334 |
+
minimum=256,
|
335 |
+
maximum=1024,
|
336 |
+
value=512,
|
337 |
+
label='Output Height')
|
338 |
+
self.output_width = gr.Slider(
|
339 |
+
minimum=256,
|
340 |
+
maximum=1024,
|
341 |
+
value=512,
|
342 |
+
label='Output Width')
|
343 |
+
with gr.Column(scale=1, min_width=50):
|
344 |
+
self.use_history = gr.Checkbox(value=False,
|
345 |
+
label='Use History')
|
346 |
+
self.video_auto = gr.Checkbox(
|
347 |
+
value=False,
|
348 |
+
label='Auto Gen Video',
|
349 |
+
visible=self.enable_i2v)
|
350 |
+
|
351 |
+
with gr.Row(variant='panel',
|
352 |
+
equal_height=True,
|
353 |
+
visible=self.enable_i2v):
|
354 |
+
self.video_fps = gr.Slider(minimum=1,
|
355 |
+
maximum=16,
|
356 |
+
value=8,
|
357 |
+
label='Video FPS',
|
358 |
+
visible=True)
|
359 |
+
self.video_frames = gr.Slider(minimum=8,
|
360 |
+
maximum=49,
|
361 |
+
value=49,
|
362 |
+
label='Video Frame Num',
|
363 |
+
visible=True)
|
364 |
+
self.video_step = gr.Slider(minimum=1,
|
365 |
+
maximum=1000,
|
366 |
+
value=50,
|
367 |
+
label='Video Sample Step',
|
368 |
+
visible=True)
|
369 |
+
self.video_cfg_scale = gr.Slider(
|
370 |
+
minimum=1.0,
|
371 |
+
maximum=20.0,
|
372 |
+
value=6.0,
|
373 |
+
label='Video Guidance Scale',
|
374 |
+
visible=True)
|
375 |
+
self.video_seed = gr.Slider(minimum=-1,
|
376 |
+
maximum=10000000,
|
377 |
+
value=-1,
|
378 |
+
label='Video Seed',
|
379 |
+
visible=True)
|
380 |
+
|
381 |
+
with gr.Row(variant='panel',
|
382 |
+
equal_height=True,
|
383 |
+
show_progress=False):
|
384 |
+
with gr.Column(scale=1, min_width=100):
|
385 |
+
self.upload_btn = gr.Button(value=upload_sty +
|
386 |
+
' Upload',
|
387 |
+
variant='secondary')
|
388 |
+
with gr.Column(scale=5, min_width=500):
|
389 |
+
self.text = gr.Textbox(
|
390 |
+
placeholder='Input "@" find history of image',
|
391 |
+
label='Instruction',
|
392 |
+
container=False)
|
393 |
+
with gr.Column(scale=1, min_width=100):
|
394 |
+
self.chat_btn = gr.Button(value=chat_sty + ' Chat',
|
395 |
+
variant='primary')
|
396 |
+
with gr.Column(scale=1, min_width=100):
|
397 |
+
self.retry_btn = gr.Button(value=refresh_sty +
|
398 |
+
' Retry',
|
399 |
+
variant='secondary')
|
400 |
+
with gr.Column(scale=(1 if self.enable_i2v else 0),
|
401 |
+
min_width=0):
|
402 |
+
self.video_gen_btn = gr.Button(value=video_sty +
|
403 |
+
' Gen Video',
|
404 |
+
variant='secondary',
|
405 |
+
visible=self.enable_i2v)
|
406 |
+
with gr.Column(scale=(1 if self.enable_i2v else 0),
|
407 |
+
min_width=0):
|
408 |
+
self.extend_prompt = gr.Checkbox(
|
409 |
+
value=True,
|
410 |
+
label='Extend Prompt',
|
411 |
+
visible=self.enable_i2v)
|
412 |
+
|
413 |
+
with gr.Row():
|
414 |
+
self.gallery = gr.Gallery(visible=False,
|
415 |
+
label='History',
|
416 |
+
columns=10,
|
417 |
+
allow_preview=False,
|
418 |
+
interactive=False)
|
419 |
+
|
420 |
+
self.eg = gr.Column(visible=True)
|
421 |
+
|
422 |
+
def set_callbacks(self, *args, **kwargs):
|
423 |
+
|
424 |
+
########################################
|
425 |
+
def change_model(model_name):
|
426 |
+
if model_name not in self.model_choices:
|
427 |
+
gr.Info('The provided model name is not a valid choice!')
|
428 |
+
return model_name, gr.update(), gr.update()
|
429 |
+
|
430 |
+
if model_name != self.model_name:
|
431 |
+
lock.acquire()
|
432 |
+
del self.pipe
|
433 |
+
torch.cuda.empty_cache()
|
434 |
+
model_cfg = Config(load=True,
|
435 |
+
cfg_file=self.model_choices[model_name])
|
436 |
+
self.pipe = ACEInference()
|
437 |
+
self.pipe.init_from_cfg(model_cfg)
|
438 |
+
self.model_name = model_name
|
439 |
+
lock.release()
|
440 |
+
|
441 |
+
return model_name, gr.update(), gr.update()
|
442 |
+
|
443 |
+
self.model_name_dd.change(
|
444 |
+
change_model,
|
445 |
+
inputs=[self.model_name_dd],
|
446 |
+
outputs=[self.model_name_dd, self.chatbot, self.text])
|
447 |
+
|
448 |
+
########################################
|
449 |
+
def generate_gallery(text, images):
|
450 |
+
if text.endswith(' '):
|
451 |
+
return gr.update(), gr.update(visible=False)
|
452 |
+
elif text.endswith('@'):
|
453 |
+
gallery_info = []
|
454 |
+
for image_id, image_meta in images.items():
|
455 |
+
thumbnail_path = image_meta['thumbnail']
|
456 |
+
gallery_info.append((thumbnail_path, image_id))
|
457 |
+
return gr.update(), gr.update(visible=True, value=gallery_info)
|
458 |
+
else:
|
459 |
+
gallery_info = []
|
460 |
+
match = re.search('@([^@ ]+)$', text)
|
461 |
+
if match:
|
462 |
+
prefix = match.group(1)
|
463 |
+
for image_id, image_meta in images.items():
|
464 |
+
if not image_id.startswith(prefix):
|
465 |
+
continue
|
466 |
+
thumbnail_path = image_meta['thumbnail']
|
467 |
+
gallery_info.append((thumbnail_path, image_id))
|
468 |
+
|
469 |
+
if len(gallery_info) > 0:
|
470 |
+
return gr.update(), gr.update(visible=True,
|
471 |
+
value=gallery_info)
|
472 |
+
else:
|
473 |
+
return gr.update(), gr.update(visible=False)
|
474 |
+
else:
|
475 |
+
return gr.update(), gr.update(visible=False)
|
476 |
+
|
477 |
+
self.text.input(generate_gallery,
|
478 |
+
inputs=[self.text, self.images],
|
479 |
+
outputs=[self.text, self.gallery],
|
480 |
+
show_progress='hidden')
|
481 |
+
|
482 |
+
########################################
|
483 |
+
def select_image(text, evt: gr.SelectData):
|
484 |
+
image_id = evt.value['caption']
|
485 |
+
text = '@'.join(text.split('@')[:-1]) + f'@{image_id} '
|
486 |
+
return gr.update(value=text), gr.update(visible=False, value=None)
|
487 |
+
|
488 |
+
self.gallery.select(select_image,
|
489 |
+
inputs=self.text,
|
490 |
+
outputs=[self.text, self.gallery])
|
491 |
+
|
492 |
+
########################################
|
493 |
+
def generate_video(message,
|
494 |
+
extend_prompt,
|
495 |
+
history,
|
496 |
+
images,
|
497 |
+
num_steps,
|
498 |
+
num_frames,
|
499 |
+
cfg_scale,
|
500 |
+
fps,
|
501 |
+
seed,
|
502 |
+
progress=gr.Progress(track_tqdm=True)):
|
503 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
504 |
+
img_ids = re.findall('@(.*?)[ ,;.?$]', message)
|
505 |
+
if len(img_ids) == 0:
|
506 |
+
history.append((
|
507 |
+
message,
|
508 |
+
'Sorry, no images were found in the prompt to be used as the first frame of the video.'
|
509 |
+
))
|
510 |
+
while len(history) >= self.max_msgs:
|
511 |
+
history.pop(0)
|
512 |
+
return history, self.get_history(
|
513 |
+
history), gr.update(), gr.update(visible=False)
|
514 |
+
|
515 |
+
img_id = img_ids[0]
|
516 |
+
prompt = re.sub(f'@{img_id}\s+', '', message)
|
517 |
+
|
518 |
+
if extend_prompt:
|
519 |
+
messages = copy.deepcopy(self.enhance_ctx)
|
520 |
+
messages.append({
|
521 |
+
'role':
|
522 |
+
'user',
|
523 |
+
'content':
|
524 |
+
f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{prompt}"',
|
525 |
+
})
|
526 |
+
lock.acquire()
|
527 |
+
outputs = self.enhancer(
|
528 |
+
messages,
|
529 |
+
max_new_tokens=200,
|
530 |
+
)
|
531 |
+
|
532 |
+
prompt = outputs[0]['generated_text'][-1]['content']
|
533 |
+
print(prompt)
|
534 |
+
lock.release()
|
535 |
+
|
536 |
+
img_meta = images[img_id]
|
537 |
+
img_path = img_meta['image']
|
538 |
+
image = Image.open(img_path).convert('RGB')
|
539 |
+
|
540 |
+
lock.acquire()
|
541 |
+
video = self.i2v_pipe(
|
542 |
+
prompt=prompt,
|
543 |
+
image=image,
|
544 |
+
num_videos_per_prompt=1,
|
545 |
+
num_inference_steps=num_steps,
|
546 |
+
num_frames=num_frames,
|
547 |
+
guidance_scale=cfg_scale,
|
548 |
+
generator=generator,
|
549 |
+
).frames[0]
|
550 |
+
lock.release()
|
551 |
+
|
552 |
+
out_video_path = export_to_video(video, fps=fps)
|
553 |
+
history.append((
|
554 |
+
f"Based on first frame @{img_id} and description '{prompt}', generate a video",
|
555 |
+
'This is generated video:'))
|
556 |
+
history.append((None, out_video_path))
|
557 |
+
while len(history) >= self.max_msgs:
|
558 |
+
history.pop(0)
|
559 |
+
|
560 |
+
return history, self.get_history(history), gr.update(
|
561 |
+
value=''), gr.update(visible=False)
|
562 |
+
|
563 |
+
self.video_gen_btn.click(
|
564 |
+
generate_video,
|
565 |
+
inputs=[
|
566 |
+
self.text, self.extend_prompt, self.history, self.images,
|
567 |
+
self.video_step, self.video_frames, self.video_cfg_scale,
|
568 |
+
self.video_fps, self.video_seed
|
569 |
+
],
|
570 |
+
outputs=[self.history, self.chatbot, self.text, self.gallery])
|
571 |
+
|
572 |
+
########################################
|
573 |
+
def run_chat(message,
|
574 |
+
extend_prompt,
|
575 |
+
history,
|
576 |
+
images,
|
577 |
+
use_history,
|
578 |
+
history_result,
|
579 |
+
negative_prompt,
|
580 |
+
cfg_scale,
|
581 |
+
rescale,
|
582 |
+
step,
|
583 |
+
seed,
|
584 |
+
output_h,
|
585 |
+
output_w,
|
586 |
+
video_auto,
|
587 |
+
video_steps,
|
588 |
+
video_frames,
|
589 |
+
video_cfg_scale,
|
590 |
+
video_fps,
|
591 |
+
video_seed,
|
592 |
+
progress=gr.Progress(track_tqdm=True)):
|
593 |
+
self.retry_msg = message
|
594 |
+
gen_id = get_md5(message)[:12]
|
595 |
+
save_path = os.path.join(self.cache_dir, f'{gen_id}.png')
|
596 |
+
|
597 |
+
img_ids = re.findall('@(.*?)[ ,;.?$]', message)
|
598 |
+
history_io = None
|
599 |
+
new_message = message
|
600 |
+
|
601 |
+
if len(img_ids) > 0:
|
602 |
+
edit_image, edit_image_mask, edit_task = [], [], []
|
603 |
+
for i, img_id in enumerate(img_ids):
|
604 |
+
if img_id not in images:
|
605 |
+
gr.Info(
|
606 |
+
f'The input image ID {img_id} is not exist... Skip loading image.'
|
607 |
+
)
|
608 |
+
continue
|
609 |
+
placeholder = '{image}' if i == 0 else '{' + f'image{i}' + '}'
|
610 |
+
new_message = re.sub(f'@{img_id}', placeholder,
|
611 |
+
new_message)
|
612 |
+
img_meta = images[img_id]
|
613 |
+
img_path = img_meta['image']
|
614 |
+
img_mask = img_meta['mask']
|
615 |
+
img_mask_type = img_meta['mask_type']
|
616 |
+
if img_mask_type is not None and img_mask_type == 'Composite':
|
617 |
+
task = 'inpainting'
|
618 |
+
else:
|
619 |
+
task = ''
|
620 |
+
edit_image.append(Image.open(img_path).convert('RGB'))
|
621 |
+
edit_image_mask.append(
|
622 |
+
Image.open(img_mask).
|
623 |
+
convert('L') if img_mask is not None else None)
|
624 |
+
edit_task.append(task)
|
625 |
+
|
626 |
+
if use_history and (img_id in history_result):
|
627 |
+
history_io = history_result[img_id]
|
628 |
+
|
629 |
+
buffered = io.BytesIO()
|
630 |
+
edit_image[0].save(buffered, format='PNG')
|
631 |
+
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
632 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
633 |
+
pre_info = f'Received one or more images, so image editing is conducted.\n The first input image @{img_ids[0]} is:\n {img_str}'
|
634 |
+
else:
|
635 |
+
pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
|
636 |
+
edit_image = None
|
637 |
+
edit_image_mask = None
|
638 |
+
edit_task = ''
|
639 |
+
|
640 |
+
print(new_message)
|
641 |
+
imgs = self.pipe(
|
642 |
+
input_image=edit_image,
|
643 |
+
input_mask=edit_image_mask,
|
644 |
+
task=edit_task,
|
645 |
+
prompt=[new_message] *
|
646 |
+
len(edit_image) if edit_image is not None else [new_message],
|
647 |
+
negative_prompt=[negative_prompt] * len(edit_image)
|
648 |
+
if edit_image is not None else [negative_prompt],
|
649 |
+
history_io=history_io,
|
650 |
+
output_height=output_h,
|
651 |
+
output_width=output_w,
|
652 |
+
sampler='ddim',
|
653 |
+
sample_steps=step,
|
654 |
+
guide_scale=cfg_scale,
|
655 |
+
guide_rescale=rescale,
|
656 |
+
seed=seed,
|
657 |
+
)
|
658 |
+
|
659 |
+
img = imgs[0]
|
660 |
+
img.save(save_path, format='PNG')
|
661 |
+
|
662 |
+
if history_io:
|
663 |
+
history_io_new = copy.deepcopy(history_io)
|
664 |
+
history_io_new['image'] += edit_image[:1]
|
665 |
+
history_io_new['mask'] += edit_image_mask[:1]
|
666 |
+
history_io_new['task'] += edit_task[:1]
|
667 |
+
history_io_new['prompt'] += [new_message]
|
668 |
+
history_io_new['image'] = history_io_new['image'][-5:]
|
669 |
+
history_io_new['mask'] = history_io_new['mask'][-5:]
|
670 |
+
history_io_new['task'] = history_io_new['task'][-5:]
|
671 |
+
history_io_new['prompt'] = history_io_new['prompt'][-5:]
|
672 |
+
history_result[gen_id] = history_io_new
|
673 |
+
elif edit_image is not None and len(edit_image) > 0:
|
674 |
+
history_io_new = {
|
675 |
+
'image': edit_image[:1],
|
676 |
+
'mask': edit_image_mask[:1],
|
677 |
+
'task': edit_task[:1],
|
678 |
+
'prompt': [new_message]
|
679 |
+
}
|
680 |
+
history_result[gen_id] = history_io_new
|
681 |
+
|
682 |
+
w, h = img.size
|
683 |
+
if w > h:
|
684 |
+
tb_w = 128
|
685 |
+
tb_h = int(h * tb_w / w)
|
686 |
+
else:
|
687 |
+
tb_h = 128
|
688 |
+
tb_w = int(w * tb_h / h)
|
689 |
+
|
690 |
+
thumbnail_path = os.path.join(self.cache_dir,
|
691 |
+
f'{gen_id}_thumbnail.jpg')
|
692 |
+
thumbnail = img.resize((tb_w, tb_h))
|
693 |
+
thumbnail.save(thumbnail_path, format='JPEG')
|
694 |
+
|
695 |
+
images[gen_id] = {
|
696 |
+
'image': save_path,
|
697 |
+
'mask': None,
|
698 |
+
'mask_type': None,
|
699 |
+
'thumbnail': thumbnail_path
|
700 |
+
}
|
701 |
+
|
702 |
+
buffered = io.BytesIO()
|
703 |
+
img.convert('RGB').save(buffered, format='PNG')
|
704 |
+
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
705 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
706 |
+
|
707 |
+
history.append(
|
708 |
+
(message,
|
709 |
+
f'{pre_info} The generated image @{gen_id} is:\n {img_str}'))
|
710 |
+
|
711 |
+
if video_auto:
|
712 |
+
if video_seed is None or video_seed == -1:
|
713 |
+
video_seed = random.randint(0, 10000000)
|
714 |
+
|
715 |
+
lock.acquire()
|
716 |
+
generator = torch.Generator(
|
717 |
+
device='cuda').manual_seed(video_seed)
|
718 |
+
pixel_values = load_image(img.convert('RGB'),
|
719 |
+
max_num=self.llm_max_num).to(
|
720 |
+
torch.bfloat16).cuda()
|
721 |
+
prompt = self.captioner.chat(self.llm_tokenizer, pixel_values,
|
722 |
+
self.llm_prompt,
|
723 |
+
self.llm_generation_config)
|
724 |
+
print(prompt)
|
725 |
+
lock.release()
|
726 |
+
|
727 |
+
if extend_prompt:
|
728 |
+
messages = copy.deepcopy(self.enhance_ctx)
|
729 |
+
messages.append({
|
730 |
+
'role':
|
731 |
+
'user',
|
732 |
+
'content':
|
733 |
+
f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{prompt}"',
|
734 |
+
})
|
735 |
+
lock.acquire()
|
736 |
+
outputs = self.enhancer(
|
737 |
+
messages,
|
738 |
+
max_new_tokens=200,
|
739 |
+
)
|
740 |
+
prompt = outputs[0]['generated_text'][-1]['content']
|
741 |
+
print(prompt)
|
742 |
+
lock.release()
|
743 |
+
|
744 |
+
lock.acquire()
|
745 |
+
video = self.i2v_pipe(
|
746 |
+
prompt=prompt,
|
747 |
+
image=img,
|
748 |
+
num_videos_per_prompt=1,
|
749 |
+
num_inference_steps=video_steps,
|
750 |
+
num_frames=video_frames,
|
751 |
+
guidance_scale=video_cfg_scale,
|
752 |
+
generator=generator,
|
753 |
+
).frames[0]
|
754 |
+
lock.release()
|
755 |
+
|
756 |
+
out_video_path = export_to_video(video, fps=video_fps)
|
757 |
+
history.append((
|
758 |
+
f"Based on first frame @{gen_id} and description '{prompt}', generate a video",
|
759 |
+
'This is generated video:'))
|
760 |
+
history.append((None, out_video_path))
|
761 |
+
|
762 |
+
while len(history) >= self.max_msgs:
|
763 |
+
history.pop(0)
|
764 |
+
|
765 |
+
return history, images, history_result, self.get_history(
|
766 |
+
history), gr.update(value=''), gr.update(visible=False)
|
767 |
+
|
768 |
+
chat_inputs = [
|
769 |
+
self.extend_prompt, self.history, self.images, self.use_history,
|
770 |
+
self.history_result, self.negative_prompt, self.cfg_scale,
|
771 |
+
self.rescale, self.step, self.seed, self.output_height,
|
772 |
+
self.output_width, self.video_auto, self.video_step,
|
773 |
+
self.video_frames, self.video_cfg_scale, self.video_fps,
|
774 |
+
self.video_seed
|
775 |
+
]
|
776 |
+
|
777 |
+
chat_outputs = [
|
778 |
+
self.history, self.images, self.history_result, self.chatbot,
|
779 |
+
self.text, self.gallery
|
780 |
+
]
|
781 |
+
|
782 |
+
self.chat_btn.click(run_chat,
|
783 |
+
inputs=[self.text] + chat_inputs,
|
784 |
+
outputs=chat_outputs)
|
785 |
+
|
786 |
+
self.text.submit(run_chat,
|
787 |
+
inputs=[self.text] + chat_inputs,
|
788 |
+
outputs=chat_outputs)
|
789 |
+
|
790 |
+
########################################
|
791 |
+
def retry_chat(*args):
|
792 |
+
return run_chat(self.retry_msg, *args)
|
793 |
+
|
794 |
+
self.retry_btn.click(retry_chat,
|
795 |
+
inputs=chat_inputs,
|
796 |
+
outputs=chat_outputs)
|
797 |
+
|
798 |
+
########################################
|
799 |
+
def run_example(task, img, img_mask, ref1, prompt, seed):
|
800 |
+
edit_image, edit_image_mask, edit_task = [], [], []
|
801 |
+
if img is not None:
|
802 |
+
w, h = img.size
|
803 |
+
if w > 2048:
|
804 |
+
ratio = w / 2048.
|
805 |
+
w = 2048
|
806 |
+
h = int(h / ratio)
|
807 |
+
if h > 2048:
|
808 |
+
ratio = h / 2048.
|
809 |
+
h = 2048
|
810 |
+
w = int(w / ratio)
|
811 |
+
img = img.resize((w, h))
|
812 |
+
edit_image.append(img)
|
813 |
+
edit_image_mask.append(
|
814 |
+
img_mask if img_mask is not None else None)
|
815 |
+
edit_task.append(task)
|
816 |
+
if ref1 is not None:
|
817 |
+
edit_image.append(ref1)
|
818 |
+
edit_image_mask.append(None)
|
819 |
+
edit_task.append('')
|
820 |
+
|
821 |
+
buffered = io.BytesIO()
|
822 |
+
img.save(buffered, format='PNG')
|
823 |
+
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
824 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
825 |
+
pre_info = f'Received one or more images, so image editing is conducted.\n The first input image is:\n {img_str}'
|
826 |
+
else:
|
827 |
+
pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
|
828 |
+
edit_image = None
|
829 |
+
edit_image_mask = None
|
830 |
+
edit_task = ''
|
831 |
+
|
832 |
+
img_num = len(edit_image) if edit_image is not None else 1
|
833 |
+
imgs = self.pipe(
|
834 |
+
input_image=edit_image,
|
835 |
+
input_mask=edit_image_mask,
|
836 |
+
task=edit_task,
|
837 |
+
prompt=[prompt] * img_num,
|
838 |
+
negative_prompt=[''] * img_num,
|
839 |
+
seed=seed,
|
840 |
+
)
|
841 |
+
|
842 |
+
img = imgs[0]
|
843 |
+
buffered = io.BytesIO()
|
844 |
+
img.convert('RGB').save(buffered, format='PNG')
|
845 |
+
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
846 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
847 |
+
history = [(prompt,
|
848 |
+
f'{pre_info} The generated image is:\n {img_str}')]
|
849 |
+
return self.get_history(history), gr.update(value=''), gr.update(
|
850 |
+
visible=False)
|
851 |
+
|
852 |
+
with self.eg:
|
853 |
+
self.example_task = gr.Text(label='Task Name',
|
854 |
+
value='',
|
855 |
+
visible=False)
|
856 |
+
self.example_image = gr.Image(label='Edit Image',
|
857 |
+
type='pil',
|
858 |
+
image_mode='RGB',
|
859 |
+
visible=False)
|
860 |
+
self.example_mask = gr.Image(label='Edit Image Mask',
|
861 |
+
type='pil',
|
862 |
+
image_mode='L',
|
863 |
+
visible=False)
|
864 |
+
self.example_ref_im1 = gr.Image(label='Ref Image',
|
865 |
+
type='pil',
|
866 |
+
image_mode='RGB',
|
867 |
+
visible=False)
|
868 |
+
|
869 |
+
self.examples = gr.Examples(
|
870 |
+
fn=run_example,
|
871 |
+
examples=self.chatbot_examples,
|
872 |
+
inputs=[
|
873 |
+
self.example_task, self.example_image, self.example_mask,
|
874 |
+
self.example_ref_im1, self.text, self.seed
|
875 |
+
],
|
876 |
+
outputs=[self.chatbot, self.text, self.gallery],
|
877 |
+
run_on_click=True)
|
878 |
+
|
879 |
+
########################################
|
880 |
+
def upload_image():
|
881 |
+
return (gr.update(visible=True,
|
882 |
+
scale=1), gr.update(visible=True, scale=1),
|
883 |
+
gr.update(visible=True), gr.update(visible=False),
|
884 |
+
gr.update(visible=False), gr.update(visible=False))
|
885 |
+
|
886 |
+
self.upload_btn.click(upload_image,
|
887 |
+
inputs=[],
|
888 |
+
outputs=[
|
889 |
+
self.chat_page, self.editor_page,
|
890 |
+
self.upload_tab, self.edit_tab,
|
891 |
+
self.image_view_tab, self.video_view_tab
|
892 |
+
])
|
893 |
+
|
894 |
+
########################################
|
895 |
+
def edit_image(evt: gr.SelectData):
|
896 |
+
if isinstance(evt.value, str):
|
897 |
+
img_b64s = re.findall(
|
898 |
+
'<img src="data:image/png;base64,(.*?)" style="pointer-events: none;">',
|
899 |
+
evt.value)
|
900 |
+
imgs = [
|
901 |
+
Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i))))
|
902 |
+
for i in img_b64s
|
903 |
+
]
|
904 |
+
if len(imgs) > 0:
|
905 |
+
if len(imgs) == 2:
|
906 |
+
view_img = copy.deepcopy(imgs)
|
907 |
+
edit_img = copy.deepcopy(imgs[-1])
|
908 |
+
else:
|
909 |
+
view_img = [
|
910 |
+
copy.deepcopy(imgs[-1]),
|
911 |
+
copy.deepcopy(imgs[-1])
|
912 |
+
]
|
913 |
+
edit_img = copy.deepcopy(imgs[-1])
|
914 |
+
|
915 |
+
return (gr.update(visible=True,
|
916 |
+
scale=1), gr.update(visible=True,
|
917 |
+
scale=1),
|
918 |
+
gr.update(visible=False), gr.update(visible=True),
|
919 |
+
gr.update(visible=True), gr.update(visible=False),
|
920 |
+
gr.update(value=edit_img),
|
921 |
+
gr.update(value=view_img), gr.update(value=None))
|
922 |
+
else:
|
923 |
+
return (gr.update(), gr.update(), gr.update(), gr.update(),
|
924 |
+
gr.update(), gr.update(), gr.update(), gr.update(),
|
925 |
+
gr.update())
|
926 |
+
elif isinstance(evt.value, dict) and evt.value.get(
|
927 |
+
'component', '') == 'video':
|
928 |
+
value = evt.value['value']['video']['path']
|
929 |
+
return (gr.update(visible=True,
|
930 |
+
scale=1), gr.update(visible=True, scale=1),
|
931 |
+
gr.update(visible=False), gr.update(visible=False),
|
932 |
+
gr.update(visible=False), gr.update(visible=True),
|
933 |
+
gr.update(), gr.update(), gr.update(value=value))
|
934 |
+
else:
|
935 |
+
return (gr.update(), gr.update(), gr.update(), gr.update(),
|
936 |
+
gr.update(), gr.update(), gr.update(), gr.update(),
|
937 |
+
gr.update())
|
938 |
+
|
939 |
+
self.chatbot.select(edit_image,
|
940 |
+
outputs=[
|
941 |
+
self.chat_page, self.editor_page,
|
942 |
+
self.upload_tab, self.edit_tab,
|
943 |
+
self.image_view_tab, self.video_view_tab,
|
944 |
+
self.image_editor, self.image_viewer,
|
945 |
+
self.video_viewer
|
946 |
+
])
|
947 |
+
|
948 |
+
self.image_viewer.change(lambda x: x,
|
949 |
+
inputs=self.image_viewer,
|
950 |
+
outputs=self.image_viewer)
|
951 |
+
|
952 |
+
########################################
|
953 |
+
def submit_upload_image(image, history, images):
|
954 |
+
history, images = self.add_uploaded_image_to_history(
|
955 |
+
image, history, images)
|
956 |
+
return gr.update(visible=False), gr.update(
|
957 |
+
visible=True), gr.update(
|
958 |
+
value=self.get_history(history)), history, images
|
959 |
+
|
960 |
+
self.sub_btn_1.click(
|
961 |
+
submit_upload_image,
|
962 |
+
inputs=[self.image_uploader, self.history, self.images],
|
963 |
+
outputs=[
|
964 |
+
self.editor_page, self.chat_page, self.chatbot, self.history,
|
965 |
+
self.images
|
966 |
+
])
|
967 |
+
|
968 |
+
########################################
|
969 |
+
def submit_edit_image(imagemask, mask_type, history, images):
|
970 |
+
history, images = self.add_edited_image_to_history(
|
971 |
+
imagemask, mask_type, history, images)
|
972 |
+
return gr.update(visible=False), gr.update(
|
973 |
+
visible=True), gr.update(
|
974 |
+
value=self.get_history(history)), history, images
|
975 |
+
|
976 |
+
self.sub_btn_2.click(submit_edit_image,
|
977 |
+
inputs=[
|
978 |
+
self.image_editor, self.mask_type,
|
979 |
+
self.history, self.images
|
980 |
+
],
|
981 |
+
outputs=[
|
982 |
+
self.editor_page, self.chat_page,
|
983 |
+
self.chatbot, self.history, self.images
|
984 |
+
])
|
985 |
+
|
986 |
+
########################################
|
987 |
+
def exit_edit():
|
988 |
+
return gr.update(visible=False), gr.update(visible=True, scale=3)
|
989 |
+
|
990 |
+
self.ext_btn_1.click(exit_edit,
|
991 |
+
outputs=[self.editor_page, self.chat_page])
|
992 |
+
self.ext_btn_2.click(exit_edit,
|
993 |
+
outputs=[self.editor_page, self.chat_page])
|
994 |
+
self.ext_btn_3.click(exit_edit,
|
995 |
+
outputs=[self.editor_page, self.chat_page])
|
996 |
+
self.ext_btn_4.click(exit_edit,
|
997 |
+
outputs=[self.editor_page, self.chat_page])
|
998 |
+
|
999 |
+
########################################
|
1000 |
+
def update_mask_type_info(mask_type):
|
1001 |
+
if mask_type == 'Background':
|
1002 |
+
info = 'Background mode will not erase the visual content in the mask area'
|
1003 |
+
visible = False
|
1004 |
+
elif mask_type == 'Composite':
|
1005 |
+
info = 'Composite mode will erase the visual content in the mask area'
|
1006 |
+
visible = False
|
1007 |
+
elif mask_type == 'Outpainting':
|
1008 |
+
info = 'Outpaint mode is used for preparing input image for outpainting task'
|
1009 |
+
visible = True
|
1010 |
+
return (gr.update(
|
1011 |
+
visible=True,
|
1012 |
+
value=
|
1013 |
+
f"<div style='background-color: white; padding-left: 15px; color: grey;'>{info}</div>"
|
1014 |
+
), gr.update(visible=visible))
|
1015 |
+
|
1016 |
+
self.mask_type.change(update_mask_type_info,
|
1017 |
+
inputs=self.mask_type,
|
1018 |
+
outputs=[self.mask_type_info, self.outpaint_tab])
|
1019 |
+
|
1020 |
+
########################################
|
1021 |
+
def extend_image(top_ratio, bottom_ratio, left_ratio, right_ratio,
|
1022 |
+
image):
|
1023 |
+
img = cv2.cvtColor(image['background'], cv2.COLOR_RGBA2RGB)
|
1024 |
+
h, w = img.shape[:2]
|
1025 |
+
new_h = int(h * (top_ratio + bottom_ratio + 1))
|
1026 |
+
new_w = int(w * (left_ratio + right_ratio + 1))
|
1027 |
+
start_h = int(h * top_ratio)
|
1028 |
+
start_w = int(w * left_ratio)
|
1029 |
+
new_img = np.zeros((new_h, new_w, 3), dtype=np.uint8)
|
1030 |
+
new_mask = np.ones((new_h, new_w, 1), dtype=np.uint8) * 255
|
1031 |
+
new_img[start_h:start_h + h, start_w:start_w + w, :] = img
|
1032 |
+
new_mask[start_h:start_h + h, start_w:start_w + w] = 0
|
1033 |
+
layer = np.concatenate([new_img, new_mask], axis=2)
|
1034 |
+
value = {
|
1035 |
+
'background': new_img,
|
1036 |
+
'composite': new_img,
|
1037 |
+
'layers': [layer]
|
1038 |
+
}
|
1039 |
+
return gr.update(value=value)
|
1040 |
+
|
1041 |
+
self.img_pad_btn.click(extend_image,
|
1042 |
+
inputs=[
|
1043 |
+
self.top_ext, self.bottom_ext,
|
1044 |
+
self.left_ext, self.right_ext,
|
1045 |
+
self.image_editor
|
1046 |
+
],
|
1047 |
+
outputs=self.image_editor)
|
1048 |
+
|
1049 |
+
########################################
|
1050 |
+
def clear_chat(history, images, history_result):
|
1051 |
+
history.clear()
|
1052 |
+
images.clear()
|
1053 |
+
history_result.clear()
|
1054 |
+
return history, images, history_result, self.get_history(history)
|
1055 |
+
|
1056 |
+
self.clear_btn.click(
|
1057 |
+
clear_chat,
|
1058 |
+
inputs=[self.history, self.images, self.history_result],
|
1059 |
+
outputs=[
|
1060 |
+
self.history, self.images, self.history_result, self.chatbot
|
1061 |
+
])
|
1062 |
+
|
1063 |
+
def get_history(self, history):
|
1064 |
+
info = []
|
1065 |
+
for item in history:
|
1066 |
+
new_item = [None, None]
|
1067 |
+
if isinstance(item[0], str) and item[0].endswith('.mp4'):
|
1068 |
+
new_item[0] = gr.Video(item[0], format='mp4')
|
1069 |
+
else:
|
1070 |
+
new_item[0] = item[0]
|
1071 |
+
if isinstance(item[1], str) and item[1].endswith('.mp4'):
|
1072 |
+
new_item[1] = gr.Video(item[1], format='mp4')
|
1073 |
+
else:
|
1074 |
+
new_item[1] = item[1]
|
1075 |
+
info.append(new_item)
|
1076 |
+
return info
|
1077 |
+
|
1078 |
+
def generate_random_string(self, length=20):
|
1079 |
+
letters_and_digits = string.ascii_letters + string.digits
|
1080 |
+
random_string = ''.join(
|
1081 |
+
random.choice(letters_and_digits) for i in range(length))
|
1082 |
+
return random_string
|
1083 |
+
|
1084 |
+
def add_edited_image_to_history(self, image, mask_type, history, images):
|
1085 |
+
if mask_type == 'Composite':
|
1086 |
+
img = Image.fromarray(image['composite'])
|
1087 |
+
else:
|
1088 |
+
img = Image.fromarray(image['background'])
|
1089 |
+
|
1090 |
+
img_id = get_md5(self.generate_random_string())[:12]
|
1091 |
+
save_path = os.path.join(self.cache_dir, f'{img_id}.png')
|
1092 |
+
img.convert('RGB').save(save_path)
|
1093 |
+
|
1094 |
+
mask = image['layers'][0][:, :, 3]
|
1095 |
+
mask = Image.fromarray(mask).convert('RGB')
|
1096 |
+
mask_path = os.path.join(self.cache_dir, f'{img_id}_mask.png')
|
1097 |
+
mask.save(mask_path)
|
1098 |
+
|
1099 |
+
w, h = img.size
|
1100 |
+
if w > h:
|
1101 |
+
tb_w = 128
|
1102 |
+
tb_h = int(h * tb_w / w)
|
1103 |
+
else:
|
1104 |
+
tb_h = 128
|
1105 |
+
tb_w = int(w * tb_h / h)
|
1106 |
+
|
1107 |
+
if mask_type == 'Background':
|
1108 |
+
comp_mask = np.array(mask, dtype=np.uint8)
|
1109 |
+
mask_alpha = (comp_mask[:, :, 0:1].astype(np.float32) *
|
1110 |
+
0.6).astype(np.uint8)
|
1111 |
+
comp_mask = np.concatenate([comp_mask, mask_alpha], axis=2)
|
1112 |
+
thumbnail = Image.alpha_composite(
|
1113 |
+
img.convert('RGBA'),
|
1114 |
+
Image.fromarray(comp_mask).convert('RGBA')).convert('RGB')
|
1115 |
+
else:
|
1116 |
+
thumbnail = img.convert('RGB')
|
1117 |
+
|
1118 |
+
thumbnail_path = os.path.join(self.cache_dir,
|
1119 |
+
f'{img_id}_thumbnail.jpg')
|
1120 |
+
thumbnail = thumbnail.resize((tb_w, tb_h))
|
1121 |
+
thumbnail.save(thumbnail_path, format='JPEG')
|
1122 |
+
|
1123 |
+
buffered = io.BytesIO()
|
1124 |
+
img.convert('RGB').save(buffered, format='PNG')
|
1125 |
+
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1126 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1127 |
+
|
1128 |
+
buffered = io.BytesIO()
|
1129 |
+
mask.convert('RGB').save(buffered, format='PNG')
|
1130 |
+
mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1131 |
+
mask_str = f'<img src="data:image/png;base64,{mask_b64}" style="pointer-events: none;">'
|
1132 |
+
|
1133 |
+
images[img_id] = {
|
1134 |
+
'image': save_path,
|
1135 |
+
'mask': mask_path,
|
1136 |
+
'mask_type': mask_type,
|
1137 |
+
'thumbnail': thumbnail_path
|
1138 |
+
}
|
1139 |
+
history.append((
|
1140 |
+
None,
|
1141 |
+
f'This is edited image and mask:\n {img_str} {mask_str} image ID is: {img_id}'
|
1142 |
+
))
|
1143 |
+
return history, images
|
1144 |
+
|
1145 |
+
def add_uploaded_image_to_history(self, img, history, images):
|
1146 |
+
img_id = get_md5(self.generate_random_string())[:12]
|
1147 |
+
save_path = os.path.join(self.cache_dir, f'{img_id}.png')
|
1148 |
+
w, h = img.size
|
1149 |
+
if w > 2048:
|
1150 |
+
ratio = w / 2048.
|
1151 |
+
w = 2048
|
1152 |
+
h = int(h / ratio)
|
1153 |
+
if h > 2048:
|
1154 |
+
ratio = h / 2048.
|
1155 |
+
h = 2048
|
1156 |
+
w = int(w / ratio)
|
1157 |
+
img = img.resize((w, h))
|
1158 |
+
img.save(save_path)
|
1159 |
+
|
1160 |
+
w, h = img.size
|
1161 |
+
if w > h:
|
1162 |
+
tb_w = 128
|
1163 |
+
tb_h = int(h * tb_w / w)
|
1164 |
+
else:
|
1165 |
+
tb_h = 128
|
1166 |
+
tb_w = int(w * tb_h / h)
|
1167 |
+
thumbnail_path = os.path.join(self.cache_dir,
|
1168 |
+
f'{img_id}_thumbnail.jpg')
|
1169 |
+
thumbnail = img.resize((tb_w, tb_h))
|
1170 |
+
thumbnail.save(thumbnail_path, format='JPEG')
|
1171 |
+
|
1172 |
+
images[img_id] = {
|
1173 |
+
'image': save_path,
|
1174 |
+
'mask': None,
|
1175 |
+
'mask_type': None,
|
1176 |
+
'thumbnail': thumbnail_path
|
1177 |
+
}
|
1178 |
+
|
1179 |
+
buffered = io.BytesIO()
|
1180 |
+
img.convert('RGB').save(buffered, format='PNG')
|
1181 |
+
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1182 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1183 |
+
|
1184 |
+
history.append(
|
1185 |
+
(None,
|
1186 |
+
f'This is uploaded image:\n {img_str} image ID is: {img_id}'))
|
1187 |
+
return history, images
|
1188 |
+
|
1189 |
+
|
1190 |
+
|
1191 |
+
if __name__ == '__main__':
|
1192 |
+
cfg = Config(cfg_file="config/chatbot_ui.yaml")
|
1193 |
+
|
1194 |
+
with gr.Blocks() as demo:
|
1195 |
+
chatbot = ChatBotUI(cfg)
|
1196 |
+
chatbot.create_bot_ui()
|
1197 |
+
chatbot.set_callbacks()
|
1198 |
+
|
1199 |
demo.launch()
|
config/chatbot_ui.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
WORK_DIR: ./cache/chatbot
|
2 |
+
FILE_SYSTEM:
|
3 |
+
- NAME: LocalFs
|
4 |
+
TEMP_DIR: ./cache
|
5 |
+
- NAME: ModelscopeFs
|
6 |
+
TEMP_DIR: ./cache
|
7 |
+
- NAME: HuggingfaceFs
|
8 |
+
TEMP_DIR: ./cache
|
9 |
+
#
|
10 |
+
ENABLE_I2V: False
|
11 |
+
#
|
12 |
+
MODEL:
|
13 |
+
EDIT_MODEL:
|
14 |
+
MODEL_CFG_DIR: config/models/
|
15 |
+
DEFAULT: ace_0.6b_512
|
16 |
+
I2V:
|
17 |
+
MODEL_NAME: CogVideoX-5b-I2V
|
18 |
+
MODEL_DIR: ms://ZhipuAI/CogVideoX-5b-I2V/
|
19 |
+
CAPTIONER:
|
20 |
+
MODEL_NAME: InternVL2-2B
|
21 |
+
MODEL_DIR: ms://OpenGVLab/InternVL2-2B/
|
22 |
+
PROMPT: '<image>\nThis image is the first frame of a video. Based on this image, please imagine what changes may occur in the next few seconds of the video. Please output brief description, such as "a dog running" or "a person turns to left". No more than 30 words.'
|
23 |
+
ENHANCER:
|
24 |
+
MODEL_NAME: Meta-Llama-3.1-8B-Instruct
|
25 |
+
MODEL_DIR: ms://LLM-Research/Meta-Llama-3.1-8B-Instruct/
|
config/models/ace_0.6b_512.yaml
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NAME: ACE_0.6B_512
|
2 |
+
IS_DEFAULT: False
|
3 |
+
DEFAULT_PARAS:
|
4 |
+
PARAS:
|
5 |
+
#
|
6 |
+
INPUT:
|
7 |
+
INPUT_IMAGE:
|
8 |
+
INPUT_MASK:
|
9 |
+
TASK:
|
10 |
+
PROMPT: ""
|
11 |
+
NEGATIVE_PROMPT: ""
|
12 |
+
OUTPUT_HEIGHT: 512
|
13 |
+
OUTPUT_WIDTH: 512
|
14 |
+
SAMPLER: ddim
|
15 |
+
SAMPLE_STEPS: 20
|
16 |
+
GUIDE_SCALE: 4.5
|
17 |
+
GUIDE_RESCALE: 0.5
|
18 |
+
SEED: -1
|
19 |
+
TAR_INDEX: 0
|
20 |
+
OUTPUT:
|
21 |
+
LATENT:
|
22 |
+
IMAGES:
|
23 |
+
SEED:
|
24 |
+
MODULES_PARAS:
|
25 |
+
FIRST_STAGE_MODEL:
|
26 |
+
FUNCTION:
|
27 |
+
- NAME: encode
|
28 |
+
DTYPE: float16
|
29 |
+
INPUT: ["IMAGE"]
|
30 |
+
- NAME: decode
|
31 |
+
DTYPE: float16
|
32 |
+
INPUT: ["LATENT"]
|
33 |
+
#
|
34 |
+
DIFFUSION_MODEL:
|
35 |
+
FUNCTION:
|
36 |
+
- NAME: forward
|
37 |
+
DTYPE: float16
|
38 |
+
INPUT: ["SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE"]
|
39 |
+
#
|
40 |
+
COND_STAGE_MODEL:
|
41 |
+
FUNCTION:
|
42 |
+
- NAME: encode_list
|
43 |
+
DTYPE: bfloat16
|
44 |
+
INPUT: ["PROMPT"]
|
45 |
+
#
|
46 |
+
MODEL:
|
47 |
+
NAME: LdmACE
|
48 |
+
PRETRAINED_MODEL:
|
49 |
+
IGNORE_KEYS: [ ]
|
50 |
+
SCALE_FACTOR: 0.18215
|
51 |
+
SIZE_FACTOR: 8
|
52 |
+
DECODER_BIAS: 0.5
|
53 |
+
DEFAULT_N_PROMPT: ""
|
54 |
+
TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
|
55 |
+
USE_TEXT_POS_EMBEDDINGS: True
|
56 |
+
#
|
57 |
+
DIFFUSION:
|
58 |
+
NAME: ACEDiffusion
|
59 |
+
PREDICTION_TYPE: eps
|
60 |
+
MIN_SNR_GAMMA:
|
61 |
+
NOISE_SCHEDULER:
|
62 |
+
NAME: LinearScheduler
|
63 |
+
NUM_TIMESTEPS: 1000
|
64 |
+
BETA_MIN: 0.0001
|
65 |
+
BETA_MAX: 0.02
|
66 |
+
#
|
67 |
+
DIFFUSION_MODEL:
|
68 |
+
NAME: DiTACE
|
69 |
+
PRETRAINED_MODEL: hf://scepter-studio/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth
|
70 |
+
IGNORE_KEYS: [ ]
|
71 |
+
PATCH_SIZE: 2
|
72 |
+
IN_CHANNELS: 4
|
73 |
+
HIDDEN_SIZE: 1152
|
74 |
+
DEPTH: 28
|
75 |
+
NUM_HEADS: 16
|
76 |
+
MLP_RATIO: 4.0
|
77 |
+
PRED_SIGMA: True
|
78 |
+
DROP_PATH: 0.0
|
79 |
+
WINDOW_DIZE: 0
|
80 |
+
Y_CHANNELS: 4096
|
81 |
+
MAX_SEQ_LEN: 1024
|
82 |
+
QK_NORM: True
|
83 |
+
USE_GRAD_CHECKPOINT: True
|
84 |
+
ATTENTION_BACKEND: flash_attn
|
85 |
+
#
|
86 |
+
FIRST_STAGE_MODEL:
|
87 |
+
NAME: AutoencoderKL
|
88 |
+
EMBED_DIM: 4
|
89 |
+
PRETRAINED_MODEL: hf://scepter-studio/ACE-0.6B-512px@models/vae/vae.bin
|
90 |
+
IGNORE_KEYS: []
|
91 |
+
#
|
92 |
+
ENCODER:
|
93 |
+
NAME: Encoder
|
94 |
+
CH: 128
|
95 |
+
OUT_CH: 3
|
96 |
+
NUM_RES_BLOCKS: 2
|
97 |
+
IN_CHANNELS: 3
|
98 |
+
ATTN_RESOLUTIONS: [ ]
|
99 |
+
CH_MULT: [ 1, 2, 4, 4 ]
|
100 |
+
Z_CHANNELS: 4
|
101 |
+
DOUBLE_Z: True
|
102 |
+
DROPOUT: 0.0
|
103 |
+
RESAMP_WITH_CONV: True
|
104 |
+
#
|
105 |
+
DECODER:
|
106 |
+
NAME: Decoder
|
107 |
+
CH: 128
|
108 |
+
OUT_CH: 3
|
109 |
+
NUM_RES_BLOCKS: 2
|
110 |
+
IN_CHANNELS: 3
|
111 |
+
ATTN_RESOLUTIONS: [ ]
|
112 |
+
CH_MULT: [ 1, 2, 4, 4 ]
|
113 |
+
Z_CHANNELS: 4
|
114 |
+
DROPOUT: 0.0
|
115 |
+
RESAMP_WITH_CONV: True
|
116 |
+
GIVE_PRE_END: False
|
117 |
+
TANH_OUT: False
|
118 |
+
#
|
119 |
+
COND_STAGE_MODEL:
|
120 |
+
NAME: ACETextEmbedder
|
121 |
+
PRETRAINED_MODEL: hf://scepter-studio/ACE-0.6B-512px@models/text_encoder/t5-v1_1-xxl/
|
122 |
+
TOKENIZER_PATH: hf://scepter-studio/ACE-0.6B-512px@models/tokenizer/t5-v1_1-xxl
|
123 |
+
LENGTH: 120
|
124 |
+
T5_DTYPE: bfloat16
|
125 |
+
ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
|
126 |
+
CLEAN: whitespace
|
127 |
+
USE_GRAD: False
|
example.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import os
|
4 |
+
|
5 |
+
from scepter.modules.utils.file_system import FS
|
6 |
+
|
7 |
+
|
8 |
+
def download_image(image, local_path=None):
|
9 |
+
if not FS.exists(local_path):
|
10 |
+
local_path = FS.get_from(image, local_path=local_path)
|
11 |
+
return local_path
|
12 |
+
|
13 |
+
|
14 |
+
def get_examples(cache_dir):
|
15 |
+
print('Downloading Examples ...')
|
16 |
+
examples = [
|
17 |
+
[
|
18 |
+
'Image Segmentation',
|
19 |
+
download_image(
|
20 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/db3ebaa81899.png?raw=true',
|
21 |
+
os.path.join(cache_dir, 'examples/db3ebaa81899.png')), None,
|
22 |
+
None, '{image} Segmentation', 6666
|
23 |
+
],
|
24 |
+
[
|
25 |
+
'Depth Estimation',
|
26 |
+
download_image(
|
27 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f1927c4692ba.png?raw=true',
|
28 |
+
os.path.join(cache_dir, 'examples/f1927c4692ba.png')), None,
|
29 |
+
None, '{image} Depth Estimation', 6666
|
30 |
+
],
|
31 |
+
[
|
32 |
+
'Pose Estimation',
|
33 |
+
download_image(
|
34 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/014e5bf3b4d1.png?raw=true',
|
35 |
+
os.path.join(cache_dir, 'examples/014e5bf3b4d1.png')), None,
|
36 |
+
None, '{image} distinguish the poses of the figures', 999999
|
37 |
+
],
|
38 |
+
[
|
39 |
+
'Scribble Extraction',
|
40 |
+
download_image(
|
41 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5f59a202f8ac.png?raw=true',
|
42 |
+
os.path.join(cache_dir, 'examples/5f59a202f8ac.png')), None,
|
43 |
+
None, 'Generate a scribble of {image}, please.', 6666
|
44 |
+
],
|
45 |
+
[
|
46 |
+
'Mosaic',
|
47 |
+
download_image(
|
48 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a2f52361eea.png?raw=true',
|
49 |
+
os.path.join(cache_dir, 'examples/3a2f52361eea.png')), None,
|
50 |
+
None, 'Adapt {image} into a mosaic representation.', 6666
|
51 |
+
],
|
52 |
+
[
|
53 |
+
'Edge map Extraction',
|
54 |
+
download_image(
|
55 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/b9d1e519d6e5.png?raw=true',
|
56 |
+
os.path.join(cache_dir, 'examples/b9d1e519d6e5.png')), None,
|
57 |
+
None, 'Get the edge-enhanced result for {image}.', 6666
|
58 |
+
],
|
59 |
+
[
|
60 |
+
'Grayscale',
|
61 |
+
download_image(
|
62 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4ebbe2ba29b.png?raw=true',
|
63 |
+
os.path.join(cache_dir, 'examples/c4ebbe2ba29b.png')), None,
|
64 |
+
None, 'transform {image} into a black and white one', 6666
|
65 |
+
],
|
66 |
+
[
|
67 |
+
'Contour Extraction',
|
68 |
+
download_image(
|
69 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/19652d0f6c4b.png?raw=true',
|
70 |
+
os.path.join(cache_dir,
|
71 |
+
'examples/19652d0f6c4b.png')), None, None,
|
72 |
+
'Would you be able to make a contour picture from {image} for me?',
|
73 |
+
6666
|
74 |
+
],
|
75 |
+
[
|
76 |
+
'Controllable Generation',
|
77 |
+
download_image(
|
78 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/249cda2844b7.png?raw=true',
|
79 |
+
os.path.join(cache_dir,
|
80 |
+
'examples/249cda2844b7.png')), None, None,
|
81 |
+
'Following the segmentation outcome in mask of {image}, develop a real-life image using the explanatory note in "a mighty cat lying on the bed”.',
|
82 |
+
6666
|
83 |
+
],
|
84 |
+
[
|
85 |
+
'Controllable Generation',
|
86 |
+
download_image(
|
87 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/411f6c4b8e6c.png?raw=true',
|
88 |
+
os.path.join(cache_dir,
|
89 |
+
'examples/411f6c4b8e6c.png')), None, None,
|
90 |
+
'use the depth map {image} and the text caption "a cut white cat" to create a corresponding graphic image',
|
91 |
+
999999
|
92 |
+
],
|
93 |
+
[
|
94 |
+
'Controllable Generation',
|
95 |
+
download_image(
|
96 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a35c96ed137a.png?raw=true',
|
97 |
+
os.path.join(cache_dir,
|
98 |
+
'examples/a35c96ed137a.png')), None, None,
|
99 |
+
'help translate this posture schema {image} into a colored image based on the context I provided "A beautiful woman Climbing the climbing wall, wearing a harness and climbing gear, skillfully maneuvering up the wall with her back to the camera, with a safety rope."',
|
100 |
+
3599999
|
101 |
+
],
|
102 |
+
[
|
103 |
+
'Controllable Generation',
|
104 |
+
download_image(
|
105 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/dcb2fc86f1ce.png?raw=true',
|
106 |
+
os.path.join(cache_dir,
|
107 |
+
'examples/dcb2fc86f1ce.png')), None, None,
|
108 |
+
'Transform and generate an image using mosaic {image} and "Monarch butterflies gracefully perch on vibrant purple flowers, showcasing their striking orange and black wings in a lush garden setting." description',
|
109 |
+
6666
|
110 |
+
],
|
111 |
+
[
|
112 |
+
'Controllable Generation',
|
113 |
+
download_image(
|
114 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/4cd4ee494962.png?raw=true',
|
115 |
+
os.path.join(cache_dir,
|
116 |
+
'examples/4cd4ee494962.png')), None, None,
|
117 |
+
'make this {image} colorful as per the "beautiful sunflowers"',
|
118 |
+
6666
|
119 |
+
],
|
120 |
+
[
|
121 |
+
'Controllable Generation',
|
122 |
+
download_image(
|
123 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a47e3a9cd166.png?raw=true',
|
124 |
+
os.path.join(cache_dir,
|
125 |
+
'examples/a47e3a9cd166.png')), None, None,
|
126 |
+
'Take the edge conscious {image} and the written guideline "A whimsical animated character is depicted holding a delectable cake adorned with blue and white frosting and a drizzle of chocolate. The character wears a yellow headband with a bow, matching a cozy yellow sweater. Her dark hair is styled in a braid, tied with a yellow ribbon. With a golden fork in hand, she stands ready to enjoy a slice, exuding an air of joyful anticipation. The scene is creatively rendered with a charming and playful aesthetic." and produce a realistic image.',
|
127 |
+
613725
|
128 |
+
],
|
129 |
+
[
|
130 |
+
'Controllable Generation',
|
131 |
+
download_image(
|
132 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d890ed8a3ac2.png?raw=true',
|
133 |
+
os.path.join(cache_dir,
|
134 |
+
'examples/d890ed8a3ac2.png')), None, None,
|
135 |
+
'creating a vivid image based on {image} and description "This image features a delicious rectangular tart with a flaky, golden-brown crust. The tart is topped with evenly sliced tomatoes, layered over a creamy cheese filling. Aromatic herbs are sprinkled on top, adding a touch of green and enhancing the visual appeal. The background includes a soft, textured fabric and scattered white flowers, creating an elegant and inviting presentation. Bright red tomatoes in the upper right corner hint at the fresh ingredients used in the dish."',
|
136 |
+
6666
|
137 |
+
],
|
138 |
+
[
|
139 |
+
'Controllable Generation',
|
140 |
+
download_image(
|
141 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/131ca90fd2a9.png?raw=true',
|
142 |
+
os.path.join(cache_dir,
|
143 |
+
'examples/131ca90fd2a9.png')), None, None,
|
144 |
+
'"A person sits contemplatively on the ground, surrounded by falling autumn leaves. Dressed in a green sweater and dark blue pants, they rest their chin on their hand, exuding a relaxed demeanor. Their stylish checkered slip-on shoes add a touch of flair, while a black purse lies in their lap. The backdrop of muted brown enhances the warm, cozy atmosphere of the scene." , generate the image that corresponds to the given scribble {image}.',
|
145 |
+
613725
|
146 |
+
],
|
147 |
+
[
|
148 |
+
'Image Denoising',
|
149 |
+
download_image(
|
150 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/0844a686a179.png?raw=true',
|
151 |
+
os.path.join(cache_dir,
|
152 |
+
'examples/0844a686a179.png')), None, None,
|
153 |
+
'Eliminate noise interference in {image} and maximize the crispness to obtain superior high-definition quality',
|
154 |
+
6666
|
155 |
+
],
|
156 |
+
[
|
157 |
+
'Inpainting',
|
158 |
+
download_image(
|
159 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b.png?raw=true',
|
160 |
+
os.path.join(cache_dir, 'examples/fa91b6b7e59b.png')),
|
161 |
+
download_image(
|
162 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b_mask.png?raw=true',
|
163 |
+
os.path.join(cache_dir,
|
164 |
+
'examples/fa91b6b7e59b_mask.png')), None,
|
165 |
+
'Ensure to overhaul the parts of the {image} indicated by the mask.',
|
166 |
+
6666
|
167 |
+
],
|
168 |
+
[
|
169 |
+
'Inpainting',
|
170 |
+
download_image(
|
171 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26.png?raw=true',
|
172 |
+
os.path.join(cache_dir, 'examples/632899695b26.png')),
|
173 |
+
download_image(
|
174 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26_mask.png?raw=true',
|
175 |
+
os.path.join(cache_dir,
|
176 |
+
'examples/632899695b26_mask.png')), None,
|
177 |
+
'Refashion the mask portion of {image} in accordance with "A yellow egg with a smiling face painted on it"',
|
178 |
+
6666
|
179 |
+
],
|
180 |
+
[
|
181 |
+
'Outpainting',
|
182 |
+
download_image(
|
183 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f.png?raw=true',
|
184 |
+
os.path.join(cache_dir, 'examples/f2b22c08be3f.png')),
|
185 |
+
download_image(
|
186 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f_mask.png?raw=true',
|
187 |
+
os.path.join(cache_dir,
|
188 |
+
'examples/f2b22c08be3f_mask.png')), None,
|
189 |
+
'Could the {image} be widened within the space designated by mask, while retaining the original?',
|
190 |
+
6666
|
191 |
+
],
|
192 |
+
[
|
193 |
+
'General Editing',
|
194 |
+
download_image(
|
195 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/354d17594afe.png?raw=true',
|
196 |
+
os.path.join(cache_dir,
|
197 |
+
'examples/354d17594afe.png')), None, None,
|
198 |
+
'{image} change the dog\'s posture to walking in the water, and change the background to green plants and a pond.',
|
199 |
+
6666
|
200 |
+
],
|
201 |
+
[
|
202 |
+
'General Editing',
|
203 |
+
download_image(
|
204 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/38946455752b.png?raw=true',
|
205 |
+
os.path.join(cache_dir,
|
206 |
+
'examples/38946455752b.png')), None, None,
|
207 |
+
'{image} change the color of the dress from white to red and the model\'s hair color red brown to blonde.Other parts remain unchanged',
|
208 |
+
6669
|
209 |
+
],
|
210 |
+
[
|
211 |
+
'Facial Editing',
|
212 |
+
download_image(
|
213 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3ba5202f0cd8.png?raw=true',
|
214 |
+
os.path.join(cache_dir,
|
215 |
+
'examples/3ba5202f0cd8.png')), None, None,
|
216 |
+
'Keep the same facial feature in @3ba5202f0cd8, change the woman\'s clothing from a Blue denim jacket to a white turtleneck sweater and adjust her posture so that she is supporting her chin with both hands. Other aspects, such as background, hairstyle, facial expression, etc, remain unchanged.',
|
217 |
+
99999
|
218 |
+
],
|
219 |
+
[
|
220 |
+
'Facial Editing',
|
221 |
+
download_image(
|
222 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/369365b94725.png?raw=true',
|
223 |
+
os.path.join(cache_dir, 'examples/369365b94725.png')), None,
|
224 |
+
None, '{image} Make her looking at the camera', 6666
|
225 |
+
],
|
226 |
+
[
|
227 |
+
'Facial Editing',
|
228 |
+
download_image(
|
229 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/92751f2e4a0e.png?raw=true',
|
230 |
+
os.path.join(cache_dir, 'examples/92751f2e4a0e.png')), None,
|
231 |
+
None, '{image} Remove the smile from his face', 9899999
|
232 |
+
],
|
233 |
+
[
|
234 |
+
'Render Text',
|
235 |
+
download_image(
|
236 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48.png?raw=true',
|
237 |
+
os.path.join(cache_dir, 'examples/33e9f27c2c48.png')),
|
238 |
+
download_image(
|
239 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48_mask.png?raw=true',
|
240 |
+
os.path.join(cache_dir,
|
241 |
+
'examples/33e9f27c2c48_mask.png')), None,
|
242 |
+
'Put the text "C A T" at the position marked by mask in the {image}',
|
243 |
+
6666
|
244 |
+
],
|
245 |
+
[
|
246 |
+
'Remove Text',
|
247 |
+
download_image(
|
248 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/8530a6711b2e.png?raw=true',
|
249 |
+
os.path.join(cache_dir, 'examples/8530a6711b2e.png')), None,
|
250 |
+
None, 'Aim to remove any textual element in {image}', 6666
|
251 |
+
],
|
252 |
+
[
|
253 |
+
'Remove Text',
|
254 |
+
download_image(
|
255 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6.png?raw=true',
|
256 |
+
os.path.join(cache_dir, 'examples/c4d7fb28f8f6.png')),
|
257 |
+
download_image(
|
258 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6_mask.png?raw=true',
|
259 |
+
os.path.join(cache_dir,
|
260 |
+
'examples/c4d7fb28f8f6_mask.png')), None,
|
261 |
+
'Rub out any text found in the mask sector of the {image}.', 6666
|
262 |
+
],
|
263 |
+
[
|
264 |
+
'Remove Object',
|
265 |
+
download_image(
|
266 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e2f318fa5e5b.png?raw=true',
|
267 |
+
os.path.join(cache_dir,
|
268 |
+
'examples/e2f318fa5e5b.png')), None, None,
|
269 |
+
'Remove the unicorn in this {image}, ensuring a smooth edit.',
|
270 |
+
99999
|
271 |
+
],
|
272 |
+
[
|
273 |
+
'Remove Object',
|
274 |
+
download_image(
|
275 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00.png?raw=true',
|
276 |
+
os.path.join(cache_dir, 'examples/1ae96d8aca00.png')),
|
277 |
+
download_image(
|
278 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00_mask.png?raw=true',
|
279 |
+
os.path.join(cache_dir, 'examples/1ae96d8aca00_mask.png')),
|
280 |
+
None, 'Discard the contents of the mask area from {image}.', 99999
|
281 |
+
],
|
282 |
+
[
|
283 |
+
'Add Object',
|
284 |
+
download_image(
|
285 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511.png?raw=true',
|
286 |
+
os.path.join(cache_dir, 'examples/80289f48e511.png')),
|
287 |
+
download_image(
|
288 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511_mask.png?raw=true',
|
289 |
+
os.path.join(cache_dir,
|
290 |
+
'examples/80289f48e511_mask.png')), None,
|
291 |
+
'add a Hot Air Balloon into the {image}, per the mask', 613725
|
292 |
+
],
|
293 |
+
[
|
294 |
+
'Style Transfer',
|
295 |
+
download_image(
|
296 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d725cb2009e8.png?raw=true',
|
297 |
+
os.path.join(cache_dir, 'examples/d725cb2009e8.png')), None,
|
298 |
+
None, 'Change the style of {image} to colored pencil style', 99999
|
299 |
+
],
|
300 |
+
[
|
301 |
+
'Style Transfer',
|
302 |
+
download_image(
|
303 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e0f48b3fd010.png?raw=true',
|
304 |
+
os.path.join(cache_dir, 'examples/e0f48b3fd010.png')), None,
|
305 |
+
None, 'make {image} to Walt Disney Animation style', 99999
|
306 |
+
],
|
307 |
+
[
|
308 |
+
'Style Transfer',
|
309 |
+
download_image(
|
310 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/9e73e7eeef55.png?raw=true',
|
311 |
+
os.path.join(cache_dir, 'examples/9e73e7eeef55.png')), None,
|
312 |
+
download_image(
|
313 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/2e02975293d6.png?raw=true',
|
314 |
+
os.path.join(cache_dir, 'examples/2e02975293d6.png')),
|
315 |
+
'edit {image} based on the style of {image1} ', 99999
|
316 |
+
],
|
317 |
+
[
|
318 |
+
'Try On',
|
319 |
+
download_image(
|
320 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96.png?raw=true',
|
321 |
+
os.path.join(cache_dir, 'examples/ee4ca60b8c96.png')),
|
322 |
+
download_image(
|
323 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96_mask.png?raw=true',
|
324 |
+
os.path.join(cache_dir, 'examples/ee4ca60b8c96_mask.png')),
|
325 |
+
download_image(
|
326 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ebe825bbfe3c.png?raw=true',
|
327 |
+
os.path.join(cache_dir, 'examples/ebe825bbfe3c.png')),
|
328 |
+
'Change the cloth in {image} to the one in {image1}', 99999
|
329 |
+
],
|
330 |
+
[
|
331 |
+
'Workflow',
|
332 |
+
download_image(
|
333 |
+
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/cb85353c004b.png?raw=true',
|
334 |
+
os.path.join(cache_dir, 'examples/cb85353c004b.png')), None,
|
335 |
+
None, '<workflow> ice cream {image}', 99999
|
336 |
+
],
|
337 |
+
]
|
338 |
+
print('Finish. Start building UI ...')
|
339 |
+
return examples
|
infer.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import copy
|
4 |
+
import math
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torchvision.transforms.functional as TF
|
13 |
+
|
14 |
+
from scepter.modules.model.registry import DIFFUSIONS
|
15 |
+
from scepter.modules.model.utils.basic_utils import (
|
16 |
+
check_list_of_list,
|
17 |
+
pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor,
|
18 |
+
to_device,
|
19 |
+
unpack_tensor_into_imagelist
|
20 |
+
)
|
21 |
+
from scepter.modules.utils.distribute import we
|
22 |
+
from scepter.modules.utils.logger import get_logger
|
23 |
+
|
24 |
+
from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
|
25 |
+
|
26 |
+
|
27 |
+
def process_edit_image(images,
|
28 |
+
masks,
|
29 |
+
tasks,
|
30 |
+
max_seq_len=1024,
|
31 |
+
max_aspect_ratio=4,
|
32 |
+
d=16,
|
33 |
+
**kwargs):
|
34 |
+
|
35 |
+
if not isinstance(images, list):
|
36 |
+
images = [images]
|
37 |
+
if not isinstance(masks, list):
|
38 |
+
masks = [masks]
|
39 |
+
if not isinstance(tasks, list):
|
40 |
+
tasks = [tasks]
|
41 |
+
|
42 |
+
img_tensors = []
|
43 |
+
mask_tensors = []
|
44 |
+
for img, mask, task in zip(images, masks, tasks):
|
45 |
+
if mask is None or mask == '':
|
46 |
+
mask = Image.new('L', img.size, 0)
|
47 |
+
W, H = img.size
|
48 |
+
if H / W > max_aspect_ratio:
|
49 |
+
img = TF.center_crop(img, [int(max_aspect_ratio * W), W])
|
50 |
+
mask = TF.center_crop(mask, [int(max_aspect_ratio * W), W])
|
51 |
+
elif W / H > max_aspect_ratio:
|
52 |
+
img = TF.center_crop(img, [H, int(max_aspect_ratio * H)])
|
53 |
+
mask = TF.center_crop(mask, [H, int(max_aspect_ratio * H)])
|
54 |
+
|
55 |
+
H, W = img.height, img.width
|
56 |
+
scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d))))
|
57 |
+
rH = int(H * scale) // d * d # ensure divisible by self.d
|
58 |
+
rW = int(W * scale) // d * d
|
59 |
+
|
60 |
+
img = TF.resize(img, (rH, rW),
|
61 |
+
interpolation=TF.InterpolationMode.BICUBIC)
|
62 |
+
mask = TF.resize(mask, (rH, rW),
|
63 |
+
interpolation=TF.InterpolationMode.NEAREST_EXACT)
|
64 |
+
|
65 |
+
mask = np.asarray(mask)
|
66 |
+
mask = np.where(mask > 128, 1, 0)
|
67 |
+
mask = mask.astype(
|
68 |
+
np.float32) if np.any(mask) else np.ones_like(mask).astype(
|
69 |
+
np.float32)
|
70 |
+
|
71 |
+
img_tensor = TF.to_tensor(img).to(we.device_id)
|
72 |
+
img_tensor = TF.normalize(img_tensor,
|
73 |
+
mean=[0.5, 0.5, 0.5],
|
74 |
+
std=[0.5, 0.5, 0.5])
|
75 |
+
mask_tensor = TF.to_tensor(mask).to(we.device_id)
|
76 |
+
if task in ['inpainting', 'Try On', 'Inpainting']:
|
77 |
+
mask_indicator = mask_tensor.repeat(3, 1, 1)
|
78 |
+
img_tensor[mask_indicator == 1] = -1.0
|
79 |
+
img_tensors.append(img_tensor)
|
80 |
+
mask_tensors.append(mask_tensor)
|
81 |
+
return img_tensors, mask_tensors
|
82 |
+
|
83 |
+
|
84 |
+
class TextEmbedding(nn.Module):
|
85 |
+
def __init__(self, embedding_shape):
|
86 |
+
super().__init__()
|
87 |
+
self.pos = nn.Parameter(data=torch.zeros(embedding_shape))
|
88 |
+
|
89 |
+
|
90 |
+
class ACEInference(DiffusionInference):
|
91 |
+
def __init__(self, logger=None):
|
92 |
+
if logger is None:
|
93 |
+
logger = get_logger(name='scepter')
|
94 |
+
self.logger = logger
|
95 |
+
self.loaded_model = {}
|
96 |
+
self.loaded_model_name = [
|
97 |
+
'diffusion_model', 'first_stage_model', 'cond_stage_model'
|
98 |
+
]
|
99 |
+
|
100 |
+
def init_from_cfg(self, cfg):
|
101 |
+
self.name = cfg.NAME
|
102 |
+
self.is_default = cfg.get('IS_DEFAULT', False)
|
103 |
+
module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
|
104 |
+
assert cfg.have('MODEL')
|
105 |
+
|
106 |
+
self.diffusion_model = self.infer_model(
|
107 |
+
cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
|
108 |
+
'DIFFUSION_MODEL',
|
109 |
+
None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None
|
110 |
+
self.first_stage_model = self.infer_model(
|
111 |
+
cfg.MODEL.FIRST_STAGE_MODEL,
|
112 |
+
module_paras.get(
|
113 |
+
'FIRST_STAGE_MODEL',
|
114 |
+
None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None
|
115 |
+
self.cond_stage_model = self.infer_model(
|
116 |
+
cfg.MODEL.COND_STAGE_MODEL,
|
117 |
+
module_paras.get(
|
118 |
+
'COND_STAGE_MODEL',
|
119 |
+
None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None
|
120 |
+
self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
|
121 |
+
logger=self.logger)
|
122 |
+
|
123 |
+
self.interpolate_func = lambda x: (F.interpolate(
|
124 |
+
x.unsqueeze(0),
|
125 |
+
scale_factor=1 / self.size_factor,
|
126 |
+
mode='nearest-exact') if x is not None else None)
|
127 |
+
self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', [])
|
128 |
+
self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS',
|
129 |
+
False)
|
130 |
+
if self.use_text_pos_embeddings:
|
131 |
+
self.text_position_embeddings = TextEmbedding(
|
132 |
+
(10, 4096)).eval().requires_grad_(False).to(we.device_id)
|
133 |
+
else:
|
134 |
+
self.text_position_embeddings = None
|
135 |
+
|
136 |
+
self.max_seq_len = cfg.MODEL.DIFFUSION_MODEL.MAX_SEQ_LEN
|
137 |
+
self.scale_factor = cfg.get('SCALE_FACTOR', 0.18215)
|
138 |
+
self.size_factor = cfg.get('SIZE_FACTOR', 8)
|
139 |
+
self.decoder_bias = cfg.get('DECODER_BIAS', 0)
|
140 |
+
self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
|
141 |
+
|
142 |
+
@torch.no_grad()
|
143 |
+
def encode_first_stage(self, x, **kwargs):
|
144 |
+
_, dtype = self.get_function_info(self.first_stage_model, 'encode')
|
145 |
+
with torch.autocast('cuda',
|
146 |
+
enabled=(dtype != 'float32'),
|
147 |
+
dtype=getattr(torch, dtype)):
|
148 |
+
z = [
|
149 |
+
self.scale_factor * get_model(self.first_stage_model)._encode(
|
150 |
+
i.unsqueeze(0).to(getattr(torch, dtype))) for i in x
|
151 |
+
]
|
152 |
+
return z
|
153 |
+
|
154 |
+
@torch.no_grad()
|
155 |
+
def decode_first_stage(self, z):
|
156 |
+
_, dtype = self.get_function_info(self.first_stage_model, 'decode')
|
157 |
+
with torch.autocast('cuda',
|
158 |
+
enabled=(dtype != 'float32'),
|
159 |
+
dtype=getattr(torch, dtype)):
|
160 |
+
x = [
|
161 |
+
get_model(self.first_stage_model)._decode(
|
162 |
+
1. / self.scale_factor * i.to(getattr(torch, dtype)))
|
163 |
+
for i in z
|
164 |
+
]
|
165 |
+
return x
|
166 |
+
|
167 |
+
@torch.no_grad()
|
168 |
+
def __call__(self,
|
169 |
+
image=None,
|
170 |
+
mask=None,
|
171 |
+
prompt='',
|
172 |
+
task=None,
|
173 |
+
negative_prompt='',
|
174 |
+
output_height=512,
|
175 |
+
output_width=512,
|
176 |
+
sampler='ddim',
|
177 |
+
sample_steps=20,
|
178 |
+
guide_scale=4.5,
|
179 |
+
guide_rescale=0.5,
|
180 |
+
seed=-1,
|
181 |
+
history_io=None,
|
182 |
+
tar_index=0,
|
183 |
+
**kwargs):
|
184 |
+
input_image, input_mask = image, mask
|
185 |
+
g = torch.Generator(device=we.device_id)
|
186 |
+
seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
|
187 |
+
g.manual_seed(int(seed))
|
188 |
+
|
189 |
+
if input_image is not None:
|
190 |
+
assert isinstance(input_image, list) and isinstance(
|
191 |
+
input_mask, list)
|
192 |
+
if task is None:
|
193 |
+
task = [''] * len(input_image)
|
194 |
+
if not isinstance(prompt, list):
|
195 |
+
prompt = [prompt] * len(input_image)
|
196 |
+
if history_io is not None and len(history_io) > 0:
|
197 |
+
his_image, his_maks, his_prompt, his_task = history_io[
|
198 |
+
'image'], history_io['mask'], history_io[
|
199 |
+
'prompt'], history_io['task']
|
200 |
+
assert len(his_image) == len(his_maks) == len(
|
201 |
+
his_prompt) == len(his_task)
|
202 |
+
input_image = his_image + input_image
|
203 |
+
input_mask = his_maks + input_mask
|
204 |
+
task = his_task + task
|
205 |
+
prompt = his_prompt + [prompt[-1]]
|
206 |
+
prompt = [
|
207 |
+
pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
|
208 |
+
for i, pp in enumerate(prompt)
|
209 |
+
]
|
210 |
+
|
211 |
+
edit_image, edit_image_mask = process_edit_image(
|
212 |
+
input_image, input_mask, task, max_seq_len=self.max_seq_len)
|
213 |
+
|
214 |
+
image, image_mask = edit_image[tar_index], edit_image_mask[
|
215 |
+
tar_index]
|
216 |
+
edit_image, edit_image_mask = [edit_image], [edit_image_mask]
|
217 |
+
|
218 |
+
else:
|
219 |
+
edit_image = edit_image_mask = [[]]
|
220 |
+
image = torch.zeros(
|
221 |
+
size=[3, int(output_height),
|
222 |
+
int(output_width)])
|
223 |
+
image_mask = torch.ones(
|
224 |
+
size=[1, int(output_height),
|
225 |
+
int(output_width)])
|
226 |
+
if not isinstance(prompt, list):
|
227 |
+
prompt = [prompt]
|
228 |
+
|
229 |
+
image, image_mask, prompt = [image], [image_mask], [prompt]
|
230 |
+
assert check_list_of_list(prompt) and check_list_of_list(
|
231 |
+
edit_image) and check_list_of_list(edit_image_mask)
|
232 |
+
# Assign Negative Prompt
|
233 |
+
if isinstance(negative_prompt, list):
|
234 |
+
negative_prompt = negative_prompt[0]
|
235 |
+
assert isinstance(negative_prompt, str)
|
236 |
+
|
237 |
+
n_prompt = copy.deepcopy(prompt)
|
238 |
+
for nn_p_id, nn_p in enumerate(n_prompt):
|
239 |
+
assert isinstance(nn_p, list)
|
240 |
+
n_prompt[nn_p_id][-1] = negative_prompt
|
241 |
+
|
242 |
+
ctx, null_ctx = {}, {}
|
243 |
+
|
244 |
+
# Get Noise Shape
|
245 |
+
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
246 |
+
image = to_device(image)
|
247 |
+
x = self.encode_first_stage(image)
|
248 |
+
self.dynamic_unload(self.first_stage_model,
|
249 |
+
'first_stage_model',
|
250 |
+
skip_loaded=True)
|
251 |
+
noise = [
|
252 |
+
torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
|
253 |
+
for i in x
|
254 |
+
]
|
255 |
+
noise, x_shapes = pack_imagelist_into_tensor(noise)
|
256 |
+
ctx['x_shapes'] = null_ctx['x_shapes'] = x_shapes
|
257 |
+
|
258 |
+
image_mask = to_device(image_mask, strict=False)
|
259 |
+
cond_mask = [self.interpolate_func(i) for i in image_mask
|
260 |
+
] if image_mask is not None else [None] * len(image)
|
261 |
+
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
|
262 |
+
|
263 |
+
# Encode Prompt
|
264 |
+
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
265 |
+
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
266 |
+
cont, cont_mask = getattr(get_model(self.cond_stage_model),
|
267 |
+
function_name)(prompt)
|
268 |
+
cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
|
269 |
+
cont_mask)
|
270 |
+
null_cont, null_cont_mask = getattr(get_model(self.cond_stage_model),
|
271 |
+
function_name)(n_prompt)
|
272 |
+
null_cont, null_cont_mask = self.cond_stage_embeddings(
|
273 |
+
prompt, edit_image, null_cont, null_cont_mask)
|
274 |
+
self.dynamic_unload(self.cond_stage_model,
|
275 |
+
'cond_stage_model',
|
276 |
+
skip_loaded=False)
|
277 |
+
ctx['crossattn'] = cont
|
278 |
+
null_ctx['crossattn'] = null_cont
|
279 |
+
|
280 |
+
# Encode Edit Images
|
281 |
+
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
282 |
+
edit_image = [to_device(i, strict=False) for i in edit_image]
|
283 |
+
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
284 |
+
e_img, e_mask = [], []
|
285 |
+
for u, m in zip(edit_image, edit_image_mask):
|
286 |
+
if u is None:
|
287 |
+
continue
|
288 |
+
if m is None:
|
289 |
+
m = [None] * len(u)
|
290 |
+
e_img.append(self.encode_first_stage(u, **kwargs))
|
291 |
+
e_mask.append([self.interpolate_func(i) for i in m])
|
292 |
+
self.dynamic_unload(self.first_stage_model,
|
293 |
+
'first_stage_model',
|
294 |
+
skip_loaded=True)
|
295 |
+
null_ctx['edit'] = ctx['edit'] = e_img
|
296 |
+
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
|
297 |
+
|
298 |
+
# Diffusion Process
|
299 |
+
self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
300 |
+
function_name, dtype = self.get_function_info(self.diffusion_model)
|
301 |
+
with torch.autocast('cuda',
|
302 |
+
enabled=dtype in ('float16', 'bfloat16'),
|
303 |
+
dtype=getattr(torch, dtype)):
|
304 |
+
latent = self.diffusion.sample(
|
305 |
+
noise=noise,
|
306 |
+
sampler=sampler,
|
307 |
+
model=get_model(self.diffusion_model),
|
308 |
+
model_kwargs=[{
|
309 |
+
'cond':
|
310 |
+
ctx,
|
311 |
+
'mask':
|
312 |
+
cont_mask,
|
313 |
+
'text_position_embeddings':
|
314 |
+
self.text_position_embeddings.pos if hasattr(
|
315 |
+
self.text_position_embeddings, 'pos') else None
|
316 |
+
}, {
|
317 |
+
'cond':
|
318 |
+
null_ctx,
|
319 |
+
'mask':
|
320 |
+
null_cont_mask,
|
321 |
+
'text_position_embeddings':
|
322 |
+
self.text_position_embeddings.pos if hasattr(
|
323 |
+
self.text_position_embeddings, 'pos') else None
|
324 |
+
}] if guide_scale is not None and guide_scale > 1 else {
|
325 |
+
'cond':
|
326 |
+
null_ctx,
|
327 |
+
'mask':
|
328 |
+
cont_mask,
|
329 |
+
'text_position_embeddings':
|
330 |
+
self.text_position_embeddings.pos if hasattr(
|
331 |
+
self.text_position_embeddings, 'pos') else None
|
332 |
+
},
|
333 |
+
steps=sample_steps,
|
334 |
+
show_progress=True,
|
335 |
+
seed=seed,
|
336 |
+
guide_scale=guide_scale,
|
337 |
+
guide_rescale=guide_rescale,
|
338 |
+
return_intermediate=None,
|
339 |
+
**kwargs)
|
340 |
+
self.dynamic_unload(self.diffusion_model,
|
341 |
+
'diffusion_model',
|
342 |
+
skip_loaded=False)
|
343 |
+
|
344 |
+
# Decode to Pixel Space
|
345 |
+
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
346 |
+
samples = unpack_tensor_into_imagelist(latent, x_shapes)
|
347 |
+
x_samples = self.decode_first_stage(samples)
|
348 |
+
self.dynamic_unload(self.first_stage_model,
|
349 |
+
'first_stage_model',
|
350 |
+
skip_loaded=False)
|
351 |
+
|
352 |
+
imgs = [
|
353 |
+
torch.clamp((x_i + 1.0) / 2.0 + self.decoder_bias / 255,
|
354 |
+
min=0.0,
|
355 |
+
max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
|
356 |
+
for x_i in x_samples
|
357 |
+
]
|
358 |
+
imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
|
359 |
+
return imgs
|
360 |
+
|
361 |
+
def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
|
362 |
+
if self.use_text_pos_embeddings and not torch.sum(
|
363 |
+
self.text_position_embeddings.pos) > 0:
|
364 |
+
identifier_cont, _ = getattr(get_model(self.cond_stage_model),
|
365 |
+
'encode')(self.text_indentifers,
|
366 |
+
return_mask=True)
|
367 |
+
self.text_position_embeddings.load_state_dict(
|
368 |
+
{'pos': identifier_cont[:, 0, :]})
|
369 |
+
|
370 |
+
cont_, cont_mask_ = [], []
|
371 |
+
for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):
|
372 |
+
if isinstance(pp, list):
|
373 |
+
cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])
|
374 |
+
cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])
|
375 |
+
else:
|
376 |
+
raise NotImplementedError
|
377 |
+
|
378 |
+
return cont_, cont_mask_
|
modules/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import data, model, solver
|
modules/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import dataset
|
modules/data/dataset/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .dataset import ACEDemoDataset
|
modules/data/dataset/dataset.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import io
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
from collections import defaultdict
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torchvision.transforms as T
|
13 |
+
from PIL import Image
|
14 |
+
from torchvision.transforms.functional import InterpolationMode
|
15 |
+
|
16 |
+
from scepter.modules.data.dataset.base_dataset import BaseDataset
|
17 |
+
from scepter.modules.data.dataset.registry import DATASETS
|
18 |
+
from scepter.modules.transform.io import pillow_convert
|
19 |
+
from scepter.modules.utils.config import dict_to_yaml
|
20 |
+
from scepter.modules.utils.file_system import FS
|
21 |
+
|
22 |
+
Image.MAX_IMAGE_PIXELS = None
|
23 |
+
|
24 |
+
@DATASETS.register_class()
|
25 |
+
class ACEDemoDataset(BaseDataset):
|
26 |
+
para_dict = {
|
27 |
+
'MS_DATASET_NAME': {
|
28 |
+
'value': '',
|
29 |
+
'description': 'Modelscope dataset name.'
|
30 |
+
},
|
31 |
+
'MS_DATASET_NAMESPACE': {
|
32 |
+
'value': '',
|
33 |
+
'description': 'Modelscope dataset namespace.'
|
34 |
+
},
|
35 |
+
'MS_DATASET_SUBNAME': {
|
36 |
+
'value': '',
|
37 |
+
'description': 'Modelscope dataset subname.'
|
38 |
+
},
|
39 |
+
'MS_DATASET_SPLIT': {
|
40 |
+
'value': '',
|
41 |
+
'description':
|
42 |
+
'Modelscope dataset split set name, default is train.'
|
43 |
+
},
|
44 |
+
'MS_REMAP_KEYS': {
|
45 |
+
'value':
|
46 |
+
None,
|
47 |
+
'description':
|
48 |
+
'Modelscope dataset header of list file, the default is Target:FILE; '
|
49 |
+
'If your file is not this header, please set this field, which is a map dict.'
|
50 |
+
"For example, { 'Image:FILE': 'Target:FILE' } will replace the filed Image:FILE to Target:FILE"
|
51 |
+
},
|
52 |
+
'MS_REMAP_PATH': {
|
53 |
+
'value':
|
54 |
+
None,
|
55 |
+
'description':
|
56 |
+
'When modelscope dataset name is not None, that means you use the dataset from modelscope,'
|
57 |
+
' default is None. But if you want to use the datalist from modelscope and the file from '
|
58 |
+
'local device, you can use this field to set the root path of your images. '
|
59 |
+
},
|
60 |
+
'TRIGGER_WORDS': {
|
61 |
+
'value':
|
62 |
+
'',
|
63 |
+
'description':
|
64 |
+
'The words used to describe the common features of your data, especially when you customize a '
|
65 |
+
'tuner. Use these words you can get what you want.'
|
66 |
+
},
|
67 |
+
'HIGHLIGHT_KEYWORDS': {
|
68 |
+
'value':
|
69 |
+
'',
|
70 |
+
'description':
|
71 |
+
'The keywords you want to highlight in prompt, which will be replace by <HIGHLIGHT_KEYWORDS>.'
|
72 |
+
},
|
73 |
+
'KEYWORDS_SIGN': {
|
74 |
+
'value':
|
75 |
+
'',
|
76 |
+
'description':
|
77 |
+
'The keywords sign you want to add, which is like <{HIGHLIGHT_KEYWORDS}{KEYWORDS_SIGN}>'
|
78 |
+
},
|
79 |
+
}
|
80 |
+
|
81 |
+
def __init__(self, cfg, logger=None):
|
82 |
+
super().__init__(cfg=cfg, logger=logger)
|
83 |
+
from modelscope import MsDataset
|
84 |
+
from modelscope.utils.constant import DownloadMode
|
85 |
+
ms_dataset_name = cfg.get('MS_DATASET_NAME', None)
|
86 |
+
ms_dataset_namespace = cfg.get('MS_DATASET_NAMESPACE', None)
|
87 |
+
ms_dataset_subname = cfg.get('MS_DATASET_SUBNAME', None)
|
88 |
+
ms_dataset_split = cfg.get('MS_DATASET_SPLIT', 'train')
|
89 |
+
ms_remap_keys = cfg.get('MS_REMAP_KEYS', None)
|
90 |
+
ms_remap_path = cfg.get('MS_REMAP_PATH', None)
|
91 |
+
|
92 |
+
self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)
|
93 |
+
self.max_aspect_ratio = cfg.get('MAX_ASPECT_RATIO', 4)
|
94 |
+
self.d = cfg.get('DOWNSAMPLE_RATIO', 16)
|
95 |
+
self.replace_style = cfg.get('REPLACE_STYLE', False)
|
96 |
+
self.trigger_words = cfg.get('TRIGGER_WORDS', '')
|
97 |
+
self.replace_keywords = cfg.get('HIGHLIGHT_KEYWORDS', '')
|
98 |
+
self.keywords_sign = cfg.get('KEYWORDS_SIGN', '')
|
99 |
+
self.add_indicator = cfg.get('ADD_INDICATOR', False)
|
100 |
+
# Use modelscope dataset
|
101 |
+
if not ms_dataset_name:
|
102 |
+
raise ValueError(
|
103 |
+
'Your must set MS_DATASET_NAME as modelscope dataset or your local dataset orignized '
|
104 |
+
'as modelscope dataset.')
|
105 |
+
if FS.exists(ms_dataset_name):
|
106 |
+
ms_dataset_name = FS.get_dir_to_local_dir(ms_dataset_name)
|
107 |
+
self.ms_dataset_name = ms_dataset_name
|
108 |
+
# ms_remap_path = ms_dataset_name
|
109 |
+
try:
|
110 |
+
self.data = MsDataset.load(str(ms_dataset_name),
|
111 |
+
namespace=ms_dataset_namespace,
|
112 |
+
subset_name=ms_dataset_subname,
|
113 |
+
split=ms_dataset_split)
|
114 |
+
except Exception:
|
115 |
+
self.logger.info(
|
116 |
+
"Load Modelscope dataset failed, retry with download_mode='force_redownload'."
|
117 |
+
)
|
118 |
+
try:
|
119 |
+
self.data = MsDataset.load(
|
120 |
+
str(ms_dataset_name),
|
121 |
+
namespace=ms_dataset_namespace,
|
122 |
+
subset_name=ms_dataset_subname,
|
123 |
+
split=ms_dataset_split,
|
124 |
+
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
125 |
+
except Exception as sec_e:
|
126 |
+
raise ValueError(f'Load Modelscope dataset failed {sec_e}.')
|
127 |
+
if ms_remap_keys:
|
128 |
+
self.data = self.data.remap_columns(ms_remap_keys.get_dict())
|
129 |
+
|
130 |
+
if ms_remap_path:
|
131 |
+
|
132 |
+
def map_func(example):
|
133 |
+
return {
|
134 |
+
k: os.path.join(ms_remap_path, v)
|
135 |
+
if k.endswith(':FILE') else v
|
136 |
+
for k, v in example.items()
|
137 |
+
}
|
138 |
+
|
139 |
+
self.data = self.data.ds_instance.map(map_func)
|
140 |
+
|
141 |
+
self.transforms = T.Compose([
|
142 |
+
T.ToTensor(),
|
143 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
144 |
+
])
|
145 |
+
|
146 |
+
def __len__(self):
|
147 |
+
if self.mode == 'train':
|
148 |
+
return sys.maxsize
|
149 |
+
else:
|
150 |
+
return len(self.data)
|
151 |
+
|
152 |
+
def _get(self, index: int):
|
153 |
+
current_data = self.data[index % len(self.data)]
|
154 |
+
|
155 |
+
tar_image_path = current_data.get('Target:FILE', '')
|
156 |
+
src_image_path = current_data.get('Source:FILE', '')
|
157 |
+
|
158 |
+
style = current_data.get('Style', '')
|
159 |
+
prompt = current_data.get('Prompt', current_data.get('prompt', ''))
|
160 |
+
if self.replace_style and not style == '':
|
161 |
+
prompt = prompt.replace(style, f'<{self.keywords_sign}>')
|
162 |
+
|
163 |
+
elif not self.replace_keywords.strip() == '':
|
164 |
+
prompt = prompt.replace(
|
165 |
+
self.replace_keywords,
|
166 |
+
'<' + self.replace_keywords + f'{self.keywords_sign}>')
|
167 |
+
|
168 |
+
if not self.trigger_words == '':
|
169 |
+
prompt = self.trigger_words.strip() + ' ' + prompt
|
170 |
+
|
171 |
+
src_image = self.load_image(self.ms_dataset_name,
|
172 |
+
src_image_path,
|
173 |
+
cvt_type='RGB')
|
174 |
+
tar_image = self.load_image(self.ms_dataset_name,
|
175 |
+
tar_image_path,
|
176 |
+
cvt_type='RGB')
|
177 |
+
src_image = self.image_preprocess(src_image)
|
178 |
+
tar_image = self.image_preprocess(tar_image)
|
179 |
+
|
180 |
+
tar_image = self.transforms(tar_image)
|
181 |
+
src_image = self.transforms(src_image)
|
182 |
+
src_mask = torch.ones_like(src_image[[0]])
|
183 |
+
tar_mask = torch.ones_like(tar_image[[0]])
|
184 |
+
if self.add_indicator:
|
185 |
+
if '{image}' not in prompt:
|
186 |
+
prompt = '{image}, ' + prompt
|
187 |
+
|
188 |
+
return {
|
189 |
+
'edit_image': [src_image],
|
190 |
+
'edit_image_mask': [src_mask],
|
191 |
+
'image': tar_image,
|
192 |
+
'image_mask': tar_mask,
|
193 |
+
'prompt': [prompt],
|
194 |
+
}
|
195 |
+
|
196 |
+
def load_image(self, prefix, img_path, cvt_type=None):
|
197 |
+
if img_path is None or img_path == '':
|
198 |
+
return None
|
199 |
+
img_path = os.path.join(prefix, img_path)
|
200 |
+
with FS.get_object(img_path) as image_bytes:
|
201 |
+
image = Image.open(io.BytesIO(image_bytes))
|
202 |
+
if cvt_type is not None:
|
203 |
+
image = pillow_convert(image, cvt_type)
|
204 |
+
return image
|
205 |
+
|
206 |
+
def image_preprocess(self,
|
207 |
+
img,
|
208 |
+
size=None,
|
209 |
+
interpolation=InterpolationMode.BILINEAR):
|
210 |
+
H, W = img.height, img.width
|
211 |
+
if H / W > self.max_aspect_ratio:
|
212 |
+
img = T.CenterCrop((self.max_aspect_ratio * W, W))(img)
|
213 |
+
elif W / H > self.max_aspect_ratio:
|
214 |
+
img = T.CenterCrop((H, self.max_aspect_ratio * H))(img)
|
215 |
+
|
216 |
+
if size is None:
|
217 |
+
# resize image for max_seq_len, while keep the aspect ratio
|
218 |
+
H, W = img.height, img.width
|
219 |
+
scale = min(
|
220 |
+
1.0,
|
221 |
+
math.sqrt(self.max_seq_len / ((H / self.d) * (W / self.d))))
|
222 |
+
rH = int(
|
223 |
+
H * scale) // self.d * self.d # ensure divisible by self.d
|
224 |
+
rW = int(W * scale) // self.d * self.d
|
225 |
+
else:
|
226 |
+
rH, rW = size
|
227 |
+
img = T.Resize((rH, rW), interpolation=interpolation,
|
228 |
+
antialias=True)(img)
|
229 |
+
return np.array(img, dtype=np.uint8)
|
230 |
+
|
231 |
+
@staticmethod
|
232 |
+
def get_config_template():
|
233 |
+
return dict_to_yaml('DATASet',
|
234 |
+
__class__.__name__,
|
235 |
+
ACEDemoDataset.para_dict,
|
236 |
+
set_name=True)
|
237 |
+
|
238 |
+
@staticmethod
|
239 |
+
def collate_fn(batch):
|
240 |
+
collect = defaultdict(list)
|
241 |
+
for sample in batch:
|
242 |
+
for k, v in sample.items():
|
243 |
+
collect[k].append(v)
|
244 |
+
|
245 |
+
new_batch = dict()
|
246 |
+
for k, v in collect.items():
|
247 |
+
if all([i is None for i in v]):
|
248 |
+
new_batch[k] = None
|
249 |
+
else:
|
250 |
+
new_batch[k] = v
|
251 |
+
|
252 |
+
return new_batch
|
modules/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import backbone, embedder, diffusion, network
|
modules/model/backbone/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
from .ace import DiTACE
|
modules/model/backbone/ace.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import re
|
4 |
+
from collections import OrderedDict
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from einops import rearrange
|
10 |
+
from torch.nn.utils.rnn import pad_sequence
|
11 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
12 |
+
|
13 |
+
from scepter.modules.model.base_model import BaseModel
|
14 |
+
from scepter.modules.model.registry import BACKBONES
|
15 |
+
from scepter.modules.utils.config import dict_to_yaml
|
16 |
+
from scepter.modules.utils.file_system import FS
|
17 |
+
|
18 |
+
from .layers import (
|
19 |
+
Mlp,
|
20 |
+
TimestepEmbedder,
|
21 |
+
PatchEmbed,
|
22 |
+
DiTACEBlock,
|
23 |
+
T2IFinalLayer
|
24 |
+
)
|
25 |
+
from .pos_embed import rope_params
|
26 |
+
|
27 |
+
|
28 |
+
@BACKBONES.register_class()
|
29 |
+
class DiTACE(BaseModel):
|
30 |
+
|
31 |
+
para_dict = {
|
32 |
+
'PATCH_SIZE': {
|
33 |
+
'value': 2,
|
34 |
+
'description': ''
|
35 |
+
},
|
36 |
+
'IN_CHANNELS': {
|
37 |
+
'value': 4,
|
38 |
+
'description': ''
|
39 |
+
},
|
40 |
+
'HIDDEN_SIZE': {
|
41 |
+
'value': 1152,
|
42 |
+
'description': ''
|
43 |
+
},
|
44 |
+
'DEPTH': {
|
45 |
+
'value': 28,
|
46 |
+
'description': ''
|
47 |
+
},
|
48 |
+
'NUM_HEADS': {
|
49 |
+
'value': 16,
|
50 |
+
'description': ''
|
51 |
+
},
|
52 |
+
'MLP_RATIO': {
|
53 |
+
'value': 4.0,
|
54 |
+
'description': ''
|
55 |
+
},
|
56 |
+
'PRED_SIGMA': {
|
57 |
+
'value': True,
|
58 |
+
'description': ''
|
59 |
+
},
|
60 |
+
'DROP_PATH': {
|
61 |
+
'value': 0.,
|
62 |
+
'description': ''
|
63 |
+
},
|
64 |
+
'WINDOW_SIZE': {
|
65 |
+
'value': 0,
|
66 |
+
'description': ''
|
67 |
+
},
|
68 |
+
'WINDOW_BLOCK_INDEXES': {
|
69 |
+
'value': None,
|
70 |
+
'description': ''
|
71 |
+
},
|
72 |
+
'Y_CHANNELS': {
|
73 |
+
'value': 4096,
|
74 |
+
'description': ''
|
75 |
+
},
|
76 |
+
'ATTENTION_BACKEND': {
|
77 |
+
'value': None,
|
78 |
+
'description': ''
|
79 |
+
},
|
80 |
+
'QK_NORM': {
|
81 |
+
'value': True,
|
82 |
+
'description': 'Whether to use RMSNorm for query and key.',
|
83 |
+
},
|
84 |
+
}
|
85 |
+
para_dict.update(BaseModel.para_dict)
|
86 |
+
|
87 |
+
def __init__(self, cfg, logger):
|
88 |
+
super().__init__(cfg, logger=logger)
|
89 |
+
self.window_block_indexes = cfg.get('WINDOW_BLOCK_INDEXES', None)
|
90 |
+
if self.window_block_indexes is None:
|
91 |
+
self.window_block_indexes = []
|
92 |
+
self.pred_sigma = cfg.get('PRED_SIGMA', True)
|
93 |
+
self.in_channels = cfg.get('IN_CHANNELS', 4)
|
94 |
+
self.out_channels = self.in_channels * 2 if self.pred_sigma else self.in_channels
|
95 |
+
self.patch_size = cfg.get('PATCH_SIZE', 2)
|
96 |
+
self.num_heads = cfg.get('NUM_HEADS', 16)
|
97 |
+
self.hidden_size = cfg.get('HIDDEN_SIZE', 1152)
|
98 |
+
self.y_channels = cfg.get('Y_CHANNELS', 4096)
|
99 |
+
self.drop_path = cfg.get('DROP_PATH', 0.)
|
100 |
+
self.depth = cfg.get('DEPTH', 28)
|
101 |
+
self.mlp_ratio = cfg.get('MLP_RATIO', 4.0)
|
102 |
+
self.use_grad_checkpoint = cfg.get('USE_GRAD_CHECKPOINT', False)
|
103 |
+
self.attention_backend = cfg.get('ATTENTION_BACKEND', None)
|
104 |
+
self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)
|
105 |
+
self.qk_norm = cfg.get('QK_NORM', False)
|
106 |
+
self.ignore_keys = cfg.get('IGNORE_KEYS', [])
|
107 |
+
assert (self.hidden_size % self.num_heads
|
108 |
+
) == 0 and (self.hidden_size // self.num_heads) % 2 == 0
|
109 |
+
d = self.hidden_size // self.num_heads
|
110 |
+
self.freqs = torch.cat(
|
111 |
+
[
|
112 |
+
rope_params(self.max_seq_len, d - 4 * (d // 6)), # T (~1/3)
|
113 |
+
rope_params(self.max_seq_len, 2 * (d // 6)), # H (~1/3)
|
114 |
+
rope_params(self.max_seq_len, 2 * (d // 6)) # W (~1/3)
|
115 |
+
],
|
116 |
+
dim=1)
|
117 |
+
|
118 |
+
# init embedder
|
119 |
+
self.x_embedder = PatchEmbed(self.patch_size,
|
120 |
+
self.in_channels + 1,
|
121 |
+
self.hidden_size,
|
122 |
+
bias=True,
|
123 |
+
flatten=False)
|
124 |
+
self.t_embedder = TimestepEmbedder(self.hidden_size)
|
125 |
+
self.y_embedder = Mlp(in_features=self.y_channels,
|
126 |
+
hidden_features=self.hidden_size,
|
127 |
+
out_features=self.hidden_size,
|
128 |
+
act_layer=lambda: nn.GELU(approximate='tanh'),
|
129 |
+
drop=0)
|
130 |
+
self.t_block = nn.Sequential(
|
131 |
+
nn.SiLU(),
|
132 |
+
nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True))
|
133 |
+
# init blocks
|
134 |
+
drop_path = [
|
135 |
+
x.item() for x in torch.linspace(0, self.drop_path, self.depth)
|
136 |
+
]
|
137 |
+
self.blocks = nn.ModuleList([
|
138 |
+
DiTACEBlock(self.hidden_size,
|
139 |
+
self.num_heads,
|
140 |
+
mlp_ratio=self.mlp_ratio,
|
141 |
+
drop_path=drop_path[i],
|
142 |
+
window_size=self.window_size
|
143 |
+
if i in self.window_block_indexes else 0,
|
144 |
+
backend=self.attention_backend,
|
145 |
+
use_condition=True,
|
146 |
+
qk_norm=self.qk_norm) for i in range(self.depth)
|
147 |
+
])
|
148 |
+
self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size,
|
149 |
+
self.out_channels)
|
150 |
+
self.initialize_weights()
|
151 |
+
|
152 |
+
def load_pretrained_model(self, pretrained_model):
|
153 |
+
if pretrained_model:
|
154 |
+
with FS.get_from(pretrained_model, wait_finish=True) as local_path:
|
155 |
+
model = torch.load(local_path, map_location='cpu')
|
156 |
+
if 'state_dict' in model:
|
157 |
+
model = model['state_dict']
|
158 |
+
new_ckpt = OrderedDict()
|
159 |
+
for k, v in model.items():
|
160 |
+
if self.ignore_keys is not None:
|
161 |
+
if (isinstance(self.ignore_keys, str) and re.match(self.ignore_keys, k)) or \
|
162 |
+
(isinstance(self.ignore_keys, list) and k in self.ignore_keys):
|
163 |
+
continue
|
164 |
+
k = k.replace('.cross_attn.q_linear.', '.cross_attn.q.')
|
165 |
+
k = k.replace('.cross_attn.proj.',
|
166 |
+
'.cross_attn.o.').replace(
|
167 |
+
'.attn.proj.', '.attn.o.')
|
168 |
+
if '.cross_attn.kv_linear.' in k:
|
169 |
+
k_p, v_p = torch.split(v, v.shape[0] // 2)
|
170 |
+
new_ckpt[k.replace('.cross_attn.kv_linear.',
|
171 |
+
'.cross_attn.k.')] = k_p
|
172 |
+
new_ckpt[k.replace('.cross_attn.kv_linear.',
|
173 |
+
'.cross_attn.v.')] = v_p
|
174 |
+
elif '.attn.qkv.' in k:
|
175 |
+
q_p, k_p, v_p = torch.split(v, v.shape[0] // 3)
|
176 |
+
new_ckpt[k.replace('.attn.qkv.', '.attn.q.')] = q_p
|
177 |
+
new_ckpt[k.replace('.attn.qkv.', '.attn.k.')] = k_p
|
178 |
+
new_ckpt[k.replace('.attn.qkv.', '.attn.v.')] = v_p
|
179 |
+
elif 'y_embedder.y_proj.' in k:
|
180 |
+
new_ckpt[k.replace('y_embedder.y_proj.',
|
181 |
+
'y_embedder.')] = v
|
182 |
+
elif k in ('x_embedder.proj.weight'):
|
183 |
+
model_p = self.state_dict()[k]
|
184 |
+
if v.shape != model_p.shape:
|
185 |
+
model_p.zero_()
|
186 |
+
model_p[:, :4, :, :].copy_(v)
|
187 |
+
new_ckpt[k] = torch.nn.parameter.Parameter(model_p)
|
188 |
+
else:
|
189 |
+
new_ckpt[k] = v
|
190 |
+
elif k in ('x_embedder.proj.bias'):
|
191 |
+
new_ckpt[k] = v
|
192 |
+
else:
|
193 |
+
new_ckpt[k] = v
|
194 |
+
missing, unexpected = self.load_state_dict(new_ckpt,
|
195 |
+
strict=False)
|
196 |
+
print(
|
197 |
+
f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
198 |
+
)
|
199 |
+
if len(missing) > 0:
|
200 |
+
print(f'Missing Keys:\n {missing}')
|
201 |
+
if len(unexpected) > 0:
|
202 |
+
print(f'\nUnexpected Keys:\n {unexpected}')
|
203 |
+
|
204 |
+
def forward(self,
|
205 |
+
x,
|
206 |
+
t=None,
|
207 |
+
cond=dict(),
|
208 |
+
mask=None,
|
209 |
+
text_position_embeddings=None,
|
210 |
+
gc_seg=-1,
|
211 |
+
**kwargs):
|
212 |
+
if self.freqs.device != x.device:
|
213 |
+
self.freqs = self.freqs.to(x.device)
|
214 |
+
if isinstance(cond, dict):
|
215 |
+
context = cond.get('crossattn', None)
|
216 |
+
else:
|
217 |
+
context = cond
|
218 |
+
if text_position_embeddings is not None:
|
219 |
+
# default use the text_position_embeddings in state_dict
|
220 |
+
# if state_dict doesn't including this key, use the arg: text_position_embeddings
|
221 |
+
proj_position_embeddings = self.y_embedder(
|
222 |
+
text_position_embeddings)
|
223 |
+
else:
|
224 |
+
proj_position_embeddings = None
|
225 |
+
|
226 |
+
ctx_batch, txt_lens = [], []
|
227 |
+
if mask is not None and isinstance(mask, list):
|
228 |
+
for ctx, ctx_mask in zip(context, mask):
|
229 |
+
for frame_id, one_ctx in enumerate(zip(ctx, ctx_mask)):
|
230 |
+
u, m = one_ctx
|
231 |
+
t_len = m.flatten().sum() # l
|
232 |
+
u = u[:t_len]
|
233 |
+
u = self.y_embedder(u)
|
234 |
+
if frame_id == 0:
|
235 |
+
u = u + proj_position_embeddings[
|
236 |
+
len(ctx) -
|
237 |
+
1] if proj_position_embeddings is not None else u
|
238 |
+
else:
|
239 |
+
u = u + proj_position_embeddings[
|
240 |
+
frame_id -
|
241 |
+
1] if proj_position_embeddings is not None else u
|
242 |
+
ctx_batch.append(u)
|
243 |
+
txt_lens.append(t_len)
|
244 |
+
else:
|
245 |
+
raise TypeError
|
246 |
+
y = torch.cat(ctx_batch, dim=0)
|
247 |
+
txt_lens = torch.LongTensor(txt_lens).to(x.device, non_blocking=True)
|
248 |
+
|
249 |
+
batch_frames = []
|
250 |
+
for u, shape, m in zip(x, cond['x_shapes'], cond['x_mask']):
|
251 |
+
u = u[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
|
252 |
+
m = torch.ones_like(u[[0], :, :]) if m is None else m.squeeze(0)
|
253 |
+
batch_frames.append([torch.cat([u, m], dim=0).unsqueeze(0)])
|
254 |
+
if 'edit' in cond:
|
255 |
+
for i, (edit, edit_mask) in enumerate(
|
256 |
+
zip(cond['edit'], cond['edit_mask'])):
|
257 |
+
if edit is None:
|
258 |
+
continue
|
259 |
+
for u, m in zip(edit, edit_mask):
|
260 |
+
u = u.squeeze(0)
|
261 |
+
m = torch.ones_like(
|
262 |
+
u[[0], :, :]) if m is None else m.squeeze(0)
|
263 |
+
batch_frames[i].append(
|
264 |
+
torch.cat([u, m], dim=0).unsqueeze(0))
|
265 |
+
|
266 |
+
patch_batch, shape_batch, self_x_len, cross_x_len = [], [], [], []
|
267 |
+
for frames in batch_frames:
|
268 |
+
patches, patch_shapes = [], []
|
269 |
+
self_x_len.append(0)
|
270 |
+
for frame_id, u in enumerate(frames):
|
271 |
+
u = self.x_embedder(u)
|
272 |
+
h, w = u.size(2), u.size(3)
|
273 |
+
u = rearrange(u, '1 c h w -> (h w) c')
|
274 |
+
if frame_id == 0:
|
275 |
+
u = u + proj_position_embeddings[
|
276 |
+
len(frames) -
|
277 |
+
1] if proj_position_embeddings is not None else u
|
278 |
+
else:
|
279 |
+
u = u + proj_position_embeddings[
|
280 |
+
frame_id -
|
281 |
+
1] if proj_position_embeddings is not None else u
|
282 |
+
patches.append(u)
|
283 |
+
patch_shapes.append([h, w])
|
284 |
+
cross_x_len.append(h * w) # b*s, 1
|
285 |
+
self_x_len[-1] += h * w # b, 1
|
286 |
+
# u = torch.cat(patches, dim=0)
|
287 |
+
patch_batch.extend(patches)
|
288 |
+
shape_batch.append(
|
289 |
+
torch.LongTensor(patch_shapes).to(x.device, non_blocking=True))
|
290 |
+
# repeat t to align with x
|
291 |
+
t = torch.cat([t[i].repeat(l) for i, l in enumerate(self_x_len)])
|
292 |
+
self_x_len, cross_x_len = (torch.LongTensor(self_x_len).to(
|
293 |
+
x.device, non_blocking=True), torch.LongTensor(cross_x_len).to(
|
294 |
+
x.device, non_blocking=True))
|
295 |
+
# x = pad_sequence(tuple(patch_batch), batch_first=True) # b, s*max(cl), c
|
296 |
+
x = torch.cat(patch_batch, dim=0)
|
297 |
+
x_shapes = pad_sequence(tuple(shape_batch),
|
298 |
+
batch_first=True) # b, max(len(frames)), 2
|
299 |
+
t = self.t_embedder(t) # (N, D)
|
300 |
+
t0 = self.t_block(t)
|
301 |
+
# y = self.y_embedder(context)
|
302 |
+
|
303 |
+
kwargs = dict(y=y,
|
304 |
+
t=t0,
|
305 |
+
x_shapes=x_shapes,
|
306 |
+
self_x_len=self_x_len,
|
307 |
+
cross_x_len=cross_x_len,
|
308 |
+
freqs=self.freqs,
|
309 |
+
txt_lens=txt_lens)
|
310 |
+
if self.use_grad_checkpoint and gc_seg >= 0:
|
311 |
+
x = checkpoint_sequential(
|
312 |
+
functions=[partial(block, **kwargs) for block in self.blocks],
|
313 |
+
segments=gc_seg if gc_seg > 0 else len(self.blocks),
|
314 |
+
input=x,
|
315 |
+
use_reentrant=False)
|
316 |
+
else:
|
317 |
+
for block in self.blocks:
|
318 |
+
x = block(x, **kwargs)
|
319 |
+
x = self.final_layer(x, t) # b*s*n, d
|
320 |
+
outs, cur_length = [], 0
|
321 |
+
p = self.patch_size
|
322 |
+
for seq_length, shape in zip(self_x_len, shape_batch):
|
323 |
+
x_i = x[cur_length:cur_length + seq_length]
|
324 |
+
h, w = shape[0].tolist()
|
325 |
+
u = x_i[:h * w].view(h, w, p, p, -1)
|
326 |
+
u = rearrange(u, 'h w p q c -> (h p w q) c'
|
327 |
+
) # dump into sequence for following tensor ops
|
328 |
+
cur_length = cur_length + seq_length
|
329 |
+
outs.append(u)
|
330 |
+
x = pad_sequence(tuple(outs), batch_first=True).permute(0, 2, 1)
|
331 |
+
if self.pred_sigma:
|
332 |
+
return x.chunk(2, dim=1)[0]
|
333 |
+
else:
|
334 |
+
return x
|
335 |
+
|
336 |
+
def initialize_weights(self):
|
337 |
+
# Initialize transformer layers:
|
338 |
+
def _basic_init(module):
|
339 |
+
if isinstance(module, nn.Linear):
|
340 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
341 |
+
if module.bias is not None:
|
342 |
+
nn.init.constant_(module.bias, 0)
|
343 |
+
|
344 |
+
self.apply(_basic_init)
|
345 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
346 |
+
w = self.x_embedder.proj.weight.data
|
347 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
348 |
+
# Initialize timestep embedding MLP:
|
349 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
350 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
351 |
+
nn.init.normal_(self.t_block[1].weight, std=0.02)
|
352 |
+
# Initialize caption embedding MLP:
|
353 |
+
if hasattr(self, 'y_embedder'):
|
354 |
+
nn.init.normal_(self.y_embedder.fc1.weight, std=0.02)
|
355 |
+
nn.init.normal_(self.y_embedder.fc2.weight, std=0.02)
|
356 |
+
# Zero-out adaLN modulation layers
|
357 |
+
for block in self.blocks:
|
358 |
+
nn.init.constant_(block.cross_attn.o.weight, 0)
|
359 |
+
nn.init.constant_(block.cross_attn.o.bias, 0)
|
360 |
+
# Zero-out output layers:
|
361 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
362 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
363 |
+
|
364 |
+
@property
|
365 |
+
def dtype(self):
|
366 |
+
return next(self.parameters()).dtype
|
367 |
+
|
368 |
+
@staticmethod
|
369 |
+
def get_config_template():
|
370 |
+
return dict_to_yaml('BACKBONE',
|
371 |
+
__class__.__name__,
|
372 |
+
DiTACE.para_dict,
|
373 |
+
set_name=True)
|
modules/model/backbone/layers.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import math
|
4 |
+
import warnings
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from .pos_embed import rope_apply_multires as rope_apply
|
8 |
+
|
9 |
+
try:
|
10 |
+
from flash_attn import (flash_attn_varlen_func)
|
11 |
+
FLASHATTN_IS_AVAILABLE = True
|
12 |
+
except ImportError as e:
|
13 |
+
FLASHATTN_IS_AVAILABLE = False
|
14 |
+
flash_attn_varlen_func = None
|
15 |
+
warnings.warn(f'{e}')
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"drop_path",
|
19 |
+
"modulate",
|
20 |
+
"PatchEmbed",
|
21 |
+
"DropPath",
|
22 |
+
"RMSNorm",
|
23 |
+
"Mlp",
|
24 |
+
"TimestepEmbedder",
|
25 |
+
"DiTEditBlock",
|
26 |
+
"MultiHeadAttentionDiTEdit",
|
27 |
+
"T2IFinalLayer",
|
28 |
+
]
|
29 |
+
|
30 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
31 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
32 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
33 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
34 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
35 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
36 |
+
'survival rate' as the argument.
|
37 |
+
"""
|
38 |
+
if drop_prob == 0. or not training:
|
39 |
+
return x
|
40 |
+
keep_prob = 1 - drop_prob
|
41 |
+
shape = (x.shape[0], ) + (1, ) * (
|
42 |
+
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
43 |
+
random_tensor = keep_prob + torch.rand(
|
44 |
+
shape, dtype=x.dtype, device=x.device)
|
45 |
+
random_tensor.floor_() # binarize
|
46 |
+
output = x.div(keep_prob) * random_tensor
|
47 |
+
return output
|
48 |
+
|
49 |
+
|
50 |
+
def modulate(x, shift, scale, unsqueeze=False):
|
51 |
+
if unsqueeze:
|
52 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
53 |
+
else:
|
54 |
+
return x * (1 + scale) + shift
|
55 |
+
|
56 |
+
|
57 |
+
class PatchEmbed(nn.Module):
|
58 |
+
""" 2D Image to Patch Embedding
|
59 |
+
"""
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
patch_size=16,
|
63 |
+
in_chans=3,
|
64 |
+
embed_dim=768,
|
65 |
+
norm_layer=None,
|
66 |
+
flatten=True,
|
67 |
+
bias=True,
|
68 |
+
):
|
69 |
+
super().__init__()
|
70 |
+
self.flatten = flatten
|
71 |
+
self.proj = nn.Conv2d(in_chans,
|
72 |
+
embed_dim,
|
73 |
+
kernel_size=patch_size,
|
74 |
+
stride=patch_size,
|
75 |
+
bias=bias)
|
76 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
x = self.proj(x)
|
80 |
+
if self.flatten:
|
81 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
82 |
+
x = self.norm(x)
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class DropPath(nn.Module):
|
87 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
88 |
+
"""
|
89 |
+
def __init__(self, drop_prob=None):
|
90 |
+
super(DropPath, self).__init__()
|
91 |
+
self.drop_prob = drop_prob
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
return drop_path(x, self.drop_prob, self.training)
|
95 |
+
|
96 |
+
|
97 |
+
class RMSNorm(nn.Module):
|
98 |
+
def __init__(self, dim, eps=1e-6):
|
99 |
+
super().__init__()
|
100 |
+
self.dim = dim
|
101 |
+
self.eps = eps
|
102 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
return self._norm(x.float()).type_as(x) * self.weight
|
106 |
+
|
107 |
+
def _norm(self, x):
|
108 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
109 |
+
|
110 |
+
|
111 |
+
class Mlp(nn.Module):
|
112 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
113 |
+
"""
|
114 |
+
def __init__(self,
|
115 |
+
in_features,
|
116 |
+
hidden_features=None,
|
117 |
+
out_features=None,
|
118 |
+
act_layer=nn.GELU,
|
119 |
+
drop=0.):
|
120 |
+
super().__init__()
|
121 |
+
out_features = out_features or in_features
|
122 |
+
hidden_features = hidden_features or in_features
|
123 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
124 |
+
self.act = act_layer()
|
125 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
126 |
+
self.drop = nn.Dropout(drop)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
x = self.fc1(x)
|
130 |
+
x = self.act(x)
|
131 |
+
x = self.drop(x)
|
132 |
+
x = self.fc2(x)
|
133 |
+
x = self.drop(x)
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
class TimestepEmbedder(nn.Module):
|
138 |
+
"""
|
139 |
+
Embeds scalar timesteps into vector representations.
|
140 |
+
"""
|
141 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
142 |
+
super().__init__()
|
143 |
+
self.mlp = nn.Sequential(
|
144 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
145 |
+
nn.SiLU(),
|
146 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
147 |
+
)
|
148 |
+
self.frequency_embedding_size = frequency_embedding_size
|
149 |
+
|
150 |
+
@staticmethod
|
151 |
+
def timestep_embedding(t, dim, max_period=10000):
|
152 |
+
"""
|
153 |
+
Create sinusoidal timestep embeddings.
|
154 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
155 |
+
These may be fractional.
|
156 |
+
:param dim: the dimension of the output.
|
157 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
158 |
+
:return: an (N, D) Tensor of positional embeddings.
|
159 |
+
"""
|
160 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
161 |
+
half = dim // 2
|
162 |
+
freqs = torch.exp(
|
163 |
+
-math.log(max_period) *
|
164 |
+
torch.arange(start=0, end=half, dtype=torch.float32) /
|
165 |
+
half).to(device=t.device)
|
166 |
+
args = t[:, None].float() * freqs[None]
|
167 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
168 |
+
if dim % 2:
|
169 |
+
embedding = torch.cat(
|
170 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
171 |
+
return embedding
|
172 |
+
|
173 |
+
def forward(self, t):
|
174 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
175 |
+
t_emb = self.mlp(t_freq)
|
176 |
+
return t_emb
|
177 |
+
|
178 |
+
|
179 |
+
class DiTACEBlock(nn.Module):
|
180 |
+
def __init__(self,
|
181 |
+
hidden_size,
|
182 |
+
num_heads,
|
183 |
+
mlp_ratio=4.0,
|
184 |
+
drop_path=0.,
|
185 |
+
window_size=0,
|
186 |
+
backend=None,
|
187 |
+
use_condition=True,
|
188 |
+
qk_norm=False,
|
189 |
+
**block_kwargs):
|
190 |
+
super().__init__()
|
191 |
+
self.hidden_size = hidden_size
|
192 |
+
self.use_condition = use_condition
|
193 |
+
self.norm1 = nn.LayerNorm(hidden_size,
|
194 |
+
elementwise_affine=False,
|
195 |
+
eps=1e-6)
|
196 |
+
self.attn = MultiHeadAttention(hidden_size,
|
197 |
+
num_heads=num_heads,
|
198 |
+
qkv_bias=True,
|
199 |
+
backend=backend,
|
200 |
+
qk_norm=qk_norm,
|
201 |
+
**block_kwargs)
|
202 |
+
if self.use_condition:
|
203 |
+
self.cross_attn = MultiHeadAttention(
|
204 |
+
hidden_size,
|
205 |
+
context_dim=hidden_size,
|
206 |
+
num_heads=num_heads,
|
207 |
+
qkv_bias=True,
|
208 |
+
backend=backend,
|
209 |
+
qk_norm=qk_norm,
|
210 |
+
**block_kwargs)
|
211 |
+
self.norm2 = nn.LayerNorm(hidden_size,
|
212 |
+
elementwise_affine=False,
|
213 |
+
eps=1e-6)
|
214 |
+
# to be compatible with lower version pytorch
|
215 |
+
approx_gelu = lambda: nn.GELU(approximate='tanh')
|
216 |
+
self.mlp = Mlp(in_features=hidden_size,
|
217 |
+
hidden_features=int(hidden_size * mlp_ratio),
|
218 |
+
act_layer=approx_gelu,
|
219 |
+
drop=0)
|
220 |
+
self.drop_path = DropPath(
|
221 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
222 |
+
self.window_size = window_size
|
223 |
+
self.scale_shift_table = nn.Parameter(
|
224 |
+
torch.randn(6, hidden_size) / hidden_size**0.5)
|
225 |
+
|
226 |
+
def forward(self, x, y, t, **kwargs):
|
227 |
+
B = x.size(0)
|
228 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
229 |
+
self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
230 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
231 |
+
shift_msa.squeeze(1), scale_msa.squeeze(1), gate_msa.squeeze(1),
|
232 |
+
shift_mlp.squeeze(1), scale_mlp.squeeze(1), gate_mlp.squeeze(1))
|
233 |
+
x = x + self.drop_path(gate_msa * self.attn(
|
234 |
+
modulate(self.norm1(x), shift_msa, scale_msa, unsqueeze=False), **
|
235 |
+
kwargs))
|
236 |
+
if self.use_condition:
|
237 |
+
x = x + self.cross_attn(x, context=y, **kwargs)
|
238 |
+
|
239 |
+
x = x + self.drop_path(gate_mlp * self.mlp(
|
240 |
+
modulate(self.norm2(x), shift_mlp, scale_mlp, unsqueeze=False)))
|
241 |
+
return x
|
242 |
+
|
243 |
+
|
244 |
+
class MultiHeadAttention(nn.Module):
|
245 |
+
def __init__(self,
|
246 |
+
dim,
|
247 |
+
context_dim=None,
|
248 |
+
num_heads=None,
|
249 |
+
head_dim=None,
|
250 |
+
attn_drop=0.0,
|
251 |
+
qkv_bias=False,
|
252 |
+
dropout=0.0,
|
253 |
+
backend=None,
|
254 |
+
qk_norm=False,
|
255 |
+
eps=1e-6,
|
256 |
+
**block_kwargs):
|
257 |
+
super().__init__()
|
258 |
+
# consider head_dim first, then num_heads
|
259 |
+
num_heads = dim // head_dim if head_dim else num_heads
|
260 |
+
head_dim = dim // num_heads
|
261 |
+
assert num_heads * head_dim == dim
|
262 |
+
context_dim = context_dim or dim
|
263 |
+
self.dim = dim
|
264 |
+
self.context_dim = context_dim
|
265 |
+
self.num_heads = num_heads
|
266 |
+
self.head_dim = head_dim
|
267 |
+
self.scale = math.pow(head_dim, -0.25)
|
268 |
+
# layers
|
269 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
270 |
+
self.k = nn.Linear(context_dim, dim, bias=qkv_bias)
|
271 |
+
self.v = nn.Linear(context_dim, dim, bias=qkv_bias)
|
272 |
+
self.o = nn.Linear(dim, dim)
|
273 |
+
self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
274 |
+
self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
275 |
+
|
276 |
+
self.dropout = nn.Dropout(dropout)
|
277 |
+
self.attention_op = None
|
278 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
279 |
+
self.backend = backend
|
280 |
+
assert self.backend in ('flash_attn', 'xformer_attn', 'pytorch_attn',
|
281 |
+
None)
|
282 |
+
if FLASHATTN_IS_AVAILABLE and self.backend in ('flash_attn', None):
|
283 |
+
self.backend = 'flash_attn'
|
284 |
+
self.softmax_scale = block_kwargs.get('softmax_scale', None)
|
285 |
+
self.causal = block_kwargs.get('causal', False)
|
286 |
+
self.window_size = block_kwargs.get('window_size', (-1, -1))
|
287 |
+
self.deterministic = block_kwargs.get('deterministic', False)
|
288 |
+
else:
|
289 |
+
raise NotImplementedError
|
290 |
+
|
291 |
+
def flash_attn(self, x, context=None, **kwargs):
|
292 |
+
'''
|
293 |
+
The implementation will be very slow when mask is not None,
|
294 |
+
because we need rearange the x/context features according to mask.
|
295 |
+
Args:
|
296 |
+
x:
|
297 |
+
context:
|
298 |
+
mask:
|
299 |
+
**kwargs:
|
300 |
+
Returns: x
|
301 |
+
'''
|
302 |
+
dtype = kwargs.get('dtype', torch.float16)
|
303 |
+
|
304 |
+
def half(x):
|
305 |
+
return x if x.dtype in [torch.float16, torch.bfloat16
|
306 |
+
] else x.to(dtype)
|
307 |
+
|
308 |
+
x_shapes = kwargs['x_shapes']
|
309 |
+
freqs = kwargs['freqs']
|
310 |
+
self_x_len = kwargs['self_x_len']
|
311 |
+
cross_x_len = kwargs['cross_x_len']
|
312 |
+
txt_lens = kwargs['txt_lens']
|
313 |
+
n, d = self.num_heads, self.head_dim
|
314 |
+
|
315 |
+
if context is None:
|
316 |
+
# self-attn
|
317 |
+
q = self.norm_q(self.q(x)).view(-1, n, d)
|
318 |
+
k = self.norm_q(self.k(x)).view(-1, n, d)
|
319 |
+
v = self.v(x).view(-1, n, d)
|
320 |
+
q = rope_apply(q, self_x_len, x_shapes, freqs, pad=False)
|
321 |
+
k = rope_apply(k, self_x_len, x_shapes, freqs, pad=False)
|
322 |
+
q_lens = k_lens = self_x_len
|
323 |
+
else:
|
324 |
+
# cross-attn
|
325 |
+
q = self.norm_q(self.q(x)).view(-1, n, d)
|
326 |
+
k = self.norm_q(self.k(context)).view(-1, n, d)
|
327 |
+
v = self.v(context).view(-1, n, d)
|
328 |
+
q_lens = cross_x_len
|
329 |
+
k_lens = txt_lens
|
330 |
+
|
331 |
+
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]),
|
332 |
+
q_lens]).cumsum(0, dtype=torch.int32)
|
333 |
+
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]),
|
334 |
+
k_lens]).cumsum(0, dtype=torch.int32)
|
335 |
+
max_seqlen_q = q_lens.max()
|
336 |
+
max_seqlen_k = k_lens.max()
|
337 |
+
|
338 |
+
out_dtype = q.dtype
|
339 |
+
q, k, v = half(q), half(k), half(v)
|
340 |
+
x = flash_attn_varlen_func(q,
|
341 |
+
k,
|
342 |
+
v,
|
343 |
+
cu_seqlens_q=cu_seqlens_q,
|
344 |
+
cu_seqlens_k=cu_seqlens_k,
|
345 |
+
max_seqlen_q=max_seqlen_q,
|
346 |
+
max_seqlen_k=max_seqlen_k,
|
347 |
+
dropout_p=self.attn_drop.p,
|
348 |
+
softmax_scale=self.softmax_scale,
|
349 |
+
causal=self.causal,
|
350 |
+
window_size=self.window_size,
|
351 |
+
deterministic=self.deterministic)
|
352 |
+
|
353 |
+
x = x.type(out_dtype)
|
354 |
+
x = x.reshape(-1, n * d)
|
355 |
+
x = self.o(x)
|
356 |
+
x = self.dropout(x)
|
357 |
+
return x
|
358 |
+
|
359 |
+
def forward(self, x, context=None, **kwargs):
|
360 |
+
x = getattr(self, self.backend)(x, context=context, **kwargs)
|
361 |
+
return x
|
362 |
+
|
363 |
+
|
364 |
+
class T2IFinalLayer(nn.Module):
|
365 |
+
"""
|
366 |
+
The final layer of PixArt.
|
367 |
+
"""
|
368 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
369 |
+
super().__init__()
|
370 |
+
self.norm_final = nn.LayerNorm(hidden_size,
|
371 |
+
elementwise_affine=False,
|
372 |
+
eps=1e-6)
|
373 |
+
self.linear = nn.Linear(hidden_size,
|
374 |
+
patch_size * patch_size * out_channels,
|
375 |
+
bias=True)
|
376 |
+
self.scale_shift_table = nn.Parameter(
|
377 |
+
torch.randn(2, hidden_size) / hidden_size**0.5)
|
378 |
+
self.out_channels = out_channels
|
379 |
+
|
380 |
+
def forward(self, x, t):
|
381 |
+
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2,
|
382 |
+
dim=1)
|
383 |
+
shift, scale = shift.squeeze(1), scale.squeeze(1)
|
384 |
+
x = modulate(self.norm_final(x), shift, scale)
|
385 |
+
x = self.linear(x)
|
386 |
+
return x
|
modules/model/backbone/pos_embed.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.cuda.amp as amp
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn.utils.rnn import pad_sequence
|
8 |
+
|
9 |
+
def frame_pad(x, seq_len, shapes):
|
10 |
+
max_h, max_w = np.max(shapes, 0)
|
11 |
+
frames = []
|
12 |
+
cur_len = 0
|
13 |
+
for h, w in shapes:
|
14 |
+
frame_len = h * w
|
15 |
+
frames.append(
|
16 |
+
F.pad(
|
17 |
+
x[cur_len:cur_len + frame_len].view(h, w, -1),
|
18 |
+
(0, 0, 0, max_w - w, 0, max_h - h)) # .view(max_h * max_w, -1)
|
19 |
+
)
|
20 |
+
cur_len += frame_len
|
21 |
+
if cur_len >= seq_len:
|
22 |
+
break
|
23 |
+
return torch.stack(frames)
|
24 |
+
|
25 |
+
|
26 |
+
def frame_unpad(x, shapes):
|
27 |
+
max_h, max_w = np.max(shapes, 0)
|
28 |
+
x = rearrange(x, '(b h w) n c -> b h w n c', h=max_h, w=max_w)
|
29 |
+
frames = []
|
30 |
+
for i, (h, w) in enumerate(shapes):
|
31 |
+
if i >= len(x):
|
32 |
+
break
|
33 |
+
frames.append(rearrange(x[i, :h, :w], 'h w n c -> (h w) n c'))
|
34 |
+
return torch.concat(frames)
|
35 |
+
|
36 |
+
|
37 |
+
@amp.autocast(enabled=False)
|
38 |
+
def rope_apply_multires(x, x_lens, x_shapes, freqs, pad=True):
|
39 |
+
"""
|
40 |
+
x: [B*L, N, C].
|
41 |
+
x_lens: [B].
|
42 |
+
x_shapes: [B, F, 2].
|
43 |
+
freqs: [M, C // 2].
|
44 |
+
"""
|
45 |
+
n, c = x.size(1), x.size(2) // 2
|
46 |
+
# split freqs
|
47 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
48 |
+
# loop over samples
|
49 |
+
output = []
|
50 |
+
st = 0
|
51 |
+
for i, (seq_len,
|
52 |
+
shapes) in enumerate(zip(x_lens.tolist(), x_shapes.tolist())):
|
53 |
+
x_i = frame_pad(x[st:st + seq_len], seq_len, shapes) # f, h, w, c
|
54 |
+
f, h, w = x_i.shape[:3]
|
55 |
+
pad_seq_len = f * h * w
|
56 |
+
# precompute multipliers
|
57 |
+
x_i = torch.view_as_complex(
|
58 |
+
x_i.to(torch.float64).reshape(pad_seq_len, n, -1, 2))
|
59 |
+
freqs_i = torch.cat([
|
60 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
61 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
62 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
63 |
+
],
|
64 |
+
dim=-1).reshape(pad_seq_len, 1, -1)
|
65 |
+
# apply rotary embedding
|
66 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2).type_as(x)
|
67 |
+
x_i = frame_unpad(x_i, shapes)
|
68 |
+
# append to collection
|
69 |
+
output.append(x_i)
|
70 |
+
st += seq_len
|
71 |
+
return pad_sequence(output) if pad else torch.concat(output)
|
72 |
+
|
73 |
+
|
74 |
+
@amp.autocast(enabled=False)
|
75 |
+
def rope_params(max_seq_len, dim, theta=10000):
|
76 |
+
"""
|
77 |
+
Precompute the frequency tensor for complex exponentials.
|
78 |
+
"""
|
79 |
+
assert dim % 2 == 0
|
80 |
+
freqs = torch.outer(
|
81 |
+
torch.arange(max_seq_len),
|
82 |
+
1.0 / torch.pow(theta,
|
83 |
+
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
84 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
85 |
+
return freqs
|
modules/model/diffusion/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
from .diffusions import ACEDiffusion
|
5 |
+
from .samplers import DDIMSampler
|
6 |
+
from .schedules import LinearScheduler
|
modules/model/diffusion/diffusions.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from tqdm import trange
|
9 |
+
|
10 |
+
from scepter.modules.model.registry import (DIFFUSION_SAMPLERS, DIFFUSIONS,
|
11 |
+
NOISE_SCHEDULERS)
|
12 |
+
from scepter.modules.utils.config import Config, dict_to_yaml
|
13 |
+
from scepter.modules.utils.distribute import we
|
14 |
+
from scepter.modules.utils.file_system import FS
|
15 |
+
|
16 |
+
|
17 |
+
@DIFFUSIONS.register_class()
|
18 |
+
class ACEDiffusion(object):
|
19 |
+
para_dict = {
|
20 |
+
'NOISE_SCHEDULER': {},
|
21 |
+
'SAMPLER_SCHEDULER': {},
|
22 |
+
'MIN_SNR_GAMMA': {
|
23 |
+
'value': None,
|
24 |
+
'description': 'The minimum SNR gamma value for the loss function.'
|
25 |
+
},
|
26 |
+
'PREDICTION_TYPE': {
|
27 |
+
'value': 'eps',
|
28 |
+
'description':
|
29 |
+
'The type of prediction to use for the loss function.'
|
30 |
+
}
|
31 |
+
}
|
32 |
+
|
33 |
+
def __init__(self, cfg, logger=None):
|
34 |
+
super(ACEDiffusion, self).__init__()
|
35 |
+
self.logger = logger
|
36 |
+
self.cfg = cfg
|
37 |
+
self.init_params()
|
38 |
+
|
39 |
+
def init_params(self):
|
40 |
+
self.min_snr_gamma = self.cfg.get('MIN_SNR_GAMMA', None)
|
41 |
+
self.prediction_type = self.cfg.get('PREDICTION_TYPE', 'eps')
|
42 |
+
self.noise_scheduler = NOISE_SCHEDULERS.build(self.cfg.NOISE_SCHEDULER,
|
43 |
+
logger=self.logger)
|
44 |
+
self.sampler_scheduler = NOISE_SCHEDULERS.build(self.cfg.get(
|
45 |
+
'SAMPLER_SCHEDULER', self.cfg.NOISE_SCHEDULER),
|
46 |
+
logger=self.logger)
|
47 |
+
self.num_timesteps = self.noise_scheduler.num_timesteps
|
48 |
+
if self.cfg.have('WORK_DIR') and we.rank == 0:
|
49 |
+
schedule_visualization = os.path.join(self.cfg.WORK_DIR,
|
50 |
+
'noise_schedule.png')
|
51 |
+
with FS.put_to(schedule_visualization) as local_path:
|
52 |
+
self.noise_scheduler.plot_noise_sampling_map(local_path)
|
53 |
+
schedule_visualization = os.path.join(self.cfg.WORK_DIR,
|
54 |
+
'sampler_schedule.png')
|
55 |
+
with FS.put_to(schedule_visualization) as local_path:
|
56 |
+
self.sampler_scheduler.plot_noise_sampling_map(local_path)
|
57 |
+
|
58 |
+
def sample(self,
|
59 |
+
noise,
|
60 |
+
model,
|
61 |
+
model_kwargs={},
|
62 |
+
steps=20,
|
63 |
+
sampler=None,
|
64 |
+
use_dynamic_cfg=False,
|
65 |
+
guide_scale=None,
|
66 |
+
guide_rescale=None,
|
67 |
+
show_progress=False,
|
68 |
+
return_intermediate=None,
|
69 |
+
intermediate_callback=None,
|
70 |
+
**kwargs):
|
71 |
+
assert isinstance(steps, (int, torch.LongTensor))
|
72 |
+
assert return_intermediate in (None, 'x0', 'xt')
|
73 |
+
assert isinstance(sampler, (str, dict, Config))
|
74 |
+
intermediates = []
|
75 |
+
|
76 |
+
def callback_fn(x_t, t, sigma=None, alpha=None):
|
77 |
+
timestamp = t
|
78 |
+
t = t.repeat(len(x_t)).round().long().to(x_t.device)
|
79 |
+
sigma = sigma.repeat(len(x_t), *([1] * (len(sigma.shape) - 1)))
|
80 |
+
alpha = alpha.repeat(len(x_t), *([1] * (len(alpha.shape) - 1)))
|
81 |
+
|
82 |
+
if guide_scale is None or guide_scale == 1.0:
|
83 |
+
out = model(x=x_t, t=t, **model_kwargs)
|
84 |
+
else:
|
85 |
+
if use_dynamic_cfg:
|
86 |
+
guidance_scale = 1 + guide_scale * (
|
87 |
+
(1 - math.cos(math.pi * (
|
88 |
+
(steps - timestamp.item()) / steps)**5.0)) / 2)
|
89 |
+
else:
|
90 |
+
guidance_scale = guide_scale
|
91 |
+
y_out = model(x=x_t, t=t, **model_kwargs[0])
|
92 |
+
u_out = model(x=x_t, t=t, **model_kwargs[1])
|
93 |
+
out = u_out + guidance_scale * (y_out - u_out)
|
94 |
+
if guide_rescale is not None and guide_rescale > 0.0:
|
95 |
+
ratio = (
|
96 |
+
y_out.flatten(1).std(dim=1) /
|
97 |
+
(out.flatten(1).std(dim=1) + 1e-12)).view((-1, ) + (1, ) *
|
98 |
+
(y_out.ndim - 1))
|
99 |
+
out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
|
100 |
+
|
101 |
+
if self.prediction_type == 'x0':
|
102 |
+
x0 = out
|
103 |
+
elif self.prediction_type == 'eps':
|
104 |
+
x0 = (x_t - sigma * out) / alpha
|
105 |
+
elif self.prediction_type == 'v':
|
106 |
+
x0 = alpha * x_t - sigma * out
|
107 |
+
else:
|
108 |
+
raise NotImplementedError(
|
109 |
+
f'prediction_type {self.prediction_type} not implemented')
|
110 |
+
|
111 |
+
return x0
|
112 |
+
|
113 |
+
sampler_ins = self.get_sampler(sampler)
|
114 |
+
|
115 |
+
# this is ignored for schnell
|
116 |
+
sampler_output = sampler_ins.preprare_sampler(
|
117 |
+
noise,
|
118 |
+
steps=steps,
|
119 |
+
prediction_type=self.prediction_type,
|
120 |
+
scheduler_ins=self.sampler_scheduler,
|
121 |
+
callback_fn=callback_fn)
|
122 |
+
|
123 |
+
for _ in trange(steps, disable=not show_progress):
|
124 |
+
trange.desc = sampler_output.msg
|
125 |
+
sampler_output = sampler_ins.step(sampler_output)
|
126 |
+
if return_intermediate == 'x_0':
|
127 |
+
intermediates.append(sampler_output.x_0)
|
128 |
+
elif return_intermediate == 'x_t':
|
129 |
+
intermediates.append(sampler_output.x_t)
|
130 |
+
if intermediate_callback is not None:
|
131 |
+
intermediate_callback(intermediates[-1])
|
132 |
+
return (sampler_output.x_0, intermediates
|
133 |
+
) if return_intermediate is not None else sampler_output.x_0
|
134 |
+
|
135 |
+
def loss(self,
|
136 |
+
x_0,
|
137 |
+
model,
|
138 |
+
model_kwargs={},
|
139 |
+
reduction='mean',
|
140 |
+
noise=None,
|
141 |
+
**kwargs):
|
142 |
+
# use noise scheduler to add noise
|
143 |
+
if noise is None:
|
144 |
+
noise = torch.randn_like(x_0)
|
145 |
+
schedule_output = self.noise_scheduler.add_noise(x_0, noise, **kwargs)
|
146 |
+
x_t, t, sigma, alpha = schedule_output.x_t, schedule_output.t, schedule_output.sigma, schedule_output.alpha
|
147 |
+
out = model(x=x_t, t=t, **model_kwargs)
|
148 |
+
|
149 |
+
# mse loss
|
150 |
+
target = {
|
151 |
+
'eps': noise,
|
152 |
+
'x0': x_0,
|
153 |
+
'v': alpha * noise - sigma * x_0
|
154 |
+
}[self.prediction_type]
|
155 |
+
|
156 |
+
loss = (out - target).pow(2)
|
157 |
+
if reduction == 'mean':
|
158 |
+
loss = loss.flatten(1).mean(dim=1)
|
159 |
+
|
160 |
+
if self.min_snr_gamma is not None:
|
161 |
+
alphas = self.noise_scheduler.alphas.to(x_0.device)[t]
|
162 |
+
sigmas = self.noise_scheduler.sigmas.pow(2).to(x_0.device)[t]
|
163 |
+
snrs = (alphas / sigmas).clamp(min=1e-20)
|
164 |
+
min_snrs = snrs.clamp(max=self.min_snr_gamma)
|
165 |
+
weights = min_snrs / snrs
|
166 |
+
else:
|
167 |
+
weights = 1
|
168 |
+
|
169 |
+
loss = loss * weights
|
170 |
+
return loss
|
171 |
+
|
172 |
+
def get_sampler(self, sampler):
|
173 |
+
if isinstance(sampler, str):
|
174 |
+
if sampler not in DIFFUSION_SAMPLERS.class_map:
|
175 |
+
if self.logger is not None:
|
176 |
+
self.logger.info(
|
177 |
+
f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'
|
178 |
+
)
|
179 |
+
else:
|
180 |
+
print(
|
181 |
+
f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'
|
182 |
+
)
|
183 |
+
return None
|
184 |
+
sampler_cfg = Config(cfg_dict={'NAME': sampler}, load=False)
|
185 |
+
sampler_ins = DIFFUSION_SAMPLERS.build(sampler_cfg,
|
186 |
+
logger=self.logger)
|
187 |
+
elif isinstance(sampler, (Config, dict, OrderedDict)):
|
188 |
+
if isinstance(sampler, (dict, OrderedDict)):
|
189 |
+
sampler = Config(
|
190 |
+
cfg_dict={k.upper(): v
|
191 |
+
for k, v in dict(sampler).items()},
|
192 |
+
load=False)
|
193 |
+
sampler_ins = DIFFUSION_SAMPLERS.build(sampler, logger=self.logger)
|
194 |
+
else:
|
195 |
+
raise NotImplementedError
|
196 |
+
return sampler_ins
|
197 |
+
|
198 |
+
def __repr__(self) -> str:
|
199 |
+
return f'{self.__class__.__name__}' + ' ' + super().__repr__()
|
200 |
+
|
201 |
+
@staticmethod
|
202 |
+
def get_config_template():
|
203 |
+
return dict_to_yaml('DIFFUSIONS',
|
204 |
+
__class__.__name__,
|
205 |
+
ACEDiffusion.para_dict,
|
206 |
+
set_name=True)
|
modules/model/diffusion/samplers.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from scepter.modules.model.registry import DIFFUSION_SAMPLERS
|
6 |
+
from scepter.modules.model.diffusion.samplers import BaseDiffusionSampler
|
7 |
+
from scepter.modules.model.diffusion.util import _i
|
8 |
+
|
9 |
+
def _i(tensor, t, x):
|
10 |
+
"""
|
11 |
+
Index tensor using t and format the output according to x.
|
12 |
+
"""
|
13 |
+
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
|
14 |
+
if isinstance(t, torch.Tensor):
|
15 |
+
t = t.to(tensor.device)
|
16 |
+
return tensor[t].view(shape).to(x.device)
|
17 |
+
|
18 |
+
|
19 |
+
@DIFFUSION_SAMPLERS.register_class('ddim')
|
20 |
+
class DDIMSampler(BaseDiffusionSampler):
|
21 |
+
def init_params(self):
|
22 |
+
super().init_params()
|
23 |
+
self.eta = self.cfg.get('ETA', 0.)
|
24 |
+
self.discretization_type = self.cfg.get('DISCRETIZATION_TYPE',
|
25 |
+
'trailing')
|
26 |
+
|
27 |
+
def preprare_sampler(self,
|
28 |
+
noise,
|
29 |
+
steps=20,
|
30 |
+
scheduler_ins=None,
|
31 |
+
prediction_type='',
|
32 |
+
sigmas=None,
|
33 |
+
betas=None,
|
34 |
+
alphas=None,
|
35 |
+
callback_fn=None,
|
36 |
+
**kwargs):
|
37 |
+
output = super().preprare_sampler(noise, steps, scheduler_ins,
|
38 |
+
prediction_type, sigmas, betas,
|
39 |
+
alphas, callback_fn, **kwargs)
|
40 |
+
sigmas = output.sigmas
|
41 |
+
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
42 |
+
sigmas_vp = (sigmas**2 / (1 + sigmas**2))**0.5
|
43 |
+
sigmas_vp[sigmas == float('inf')] = 1.
|
44 |
+
output.add_custom_field('sigmas_vp', sigmas_vp)
|
45 |
+
return output
|
46 |
+
|
47 |
+
def step(self, sampler_output):
|
48 |
+
x_t = sampler_output.x_t
|
49 |
+
step = sampler_output.step
|
50 |
+
t = sampler_output.ts[step]
|
51 |
+
sigmas_vp = sampler_output.sigmas_vp.to(x_t.device)
|
52 |
+
alpha_init = _i(sampler_output.alphas_init, step, x_t[:1])
|
53 |
+
sigma_init = _i(sampler_output.sigmas_init, step, x_t[:1])
|
54 |
+
|
55 |
+
x = sampler_output.callback_fn(x_t, t, sigma_init, alpha_init)
|
56 |
+
noise_factor = self.eta * (sigmas_vp[step + 1]**2 /
|
57 |
+
sigmas_vp[step]**2 *
|
58 |
+
(1 - (1 - sigmas_vp[step]**2) /
|
59 |
+
(1 - sigmas_vp[step + 1]**2)))
|
60 |
+
d = (x_t - (1 - sigmas_vp[step]**2)**0.5 * x) / sigmas_vp[step]
|
61 |
+
x = (1 - sigmas_vp[step + 1] ** 2) ** 0.5 * x + \
|
62 |
+
(sigmas_vp[step + 1] ** 2 - noise_factor ** 2) ** 0.5 * d
|
63 |
+
sampler_output.x_0 = x
|
64 |
+
if sigmas_vp[step + 1] > 0:
|
65 |
+
x += noise_factor * torch.randn_like(x)
|
66 |
+
sampler_output.x_t = x
|
67 |
+
sampler_output.step += 1
|
68 |
+
sampler_output.msg = f'step {step}'
|
69 |
+
return sampler_output
|
modules/model/diffusion/schedules.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from scepter.modules.model.registry import NOISE_SCHEDULERS
|
6 |
+
from scepter.modules.model.diffusion.schedules import BaseNoiseScheduler
|
7 |
+
|
8 |
+
|
9 |
+
@NOISE_SCHEDULERS.register_class()
|
10 |
+
class LinearScheduler(BaseNoiseScheduler):
|
11 |
+
para_dict = {}
|
12 |
+
|
13 |
+
def init_params(self):
|
14 |
+
super().init_params()
|
15 |
+
self.beta_min = self.cfg.get('BETA_MIN', 0.00085)
|
16 |
+
self.beta_max = self.cfg.get('BETA_MAX', 0.012)
|
17 |
+
|
18 |
+
def betas_to_sigmas(self, betas):
|
19 |
+
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
|
20 |
+
|
21 |
+
def get_schedule(self):
|
22 |
+
betas = torch.linspace(self.beta_min,
|
23 |
+
self.beta_max,
|
24 |
+
self.num_timesteps,
|
25 |
+
dtype=torch.float32)
|
26 |
+
sigmas = self.betas_to_sigmas(betas)
|
27 |
+
self._sigmas = sigmas
|
28 |
+
self._betas = betas
|
29 |
+
self._alphas = torch.sqrt(1 - sigmas**2)
|
30 |
+
self._timesteps = torch.arange(len(sigmas), dtype=torch.float32)
|
modules/model/embedder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .embedder import ACETextEmbedder
|
modules/model/embedder/embedder.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import warnings
|
4 |
+
from contextlib import nullcontext
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.utils.dlpack
|
9 |
+
from scepter.modules.model.embedder.base_embedder import BaseEmbedder
|
10 |
+
from scepter.modules.model.registry import EMBEDDERS
|
11 |
+
from scepter.modules.model.tokenizer.tokenizer_component import (
|
12 |
+
basic_clean, canonicalize, heavy_clean, whitespace_clean)
|
13 |
+
from scepter.modules.utils.config import dict_to_yaml
|
14 |
+
from scepter.modules.utils.distribute import we
|
15 |
+
from scepter.modules.utils.file_system import FS
|
16 |
+
|
17 |
+
try:
|
18 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
19 |
+
except Exception as e:
|
20 |
+
warnings.warn(
|
21 |
+
f'Import transformers error, please deal with this problem: {e}')
|
22 |
+
|
23 |
+
|
24 |
+
@EMBEDDERS.register_class()
|
25 |
+
class ACETextEmbedder(BaseEmbedder):
|
26 |
+
"""
|
27 |
+
Uses the OpenCLIP transformer encoder for text
|
28 |
+
"""
|
29 |
+
"""
|
30 |
+
Uses the OpenCLIP transformer encoder for text
|
31 |
+
"""
|
32 |
+
para_dict = {
|
33 |
+
'PRETRAINED_MODEL': {
|
34 |
+
'value':
|
35 |
+
'google/umt5-small',
|
36 |
+
'description':
|
37 |
+
'Pretrained Model for umt5, modelcard path or local path.'
|
38 |
+
},
|
39 |
+
'TOKENIZER_PATH': {
|
40 |
+
'value': 'google/umt5-small',
|
41 |
+
'description':
|
42 |
+
'Tokenizer Path for umt5, modelcard path or local path.'
|
43 |
+
},
|
44 |
+
'FREEZE': {
|
45 |
+
'value': True,
|
46 |
+
'description': ''
|
47 |
+
},
|
48 |
+
'USE_GRAD': {
|
49 |
+
'value': False,
|
50 |
+
'description': 'Compute grad or not.'
|
51 |
+
},
|
52 |
+
'CLEAN': {
|
53 |
+
'value':
|
54 |
+
'whitespace',
|
55 |
+
'description':
|
56 |
+
'Set the clean strtegy for tokenizer, used when TOKENIZER_PATH is not None.'
|
57 |
+
},
|
58 |
+
'LAYER': {
|
59 |
+
'value': 'last',
|
60 |
+
'description': ''
|
61 |
+
},
|
62 |
+
'LEGACY': {
|
63 |
+
'value':
|
64 |
+
True,
|
65 |
+
'description':
|
66 |
+
'Whether use legacy returnd feature or not ,default True.'
|
67 |
+
}
|
68 |
+
}
|
69 |
+
|
70 |
+
def __init__(self, cfg, logger=None):
|
71 |
+
super().__init__(cfg, logger=logger)
|
72 |
+
pretrained_path = cfg.get('PRETRAINED_MODEL', None)
|
73 |
+
self.t5_dtype = cfg.get('T5_DTYPE', 'float32')
|
74 |
+
assert pretrained_path
|
75 |
+
with FS.get_dir_to_local_dir(pretrained_path,
|
76 |
+
wait_finish=True) as local_path:
|
77 |
+
self.model = T5EncoderModel.from_pretrained(
|
78 |
+
local_path,
|
79 |
+
torch_dtype=getattr(
|
80 |
+
torch,
|
81 |
+
'float' if self.t5_dtype == 'float32' else self.t5_dtype))
|
82 |
+
tokenizer_path = cfg.get('TOKENIZER_PATH', None)
|
83 |
+
self.length = cfg.get('LENGTH', 77)
|
84 |
+
|
85 |
+
self.use_grad = cfg.get('USE_GRAD', False)
|
86 |
+
self.clean = cfg.get('CLEAN', 'whitespace')
|
87 |
+
self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
|
88 |
+
if tokenizer_path:
|
89 |
+
self.tokenize_kargs = {'return_tensors': 'pt'}
|
90 |
+
with FS.get_dir_to_local_dir(tokenizer_path,
|
91 |
+
wait_finish=True) as local_path:
|
92 |
+
if self.added_identifier is not None and isinstance(
|
93 |
+
self.added_identifier, list):
|
94 |
+
self.tokenizer = AutoTokenizer.from_pretrained(local_path)
|
95 |
+
else:
|
96 |
+
self.tokenizer = AutoTokenizer.from_pretrained(local_path)
|
97 |
+
if self.length is not None:
|
98 |
+
self.tokenize_kargs.update({
|
99 |
+
'padding': 'max_length',
|
100 |
+
'truncation': True,
|
101 |
+
'max_length': self.length
|
102 |
+
})
|
103 |
+
self.eos_token = self.tokenizer(
|
104 |
+
self.tokenizer.eos_token)['input_ids'][0]
|
105 |
+
else:
|
106 |
+
self.tokenizer = None
|
107 |
+
self.tokenize_kargs = {}
|
108 |
+
|
109 |
+
self.use_grad = cfg.get('USE_GRAD', False)
|
110 |
+
self.clean = cfg.get('CLEAN', 'whitespace')
|
111 |
+
|
112 |
+
def freeze(self):
|
113 |
+
self.model = self.model.eval()
|
114 |
+
for param in self.parameters():
|
115 |
+
param.requires_grad = False
|
116 |
+
|
117 |
+
# encode && encode_text
|
118 |
+
def forward(self, tokens, return_mask=False, use_mask=True):
|
119 |
+
# tokenization
|
120 |
+
embedding_context = nullcontext if self.use_grad else torch.no_grad
|
121 |
+
with embedding_context():
|
122 |
+
if use_mask:
|
123 |
+
x = self.model(tokens.input_ids.to(we.device_id),
|
124 |
+
tokens.attention_mask.to(we.device_id))
|
125 |
+
else:
|
126 |
+
x = self.model(tokens.input_ids.to(we.device_id))
|
127 |
+
x = x.last_hidden_state
|
128 |
+
|
129 |
+
if return_mask:
|
130 |
+
return x.detach() + 0.0, tokens.attention_mask.to(we.device_id)
|
131 |
+
else:
|
132 |
+
return x.detach() + 0.0, None
|
133 |
+
|
134 |
+
def _clean(self, text):
|
135 |
+
if self.clean == 'whitespace':
|
136 |
+
text = whitespace_clean(basic_clean(text))
|
137 |
+
elif self.clean == 'lower':
|
138 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
139 |
+
elif self.clean == 'canonicalize':
|
140 |
+
text = canonicalize(basic_clean(text))
|
141 |
+
elif self.clean == 'heavy':
|
142 |
+
text = heavy_clean(basic_clean(text))
|
143 |
+
return text
|
144 |
+
|
145 |
+
def encode(self, text, return_mask=False, use_mask=True):
|
146 |
+
if isinstance(text, str):
|
147 |
+
text = [text]
|
148 |
+
if self.clean:
|
149 |
+
text = [self._clean(u) for u in text]
|
150 |
+
assert self.tokenizer is not None
|
151 |
+
cont, mask = [], []
|
152 |
+
with torch.autocast(device_type='cuda',
|
153 |
+
enabled=self.t5_dtype in ('float16', 'bfloat16'),
|
154 |
+
dtype=getattr(torch, self.t5_dtype)):
|
155 |
+
for tt in text:
|
156 |
+
tokens = self.tokenizer([tt], **self.tokenize_kargs)
|
157 |
+
one_cont, one_mask = self(tokens,
|
158 |
+
return_mask=return_mask,
|
159 |
+
use_mask=use_mask)
|
160 |
+
cont.append(one_cont)
|
161 |
+
mask.append(one_mask)
|
162 |
+
if return_mask:
|
163 |
+
return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
|
164 |
+
else:
|
165 |
+
return torch.cat(cont, dim=0)
|
166 |
+
|
167 |
+
def encode_list(self, text_list, return_mask=True):
|
168 |
+
cont_list = []
|
169 |
+
mask_list = []
|
170 |
+
for pp in text_list:
|
171 |
+
cont, cont_mask = self.encode(pp, return_mask=return_mask)
|
172 |
+
cont_list.append(cont)
|
173 |
+
mask_list.append(cont_mask)
|
174 |
+
if return_mask:
|
175 |
+
return cont_list, mask_list
|
176 |
+
else:
|
177 |
+
return cont_list
|
178 |
+
|
179 |
+
@staticmethod
|
180 |
+
def get_config_template():
|
181 |
+
return dict_to_yaml('MODELS',
|
182 |
+
__class__.__name__,
|
183 |
+
ACETextEmbedder.para_dict,
|
184 |
+
set_name=True)
|
modules/model/network/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .ldm_ace import LdmACE
|
modules/model/network/ldm_ace.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import copy
|
4 |
+
import random
|
5 |
+
from contextlib import nullcontext
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from scepter.modules.model.network.ldm import LatentDiffusion
|
12 |
+
from scepter.modules.model.registry import MODELS
|
13 |
+
from scepter.modules.utils.config import dict_to_yaml
|
14 |
+
from scepter.modules.utils.distribute import we
|
15 |
+
|
16 |
+
from ..utils.basic_utils import (
|
17 |
+
check_list_of_list,
|
18 |
+
pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor,
|
19 |
+
to_device,
|
20 |
+
unpack_tensor_into_imagelist
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
class TextEmbedding(nn.Module):
|
25 |
+
def __init__(self, embedding_shape):
|
26 |
+
super().__init__()
|
27 |
+
self.pos = nn.Parameter(data=torch.zeros(embedding_shape))
|
28 |
+
|
29 |
+
|
30 |
+
@MODELS.register_class()
|
31 |
+
class LdmACE(LatentDiffusion):
|
32 |
+
para_dict = LatentDiffusion.para_dict
|
33 |
+
para_dict['DECODER_BIAS'] = {'value': 0, 'description': ''}
|
34 |
+
|
35 |
+
def __init__(self, cfg, logger=None):
|
36 |
+
super().__init__(cfg, logger=logger)
|
37 |
+
self.interpolate_func = lambda x: (F.interpolate(
|
38 |
+
x.unsqueeze(0),
|
39 |
+
scale_factor=1 / self.size_factor,
|
40 |
+
mode='nearest-exact') if x is not None else None)
|
41 |
+
|
42 |
+
self.text_indentifers = cfg.get('TEXT_IDENTIFIER', [])
|
43 |
+
self.use_text_pos_embeddings = cfg.get('USE_TEXT_POS_EMBEDDINGS',
|
44 |
+
False)
|
45 |
+
if self.use_text_pos_embeddings:
|
46 |
+
self.text_position_embeddings = TextEmbedding(
|
47 |
+
(10, 4096)).eval().requires_grad_(False)
|
48 |
+
else:
|
49 |
+
self.text_position_embeddings = None
|
50 |
+
|
51 |
+
self.logger.info(self.model)
|
52 |
+
|
53 |
+
@torch.no_grad()
|
54 |
+
def encode_first_stage(self, x, **kwargs):
|
55 |
+
return [
|
56 |
+
self.scale_factor *
|
57 |
+
self.first_stage_model._encode(i.unsqueeze(0).to(torch.float16))
|
58 |
+
for i in x
|
59 |
+
]
|
60 |
+
|
61 |
+
@torch.no_grad()
|
62 |
+
def decode_first_stage(self, z):
|
63 |
+
return [
|
64 |
+
self.first_stage_model._decode(1. / self.scale_factor *
|
65 |
+
i.to(torch.float16)) for i in z
|
66 |
+
]
|
67 |
+
|
68 |
+
def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
|
69 |
+
if self.use_text_pos_embeddings and not torch.sum(
|
70 |
+
self.text_position_embeddings.pos) > 0:
|
71 |
+
identifier_cont, identifier_cont_mask = getattr(
|
72 |
+
self.cond_stage_model, 'encode')(self.text_indentifers,
|
73 |
+
return_mask=True)
|
74 |
+
self.text_position_embeddings.load_state_dict(
|
75 |
+
{'pos': identifier_cont[:, 0, :]})
|
76 |
+
cont_, cont_mask_ = [], []
|
77 |
+
for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):
|
78 |
+
if isinstance(pp, list):
|
79 |
+
cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])
|
80 |
+
cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])
|
81 |
+
else:
|
82 |
+
raise NotImplementedError
|
83 |
+
|
84 |
+
return cont_, cont_mask_
|
85 |
+
|
86 |
+
def limit_batch_data(self, batch_data_list, log_num):
|
87 |
+
if log_num and log_num > 0:
|
88 |
+
batch_data_list_limited = []
|
89 |
+
for sub_data in batch_data_list:
|
90 |
+
if sub_data is not None:
|
91 |
+
sub_data = sub_data[:log_num]
|
92 |
+
batch_data_list_limited.append(sub_data)
|
93 |
+
return batch_data_list_limited
|
94 |
+
else:
|
95 |
+
return batch_data_list
|
96 |
+
|
97 |
+
def forward_train(self,
|
98 |
+
edit_image=[],
|
99 |
+
edit_image_mask=[],
|
100 |
+
image=None,
|
101 |
+
image_mask=None,
|
102 |
+
noise=None,
|
103 |
+
prompt=[],
|
104 |
+
**kwargs):
|
105 |
+
'''
|
106 |
+
Args:
|
107 |
+
edit_image: list of list of edit_image
|
108 |
+
edit_image_mask: list of list of edit_image_mask
|
109 |
+
image: target image
|
110 |
+
image_mask: target image mask
|
111 |
+
noise: default is None, generate automaticly
|
112 |
+
prompt: list of list of text
|
113 |
+
**kwargs:
|
114 |
+
Returns:
|
115 |
+
'''
|
116 |
+
assert check_list_of_list(prompt) and check_list_of_list(
|
117 |
+
edit_image) and check_list_of_list(edit_image_mask)
|
118 |
+
assert len(edit_image) == len(edit_image_mask) == len(prompt)
|
119 |
+
assert self.cond_stage_model is not None
|
120 |
+
gc_seg = kwargs.pop('gc_seg', [])
|
121 |
+
gc_seg = int(gc_seg[0]) if len(gc_seg) > 0 else 0
|
122 |
+
context = {}
|
123 |
+
|
124 |
+
# process image
|
125 |
+
image = to_device(image)
|
126 |
+
x_start = self.encode_first_stage(image, **kwargs)
|
127 |
+
x_start, x_shapes = pack_imagelist_into_tensor(x_start) # B, C, L
|
128 |
+
n, _, _ = x_start.shape
|
129 |
+
t = torch.randint(0, self.num_timesteps, (n, ),
|
130 |
+
device=x_start.device).long()
|
131 |
+
context['x_shapes'] = x_shapes
|
132 |
+
|
133 |
+
# process image mask
|
134 |
+
image_mask = to_device(image_mask, strict=False)
|
135 |
+
context['x_mask'] = [self.interpolate_func(i) for i in image_mask
|
136 |
+
] if image_mask is not None else [None] * n
|
137 |
+
|
138 |
+
# process text
|
139 |
+
# with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
|
140 |
+
prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt]
|
141 |
+
try:
|
142 |
+
cont, cont_mask = getattr(self.cond_stage_model,
|
143 |
+
'encode_list')(prompt_, return_mask=True)
|
144 |
+
except Exception as e:
|
145 |
+
print(e, prompt_)
|
146 |
+
cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
|
147 |
+
cont_mask)
|
148 |
+
context['crossattn'] = cont
|
149 |
+
|
150 |
+
# process edit image & edit image mask
|
151 |
+
edit_image = [to_device(i, strict=False) for i in edit_image]
|
152 |
+
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
153 |
+
e_img, e_mask = [], []
|
154 |
+
for u, m in zip(edit_image, edit_image_mask):
|
155 |
+
if m is None:
|
156 |
+
m = [None] * len(u) if u is not None else [None]
|
157 |
+
e_img.append(
|
158 |
+
self.encode_first_stage(u, **kwargs) if u is not None else u)
|
159 |
+
e_mask.append([
|
160 |
+
self.interpolate_func(i) if i is not None else None for i in m
|
161 |
+
])
|
162 |
+
context['edit'], context['edit_mask'] = e_img, e_mask
|
163 |
+
|
164 |
+
# process loss
|
165 |
+
loss = self.diffusion.loss(
|
166 |
+
x_0=x_start,
|
167 |
+
t=t,
|
168 |
+
noise=noise,
|
169 |
+
model=self.model,
|
170 |
+
model_kwargs={
|
171 |
+
'cond':
|
172 |
+
context,
|
173 |
+
'mask':
|
174 |
+
cont_mask,
|
175 |
+
'gc_seg':
|
176 |
+
gc_seg,
|
177 |
+
'text_position_embeddings':
|
178 |
+
self.text_position_embeddings.pos if hasattr(
|
179 |
+
self.text_position_embeddings, 'pos') else None
|
180 |
+
},
|
181 |
+
**kwargs)
|
182 |
+
loss = loss.mean()
|
183 |
+
ret = {'loss': loss, 'probe_data': {'prompt': prompt}}
|
184 |
+
return ret
|
185 |
+
|
186 |
+
@torch.no_grad()
|
187 |
+
def forward_test(self,
|
188 |
+
edit_image=[],
|
189 |
+
edit_image_mask=[],
|
190 |
+
image=None,
|
191 |
+
image_mask=None,
|
192 |
+
prompt=[],
|
193 |
+
n_prompt=[],
|
194 |
+
sampler='ddim',
|
195 |
+
sample_steps=20,
|
196 |
+
guide_scale=4.5,
|
197 |
+
guide_rescale=0.5,
|
198 |
+
log_num=-1,
|
199 |
+
seed=2024,
|
200 |
+
**kwargs):
|
201 |
+
|
202 |
+
assert check_list_of_list(prompt) and check_list_of_list(
|
203 |
+
edit_image) and check_list_of_list(edit_image_mask)
|
204 |
+
assert len(edit_image) == len(edit_image_mask) == len(prompt)
|
205 |
+
assert self.cond_stage_model is not None
|
206 |
+
# gc_seg is unused
|
207 |
+
kwargs.pop('gc_seg', -1)
|
208 |
+
# prepare data
|
209 |
+
context, null_context = {}, {}
|
210 |
+
|
211 |
+
prompt, n_prompt, image, image_mask, edit_image, edit_image_mask = self.limit_batch_data(
|
212 |
+
[prompt, n_prompt, image, image_mask, edit_image, edit_image_mask],
|
213 |
+
log_num)
|
214 |
+
g = torch.Generator(device=we.device_id)
|
215 |
+
seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
|
216 |
+
g.manual_seed(seed)
|
217 |
+
n_prompt = copy.deepcopy(prompt)
|
218 |
+
# only modify the last prompt to be zero
|
219 |
+
for nn_p_id, nn_p in enumerate(n_prompt):
|
220 |
+
if isinstance(nn_p, str):
|
221 |
+
n_prompt[nn_p_id] = ['']
|
222 |
+
elif isinstance(nn_p, list):
|
223 |
+
n_prompt[nn_p_id][-1] = ''
|
224 |
+
else:
|
225 |
+
raise NotImplementedError
|
226 |
+
# process image
|
227 |
+
image = to_device(image)
|
228 |
+
x = self.encode_first_stage(image, **kwargs)
|
229 |
+
noise = [
|
230 |
+
torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
|
231 |
+
for i in x
|
232 |
+
]
|
233 |
+
noise, x_shapes = pack_imagelist_into_tensor(noise)
|
234 |
+
context['x_shapes'] = null_context['x_shapes'] = x_shapes
|
235 |
+
|
236 |
+
# process image mask
|
237 |
+
image_mask = to_device(image_mask, strict=False)
|
238 |
+
cond_mask = [self.interpolate_func(i) for i in image_mask
|
239 |
+
] if image_mask is not None else [None] * len(image)
|
240 |
+
context['x_mask'] = null_context['x_mask'] = cond_mask
|
241 |
+
# process text
|
242 |
+
# with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
|
243 |
+
prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt]
|
244 |
+
cont, cont_mask = getattr(self.cond_stage_model,
|
245 |
+
'encode_list')(prompt_, return_mask=True)
|
246 |
+
cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
|
247 |
+
cont_mask)
|
248 |
+
null_cont, null_cont_mask = getattr(self.cond_stage_model,
|
249 |
+
'encode_list')(n_prompt,
|
250 |
+
return_mask=True)
|
251 |
+
null_cont, null_cont_mask = self.cond_stage_embeddings(
|
252 |
+
prompt, edit_image, null_cont, null_cont_mask)
|
253 |
+
context['crossattn'] = cont
|
254 |
+
null_context['crossattn'] = null_cont
|
255 |
+
|
256 |
+
# processe edit image & edit image mask
|
257 |
+
edit_image = [to_device(i, strict=False) for i in edit_image]
|
258 |
+
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
259 |
+
e_img, e_mask = [], []
|
260 |
+
for u, m in zip(edit_image, edit_image_mask):
|
261 |
+
if u is None:
|
262 |
+
continue
|
263 |
+
if m is None:
|
264 |
+
m = [None] * len(u)
|
265 |
+
e_img.append(self.encode_first_stage(u, **kwargs))
|
266 |
+
e_mask.append([self.interpolate_func(i) for i in m])
|
267 |
+
null_context['edit'] = context['edit'] = e_img
|
268 |
+
null_context['edit_mask'] = context['edit_mask'] = e_mask
|
269 |
+
|
270 |
+
# process sample
|
271 |
+
model = self.model_ema if self.use_ema and self.eval_ema else self.model
|
272 |
+
embedding_context = model.no_sync if isinstance(model, torch.distributed.fsdp.FullyShardedDataParallel) \
|
273 |
+
else nullcontext
|
274 |
+
with embedding_context():
|
275 |
+
samples = self.diffusion.sample(
|
276 |
+
sampler=sampler,
|
277 |
+
noise=noise,
|
278 |
+
model=model,
|
279 |
+
model_kwargs=[{
|
280 |
+
'cond':
|
281 |
+
context,
|
282 |
+
'mask':
|
283 |
+
cont_mask,
|
284 |
+
'text_position_embeddings':
|
285 |
+
self.text_position_embeddings.pos if hasattr(
|
286 |
+
self.text_position_embeddings, 'pos') else None
|
287 |
+
}, {
|
288 |
+
'cond':
|
289 |
+
null_context,
|
290 |
+
'mask':
|
291 |
+
null_cont_mask,
|
292 |
+
'text_position_embeddings':
|
293 |
+
self.text_position_embeddings.pos if hasattr(
|
294 |
+
self.text_position_embeddings, 'pos') else None
|
295 |
+
}] if guide_scale is not None and guide_scale > 1 else {
|
296 |
+
'cond':
|
297 |
+
context,
|
298 |
+
'mask':
|
299 |
+
cont_mask,
|
300 |
+
'text_position_embeddings':
|
301 |
+
self.text_position_embeddings.pos if hasattr(
|
302 |
+
self.text_position_embeddings, 'pos') else None
|
303 |
+
},
|
304 |
+
steps=sample_steps,
|
305 |
+
guide_scale=guide_scale,
|
306 |
+
guide_rescale=guide_rescale,
|
307 |
+
show_progress=True,
|
308 |
+
**kwargs)
|
309 |
+
|
310 |
+
samples = unpack_tensor_into_imagelist(samples, x_shapes)
|
311 |
+
x_samples = self.decode_first_stage(samples)
|
312 |
+
outputs = list()
|
313 |
+
for i in range(len(prompt)):
|
314 |
+
rec_img = torch.clamp(
|
315 |
+
(x_samples[i] + 1.0) / 2.0 + self.decoder_bias / 255,
|
316 |
+
min=0.0,
|
317 |
+
max=1.0)
|
318 |
+
rec_img = rec_img.squeeze(0)
|
319 |
+
edit_imgs, edit_img_masks = [], []
|
320 |
+
if edit_image is not None and edit_image[i] is not None:
|
321 |
+
if edit_image_mask[i] is None:
|
322 |
+
edit_image_mask[i] = [None] * len(edit_image[i])
|
323 |
+
for edit_img, edit_mask in zip(edit_image[i],
|
324 |
+
edit_image_mask[i]):
|
325 |
+
edit_img = torch.clamp((edit_img + 1.0) / 2.0,
|
326 |
+
min=0.0,
|
327 |
+
max=1.0)
|
328 |
+
edit_imgs.append(edit_img.squeeze(0))
|
329 |
+
if edit_mask is None:
|
330 |
+
edit_mask = torch.ones_like(edit_img[[0], :, :])
|
331 |
+
edit_img_masks.append(edit_mask)
|
332 |
+
one_tup = {
|
333 |
+
'reconstruct_image': rec_img,
|
334 |
+
'instruction': prompt[i],
|
335 |
+
'edit_image': edit_imgs if len(edit_imgs) > 0 else None,
|
336 |
+
'edit_mask': edit_img_masks if len(edit_imgs) > 0 else None
|
337 |
+
}
|
338 |
+
if image is not None:
|
339 |
+
if image_mask is None:
|
340 |
+
image_mask = [None] * len(image)
|
341 |
+
ori_img = torch.clamp((image[i] + 1.0) / 2.0, min=0.0, max=1.0)
|
342 |
+
one_tup['target_image'] = ori_img.squeeze(0)
|
343 |
+
one_tup['target_mask'] = image_mask[i] if image_mask[
|
344 |
+
i] is not None else torch.ones_like(ori_img[[0], :, :])
|
345 |
+
outputs.append(one_tup)
|
346 |
+
return outputs
|
347 |
+
|
348 |
+
@staticmethod
|
349 |
+
def get_config_template():
|
350 |
+
return dict_to_yaml('MODEL',
|
351 |
+
__class__.__name__,
|
352 |
+
LdmACE.para_dict,
|
353 |
+
set_name=True)
|
modules/model/utils/basic_utils.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
from inspect import isfunction
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.nn.utils.rnn import pad_sequence
|
7 |
+
|
8 |
+
from scepter.modules.utils.distribute import we
|
9 |
+
|
10 |
+
|
11 |
+
def exists(x):
|
12 |
+
return x is not None
|
13 |
+
|
14 |
+
|
15 |
+
def default(val, d):
|
16 |
+
if exists(val):
|
17 |
+
return val
|
18 |
+
return d() if isfunction(d) else d
|
19 |
+
|
20 |
+
|
21 |
+
def disabled_train(self, mode=True):
|
22 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
23 |
+
does not change anymore."""
|
24 |
+
return self
|
25 |
+
|
26 |
+
|
27 |
+
def transfer_size(para_num):
|
28 |
+
if para_num > 1000 * 1000 * 1000 * 1000:
|
29 |
+
bill = para_num / (1000 * 1000 * 1000 * 1000)
|
30 |
+
return '{:.2f}T'.format(bill)
|
31 |
+
elif para_num > 1000 * 1000 * 1000:
|
32 |
+
gyte = para_num / (1000 * 1000 * 1000)
|
33 |
+
return '{:.2f}B'.format(gyte)
|
34 |
+
elif para_num > (1000 * 1000):
|
35 |
+
meta = para_num / (1000 * 1000)
|
36 |
+
return '{:.2f}M'.format(meta)
|
37 |
+
elif para_num > 1000:
|
38 |
+
kelo = para_num / 1000
|
39 |
+
return '{:.2f}K'.format(kelo)
|
40 |
+
else:
|
41 |
+
return para_num
|
42 |
+
|
43 |
+
|
44 |
+
def count_params(model):
|
45 |
+
total_params = sum(p.numel() for p in model.parameters())
|
46 |
+
return transfer_size(total_params)
|
47 |
+
|
48 |
+
|
49 |
+
def expand_dims_like(x, y):
|
50 |
+
while x.dim() != y.dim():
|
51 |
+
x = x.unsqueeze(-1)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
def unpack_tensor_into_imagelist(image_tensor, shapes):
|
56 |
+
image_list = []
|
57 |
+
for img, shape in zip(image_tensor, shapes):
|
58 |
+
h, w = shape[0], shape[1]
|
59 |
+
image_list.append(img[:, :h * w].view(1, -1, h, w))
|
60 |
+
|
61 |
+
return image_list
|
62 |
+
|
63 |
+
|
64 |
+
def find_example(tensor_list, image_list):
|
65 |
+
for i in tensor_list:
|
66 |
+
if isinstance(i, torch.Tensor):
|
67 |
+
return torch.zeros_like(i)
|
68 |
+
for i in image_list:
|
69 |
+
if isinstance(i, torch.Tensor):
|
70 |
+
_, c, h, w = i.size()
|
71 |
+
return torch.zeros_like(i.view(c, h * w).transpose(1, 0))
|
72 |
+
return None
|
73 |
+
|
74 |
+
|
75 |
+
def pack_imagelist_into_tensor_v2(image_list):
|
76 |
+
# allow None
|
77 |
+
example = None
|
78 |
+
image_tensor, shapes = [], []
|
79 |
+
for img in image_list:
|
80 |
+
if img is None:
|
81 |
+
example = find_example(image_tensor,
|
82 |
+
image_list) if example is None else example
|
83 |
+
image_tensor.append(example)
|
84 |
+
shapes.append(None)
|
85 |
+
continue
|
86 |
+
_, c, h, w = img.size()
|
87 |
+
image_tensor.append(img.view(c, h * w).transpose(1, 0)) # h*w, c
|
88 |
+
shapes.append((h, w))
|
89 |
+
|
90 |
+
image_tensor = pad_sequence(image_tensor,
|
91 |
+
batch_first=True).permute(0, 2, 1) # b, c, l
|
92 |
+
return image_tensor, shapes
|
93 |
+
|
94 |
+
|
95 |
+
def to_device(inputs, strict=True):
|
96 |
+
if inputs is None:
|
97 |
+
return None
|
98 |
+
if strict:
|
99 |
+
assert all(isinstance(i, torch.Tensor) for i in inputs)
|
100 |
+
return [i.to(we.device_id) if i is not None else None for i in inputs]
|
101 |
+
|
102 |
+
|
103 |
+
def check_list_of_list(ll):
|
104 |
+
return isinstance(ll, list) and all(isinstance(i, list) for i in ll)
|
modules/solver/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .ace_solver import ACESolverV1
|
modules/solver/ace_solver.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
from scepter.modules.utils.data import transfer_data_to_cuda
|
8 |
+
from scepter.modules.utils.distribute import we
|
9 |
+
from scepter.modules.utils.probe import ProbeData
|
10 |
+
from scepter.modules.solver.registry import SOLVERS
|
11 |
+
from scepter.modules.solver.diffusion_solver import LatentDiffusionSolver
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
@SOLVERS.register_class()
|
16 |
+
class ACESolverV1(LatentDiffusionSolver):
|
17 |
+
def __init__(self, cfg, logger=None):
|
18 |
+
super().__init__(cfg, logger=logger)
|
19 |
+
self.log_train_num = cfg.get('LOG_TRAIN_NUM', -1)
|
20 |
+
|
21 |
+
def save_results(self, results):
|
22 |
+
log_data, log_label = [], []
|
23 |
+
for result in results:
|
24 |
+
ret_images, ret_labels = [], []
|
25 |
+
edit_image = result.get('edit_image', None)
|
26 |
+
edit_mask = result.get('edit_mask', None)
|
27 |
+
if edit_image is not None:
|
28 |
+
for i, edit_img in enumerate(result['edit_image']):
|
29 |
+
if edit_img is None:
|
30 |
+
continue
|
31 |
+
ret_images.append(
|
32 |
+
(edit_img.permute(1, 2, 0).cpu().numpy() * 255).astype(
|
33 |
+
np.uint8))
|
34 |
+
ret_labels.append(f'edit_image{i}; ')
|
35 |
+
if edit_mask is not None:
|
36 |
+
ret_images.append(
|
37 |
+
(edit_mask[i].permute(1, 2, 0).cpu().numpy() *
|
38 |
+
255).astype(np.uint8))
|
39 |
+
ret_labels.append(f'edit_mask{i}; ')
|
40 |
+
|
41 |
+
target_image = result.get('target_image', None)
|
42 |
+
target_mask = result.get('target_mask', None)
|
43 |
+
if target_image is not None:
|
44 |
+
ret_images.append(
|
45 |
+
(target_image.permute(1, 2, 0).cpu().numpy() * 255).astype(
|
46 |
+
np.uint8))
|
47 |
+
ret_labels.append('target_image; ')
|
48 |
+
if target_mask is not None:
|
49 |
+
ret_images.append(
|
50 |
+
(target_mask.permute(1, 2, 0).cpu().numpy() *
|
51 |
+
255).astype(np.uint8))
|
52 |
+
ret_labels.append('target_mask; ')
|
53 |
+
|
54 |
+
reconstruct_image = result.get('reconstruct_image', None)
|
55 |
+
if reconstruct_image is not None:
|
56 |
+
ret_images.append(
|
57 |
+
(reconstruct_image.permute(1, 2, 0).cpu().numpy() *
|
58 |
+
255).astype(np.uint8))
|
59 |
+
ret_labels.append(f"{result['instruction']}")
|
60 |
+
log_data.append(ret_images)
|
61 |
+
log_label.append(ret_labels)
|
62 |
+
return log_data, log_label
|
63 |
+
|
64 |
+
@torch.no_grad()
|
65 |
+
def run_eval(self):
|
66 |
+
self.eval_mode()
|
67 |
+
self.before_all_iter(self.hooks_dict[self._mode])
|
68 |
+
all_results = []
|
69 |
+
for batch_idx, batch_data in tqdm(
|
70 |
+
enumerate(self.datas[self._mode].dataloader)):
|
71 |
+
self.before_iter(self.hooks_dict[self._mode])
|
72 |
+
if self.sample_args:
|
73 |
+
batch_data.update(self.sample_args.get_lowercase_dict())
|
74 |
+
with torch.autocast(device_type='cuda',
|
75 |
+
enabled=self.use_amp,
|
76 |
+
dtype=self.dtype):
|
77 |
+
results = self.run_step_eval(transfer_data_to_cuda(batch_data),
|
78 |
+
batch_idx,
|
79 |
+
step=self.total_iter,
|
80 |
+
rank=we.rank)
|
81 |
+
all_results.extend(results)
|
82 |
+
self.after_iter(self.hooks_dict[self._mode])
|
83 |
+
log_data, log_label = self.save_results(all_results)
|
84 |
+
self.register_probe({'eval_label': log_label})
|
85 |
+
self.register_probe({
|
86 |
+
'eval_image':
|
87 |
+
ProbeData(log_data,
|
88 |
+
is_image=True,
|
89 |
+
build_html=True,
|
90 |
+
build_label=log_label)
|
91 |
+
})
|
92 |
+
self.after_all_iter(self.hooks_dict[self._mode])
|
93 |
+
|
94 |
+
@torch.no_grad()
|
95 |
+
def run_test(self):
|
96 |
+
self.test_mode()
|
97 |
+
self.before_all_iter(self.hooks_dict[self._mode])
|
98 |
+
all_results = []
|
99 |
+
for batch_idx, batch_data in tqdm(
|
100 |
+
enumerate(self.datas[self._mode].dataloader)):
|
101 |
+
self.before_iter(self.hooks_dict[self._mode])
|
102 |
+
if self.sample_args:
|
103 |
+
batch_data.update(self.sample_args.get_lowercase_dict())
|
104 |
+
with torch.autocast(device_type='cuda',
|
105 |
+
enabled=self.use_amp,
|
106 |
+
dtype=self.dtype):
|
107 |
+
results = self.run_step_eval(transfer_data_to_cuda(batch_data),
|
108 |
+
batch_idx,
|
109 |
+
step=self.total_iter,
|
110 |
+
rank=we.rank)
|
111 |
+
all_results.extend(results)
|
112 |
+
self.after_iter(self.hooks_dict[self._mode])
|
113 |
+
log_data, log_label = self.save_results(all_results)
|
114 |
+
self.register_probe({'test_label': log_label})
|
115 |
+
self.register_probe({
|
116 |
+
'test_image':
|
117 |
+
ProbeData(log_data,
|
118 |
+
is_image=True,
|
119 |
+
build_html=True,
|
120 |
+
build_label=log_label)
|
121 |
+
})
|
122 |
+
|
123 |
+
self.after_all_iter(self.hooks_dict[self._mode])
|
124 |
+
|
125 |
+
@property
|
126 |
+
def probe_data(self):
|
127 |
+
if not we.debug and self.mode == 'train':
|
128 |
+
batch_data = transfer_data_to_cuda(
|
129 |
+
self.current_batch_data[self.mode])
|
130 |
+
self.eval_mode()
|
131 |
+
with torch.autocast(device_type='cuda',
|
132 |
+
enabled=self.use_amp,
|
133 |
+
dtype=self.dtype):
|
134 |
+
batch_data['log_num'] = self.log_train_num
|
135 |
+
results = self.run_step_eval(batch_data)
|
136 |
+
self.train_mode()
|
137 |
+
log_data, log_label = self.save_results(results)
|
138 |
+
self.register_probe({
|
139 |
+
'train_image':
|
140 |
+
ProbeData(log_data,
|
141 |
+
is_image=True,
|
142 |
+
build_html=True,
|
143 |
+
build_label=log_label)
|
144 |
+
})
|
145 |
+
self.register_probe({'train_label': log_label})
|
146 |
+
return super(LatentDiffusionSolver, self).probe_data
|
requirements.txt
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
huggingface_hub==0.25.2
|
|
|
|
1 |
+
huggingface_hub==0.25.2
|
2 |
+
scepter>=1.2.0
|
utils.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import torch
|
4 |
+
import torchvision.transforms as T
|
5 |
+
from PIL import Image
|
6 |
+
from torchvision.transforms.functional import InterpolationMode
|
7 |
+
|
8 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
9 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
10 |
+
|
11 |
+
|
12 |
+
def build_transform(input_size):
|
13 |
+
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
14 |
+
transform = T.Compose([
|
15 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
16 |
+
T.Resize((input_size, input_size),
|
17 |
+
interpolation=InterpolationMode.BICUBIC),
|
18 |
+
T.ToTensor(),
|
19 |
+
T.Normalize(mean=MEAN, std=STD)
|
20 |
+
])
|
21 |
+
return transform
|
22 |
+
|
23 |
+
|
24 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
25 |
+
image_size):
|
26 |
+
best_ratio_diff = float('inf')
|
27 |
+
best_ratio = (1, 1)
|
28 |
+
area = width * height
|
29 |
+
for ratio in target_ratios:
|
30 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
31 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
32 |
+
if ratio_diff < best_ratio_diff:
|
33 |
+
best_ratio_diff = ratio_diff
|
34 |
+
best_ratio = ratio
|
35 |
+
elif ratio_diff == best_ratio_diff:
|
36 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
37 |
+
best_ratio = ratio
|
38 |
+
return best_ratio
|
39 |
+
|
40 |
+
|
41 |
+
def dynamic_preprocess(image,
|
42 |
+
min_num=1,
|
43 |
+
max_num=12,
|
44 |
+
image_size=448,
|
45 |
+
use_thumbnail=False):
|
46 |
+
orig_width, orig_height = image.size
|
47 |
+
aspect_ratio = orig_width / orig_height
|
48 |
+
|
49 |
+
# calculate the existing image aspect ratio
|
50 |
+
target_ratios = set((i, j) for n in range(min_num, max_num + 1)
|
51 |
+
for i in range(1, n + 1) for j in range(1, n + 1)
|
52 |
+
if i * j <= max_num and i * j >= min_num)
|
53 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
54 |
+
|
55 |
+
# find the closest aspect ratio to the target
|
56 |
+
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
|
57 |
+
target_ratios, orig_width,
|
58 |
+
orig_height, image_size)
|
59 |
+
|
60 |
+
# calculate the target width and height
|
61 |
+
target_width = image_size * target_aspect_ratio[0]
|
62 |
+
target_height = image_size * target_aspect_ratio[1]
|
63 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
64 |
+
|
65 |
+
# resize the image
|
66 |
+
resized_img = image.resize((target_width, target_height))
|
67 |
+
processed_images = []
|
68 |
+
for i in range(blocks):
|
69 |
+
box = ((i % (target_width // image_size)) * image_size,
|
70 |
+
(i // (target_width // image_size)) * image_size,
|
71 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
72 |
+
((i // (target_width // image_size)) + 1) * image_size)
|
73 |
+
# split the image
|
74 |
+
split_img = resized_img.crop(box)
|
75 |
+
processed_images.append(split_img)
|
76 |
+
assert len(processed_images) == blocks
|
77 |
+
if use_thumbnail and len(processed_images) != 1:
|
78 |
+
thumbnail_img = image.resize((image_size, image_size))
|
79 |
+
processed_images.append(thumbnail_img)
|
80 |
+
return processed_images
|
81 |
+
|
82 |
+
|
83 |
+
def load_image(image_file, input_size=448, max_num=12):
|
84 |
+
if isinstance(image_file, str):
|
85 |
+
image = Image.open(image_file).convert('RGB')
|
86 |
+
else:
|
87 |
+
image = image_file
|
88 |
+
transform = build_transform(input_size=input_size)
|
89 |
+
images = dynamic_preprocess(image,
|
90 |
+
image_size=input_size,
|
91 |
+
use_thumbnail=True,
|
92 |
+
max_num=max_num)
|
93 |
+
pixel_values = [transform(image) for image in images]
|
94 |
+
pixel_values = torch.stack(pixel_values)
|
95 |
+
return pixel_values
|