xianbao HF staff commited on
Commit
72895aa
·
1 Parent(s): bd81c55

Upload with huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .ipynb_checkpoints/env-checkpoint.py +13 -0
  2. README.md +6 -6
  3. app.py +1677 -0
  4. env.py +13 -0
  5. ppdiffusers/__init__.py +162 -0
  6. ppdiffusers/__pycache__/__init__.cpython-37.pyc +0 -0
  7. ppdiffusers/__pycache__/configuration_utils.cpython-37.pyc +0 -0
  8. ppdiffusers/__pycache__/download_utils.cpython-37.pyc +0 -0
  9. ppdiffusers/__pycache__/fastdeploy_utils.cpython-37.pyc +0 -0
  10. ppdiffusers/__pycache__/initializer.cpython-37.pyc +0 -0
  11. ppdiffusers/__pycache__/loaders.cpython-37.pyc +0 -0
  12. ppdiffusers/__pycache__/modeling_utils.cpython-37.pyc +0 -0
  13. ppdiffusers/__pycache__/optimization.cpython-37.pyc +0 -0
  14. ppdiffusers/__pycache__/pipeline_utils.cpython-37.pyc +0 -0
  15. ppdiffusers/__pycache__/ppnlp_patch_utils.cpython-37.pyc +0 -0
  16. ppdiffusers/__pycache__/training_utils.cpython-37.pyc +0 -0
  17. ppdiffusers/__pycache__/version.cpython-37.pyc +0 -0
  18. ppdiffusers/commands/__init__.py +28 -0
  19. ppdiffusers/commands/env.py +67 -0
  20. ppdiffusers/commands/ppdiffusers_cli.py +41 -0
  21. ppdiffusers/configuration_utils.py +591 -0
  22. ppdiffusers/download_utils.py +44 -0
  23. ppdiffusers/experimental/README.md +6 -0
  24. ppdiffusers/experimental/__init__.py +17 -0
  25. ppdiffusers/experimental/rl/__init__.py +17 -0
  26. ppdiffusers/experimental/rl/value_guided_sampling.py +146 -0
  27. ppdiffusers/fastdeploy_utils.py +260 -0
  28. ppdiffusers/initializer.py +303 -0
  29. ppdiffusers/loaders.py +190 -0
  30. ppdiffusers/modeling_paddle_pytorch_utils.py +106 -0
  31. ppdiffusers/modeling_utils.py +619 -0
  32. ppdiffusers/models/__init__.py +25 -0
  33. ppdiffusers/models/__pycache__/__init__.cpython-37.pyc +0 -0
  34. ppdiffusers/models/__pycache__/attention.cpython-37.pyc +0 -0
  35. ppdiffusers/models/__pycache__/cross_attention.cpython-37.pyc +0 -0
  36. ppdiffusers/models/__pycache__/embeddings.cpython-37.pyc +0 -0
  37. ppdiffusers/models/__pycache__/prior_transformer.cpython-37.pyc +0 -0
  38. ppdiffusers/models/__pycache__/resnet.cpython-37.pyc +0 -0
  39. ppdiffusers/models/__pycache__/unet_1d.cpython-37.pyc +0 -0
  40. ppdiffusers/models/__pycache__/unet_1d_blocks.cpython-37.pyc +0 -0
  41. ppdiffusers/models/__pycache__/unet_2d.cpython-37.pyc +0 -0
  42. ppdiffusers/models/__pycache__/unet_2d_blocks.cpython-37.pyc +0 -0
  43. ppdiffusers/models/__pycache__/unet_2d_condition.cpython-37.pyc +0 -0
  44. ppdiffusers/models/__pycache__/vae.cpython-37.pyc +0 -0
  45. ppdiffusers/models/attention.py +683 -0
  46. ppdiffusers/models/cross_attention.py +435 -0
  47. ppdiffusers/models/ema.py +103 -0
  48. ppdiffusers/models/embeddings.py +199 -0
  49. ppdiffusers/models/prior_transformer.py +220 -0
  50. ppdiffusers/models/resnet.py +716 -0
.ipynb_checkpoints/env-checkpoint.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ############################################################################################################################
2
+ # 修改下面的参数
3
+ # (1)BASE_MODEL_NAME 代表你训练的基础模型
4
+ BASE_MODEL_NAME = "runwayml/stable-diffusion-v1-5"
5
+
6
+ # 是否开启lora
7
+ # (2)LORA_WEIGHTS_PATH 代码你上传到huggingface后的lora权重。
8
+ # LORA_WEIGHTS_PATH = None 表示不适应lora
9
+ LORA_WEIGHTS_PATH = "xianbao/demo_test"
10
+
11
+ # (3)PROMPTS 需要展示的prompt文本
12
+ PROMPTS = "A photo of sks dog in a bucket"
13
+ ############################################################################################################################
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: Demo Test
3
- emoji: 💻
4
- colorFrom: blue
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: LoRa ppdiffusers dreambooth
3
+ emoji: 🎨🎞️
4
+ colorFrom: pink
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.18.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import gradio as gr
16
+ from env import BASE_MODEL_NAME, LORA_WEIGHTS_PATH, PROMPTS
17
+
18
+ examples = [
19
+ [
20
+ PROMPTS,
21
+ 'low quality',
22
+ 7.5,
23
+ 512,
24
+ 512,
25
+ 25,
26
+ "DPMSolver"
27
+ ],
28
+ ]
29
+ import inspect
30
+ import os
31
+ import random
32
+ import re
33
+ import time
34
+ from typing import Callable, List, Optional, Union
35
+
36
+ import numpy as np
37
+ import paddle
38
+ import PIL
39
+ import PIL.Image
40
+ from packaging import version
41
+
42
+ from paddlenlp.transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
43
+
44
+ from ppdiffusers.configuration_utils import FrozenDict
45
+ from ppdiffusers.models import AutoencoderKL, UNet2DConditionModel
46
+ from ppdiffusers.pipeline_utils import DiffusionPipeline
47
+ from ppdiffusers.schedulers import (
48
+ DDIMScheduler,
49
+ DPMSolverMultistepScheduler,
50
+ EulerAncestralDiscreteScheduler,
51
+ EulerDiscreteScheduler,
52
+ LMSDiscreteScheduler,
53
+ PNDMScheduler,
54
+ HeunDiscreteScheduler,
55
+ KDPM2AncestralDiscreteScheduler,
56
+ KDPM2DiscreteScheduler,
57
+
58
+ )
59
+ from ppdiffusers.utils import PIL_INTERPOLATION, deprecate, logging
60
+ from ppdiffusers.utils.testing_utils import load_image
61
+ from ppdiffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
62
+ from ppdiffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
63
+
64
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
65
+
66
+
67
+ def save_all(images, FORMAT="jpg", OUTDIR="./outputs/"):
68
+ if not isinstance(images, (list, tuple)):
69
+ images = [images]
70
+ for image in images:
71
+ PRECISION = "fp32"
72
+ argument = image.argument
73
+ os.makedirs(OUTDIR, exist_ok=True)
74
+ epoch_time = argument["epoch_time"]
75
+ PROMPT = argument["prompt"]
76
+ NEGPROMPT = argument["negative_prompt"]
77
+ HEIGHT = argument["height"]
78
+ WIDTH = argument["width"]
79
+ SEED = argument["seed"]
80
+ STRENGTH = argument.get("strength", 1)
81
+ INFERENCE_STEPS = argument["num_inference_steps"]
82
+ GUIDANCE_SCALE = argument["guidance_scale"]
83
+
84
+ filename = f"{str(epoch_time)}_scale_{GUIDANCE_SCALE}_steps_{INFERENCE_STEPS}_seed_{SEED}.{FORMAT}"
85
+ filedir = f"{OUTDIR}/{filename}"
86
+ image.save(filedir)
87
+ with open(f"{OUTDIR}/{epoch_time}_prompt.txt", "w") as file:
88
+ file.write(
89
+ f"PROMPT: {PROMPT}\nNEG_PROMPT: {NEGPROMPT}\n\nINFERENCE_STEPS: {INFERENCE_STEPS}\nHeight: {HEIGHT}\nWidth: {WIDTH}\nSeed: {SEED}\n\nPrecision: {PRECISION}\nSTRENGTH: {STRENGTH}\nGUIDANCE_SCALE: {GUIDANCE_SCALE}"
90
+ )
91
+
92
+
93
+ re_attention = re.compile(
94
+ r"""
95
+ \\\(|
96
+ \\\)|
97
+ \\\[|
98
+ \\]|
99
+ \\\\|
100
+ \\|
101
+ \(|
102
+ \[|
103
+ :([+-]?[.\d]+)\)|
104
+ \)|
105
+ ]|
106
+ [^\\()\[\]:]+|
107
+ :
108
+ """,
109
+ re.X,
110
+ )
111
+
112
+
113
+ def parse_prompt_attention(text):
114
+ """
115
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
116
+ Accepted tokens are:
117
+ (abc) - increases attention to abc by a multiplier of 1.1
118
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
119
+ [abc] - decreases attention to abc by a multiplier of 1.1
120
+ \( - literal character '('
121
+ \[ - literal character '['
122
+ \) - literal character ')'
123
+ \] - literal character ']'
124
+ \\ - literal character '\'
125
+ anything else - just text
126
+ >>> parse_prompt_attention('normal text')
127
+ [['normal text', 1.0]]
128
+ >>> parse_prompt_attention('an (important) word')
129
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
130
+ >>> parse_prompt_attention('(unbalanced')
131
+ [['unbalanced', 1.1]]
132
+ >>> parse_prompt_attention('\(literal\]')
133
+ [['(literal]', 1.0]]
134
+ >>> parse_prompt_attention('(unnecessary)(parens)')
135
+ [['unnecessaryparens', 1.1]]
136
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
137
+ [['a ', 1.0],
138
+ ['house', 1.5730000000000004],
139
+ [' ', 1.1],
140
+ ['on', 1.0],
141
+ [' a ', 1.1],
142
+ ['hill', 0.55],
143
+ [', sun, ', 1.1],
144
+ ['sky', 1.4641000000000006],
145
+ ['.', 1.1]]
146
+ """
147
+
148
+ res = []
149
+ round_brackets = []
150
+ square_brackets = []
151
+
152
+ round_bracket_multiplier = 1.1
153
+ square_bracket_multiplier = 1 / 1.1
154
+
155
+ def multiply_range(start_position, multiplier):
156
+ for p in range(start_position, len(res)):
157
+ res[p][1] *= multiplier
158
+
159
+ for m in re_attention.finditer(text):
160
+ text = m.group(0)
161
+ weight = m.group(1)
162
+
163
+ if text.startswith("\\"):
164
+ res.append([text[1:], 1.0])
165
+ elif text == "(":
166
+ round_brackets.append(len(res))
167
+ elif text == "[":
168
+ square_brackets.append(len(res))
169
+ elif weight is not None and len(round_brackets) > 0:
170
+ multiply_range(round_brackets.pop(), float(weight))
171
+ elif text == ")" and len(round_brackets) > 0:
172
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
173
+ elif text == "]" and len(square_brackets) > 0:
174
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
175
+ else:
176
+ res.append([text, 1.0])
177
+
178
+ for pos in round_brackets:
179
+ multiply_range(pos, round_bracket_multiplier)
180
+
181
+ for pos in square_brackets:
182
+ multiply_range(pos, square_bracket_multiplier)
183
+
184
+ if len(res) == 0:
185
+ res = [["", 1.0]]
186
+
187
+ # merge runs of identical weights
188
+ i = 0
189
+ while i + 1 < len(res):
190
+ if res[i][1] == res[i + 1][1]:
191
+ res[i][0] += res[i + 1][0]
192
+ res.pop(i + 1)
193
+ else:
194
+ i += 1
195
+
196
+ return res
197
+
198
+
199
+ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
200
+ r"""
201
+ Tokenize a list of prompts and return its tokens with weights of each token.
202
+
203
+ No padding, starting or ending token is included.
204
+ """
205
+ tokens = []
206
+ weights = []
207
+ for text in prompt:
208
+ texts_and_weights = parse_prompt_attention(text)
209
+ text_token = []
210
+ text_weight = []
211
+ for word, weight in texts_and_weights:
212
+ # tokenize and discard the starting and the ending token
213
+ token = pipe.tokenizer(word).input_ids[1:-1]
214
+ text_token += token
215
+
216
+ # copy the weight by length of token
217
+ text_weight += [weight] * len(token)
218
+
219
+ # stop if the text is too long (longer than truncation limit)
220
+ if len(text_token) > max_length:
221
+ break
222
+
223
+ # truncate
224
+ if len(text_token) > max_length:
225
+ text_token = text_token[:max_length]
226
+ text_weight = text_weight[:max_length]
227
+
228
+ tokens.append(text_token)
229
+ weights.append(text_weight)
230
+ return tokens, weights
231
+
232
+
233
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
234
+ r"""
235
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
236
+ """
237
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
238
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
239
+ for i in range(len(tokens)):
240
+ tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
241
+ if no_boseos_middle:
242
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
243
+ else:
244
+ w = []
245
+ if len(weights[i]) == 0:
246
+ w = [1.0] * weights_length
247
+ else:
248
+ for j in range((len(weights[i]) - 1) // chunk_length + 1):
249
+ w.append(1.0) # weight for starting token in this chunk
250
+ w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
251
+ w.append(1.0) # weight for ending token in this chunk
252
+ w += [1.0] * (weights_length - len(w))
253
+ weights[i] = w[:]
254
+
255
+ return tokens, weights
256
+
257
+
258
+ def get_unweighted_text_embeddings(
259
+ pipe: DiffusionPipeline, text_input: paddle.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True
260
+ ):
261
+ """
262
+ When the length of tokens is a multiple of the capacity of the text encoder,
263
+ it should be split into chunks and sent to the text encoder individually.
264
+ """
265
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
266
+ if max_embeddings_multiples > 1:
267
+ text_embeddings = []
268
+ for i in range(max_embeddings_multiples):
269
+ # extract the i-th chunk
270
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
271
+
272
+ # cover the head and the tail by the starting and the ending tokens
273
+ text_input_chunk[:, 0] = text_input[0, 0]
274
+ text_input_chunk[:, -1] = text_input[0, -1]
275
+
276
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
277
+
278
+ if no_boseos_middle:
279
+ if i == 0:
280
+ # discard the ending token
281
+ text_embedding = text_embedding[:, :-1]
282
+ elif i == max_embeddings_multiples - 1:
283
+ # discard the starting token
284
+ text_embedding = text_embedding[:, 1:]
285
+ else:
286
+ # discard both starting and ending tokens
287
+ text_embedding = text_embedding[:, 1:-1]
288
+
289
+ text_embeddings.append(text_embedding)
290
+ text_embeddings = paddle.concat(text_embeddings, axis=1)
291
+ else:
292
+ text_embeddings = pipe.text_encoder(text_input)[0]
293
+ return text_embeddings
294
+
295
+
296
+ def get_weighted_text_embeddings(
297
+ pipe: DiffusionPipeline,
298
+ prompt: Union[str, List[str]],
299
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
300
+ max_embeddings_multiples: Optional[int] = 1,
301
+ no_boseos_middle: Optional[bool] = False,
302
+ skip_parsing: Optional[bool] = False,
303
+ skip_weighting: Optional[bool] = False,
304
+ **kwargs
305
+ ):
306
+ r"""
307
+ Prompts can be assigned with local weights using brackets. For example,
308
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
309
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
310
+
311
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
312
+
313
+ Args:
314
+ pipe (`DiffusionPipeline`):
315
+ Pipe to provide access to the tokenizer and the text encoder.
316
+ prompt (`str` or `List[str]`):
317
+ The prompt or prompts to guide the image generation.
318
+ uncond_prompt (`str` or `List[str]`):
319
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
320
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
321
+ max_embeddings_multiples (`int`, *optional*, defaults to `1`):
322
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
323
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
324
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
325
+ ending token in each of the chunk in the middle.
326
+ skip_parsing (`bool`, *optional*, defaults to `False`):
327
+ Skip the parsing of brackets.
328
+ skip_weighting (`bool`, *optional*, defaults to `False`):
329
+ Skip the weighting. When the parsing is skipped, it is forced True.
330
+ """
331
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
332
+ if isinstance(prompt, str):
333
+ prompt = [prompt]
334
+
335
+ if not skip_parsing:
336
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
337
+ if uncond_prompt is not None:
338
+ if isinstance(uncond_prompt, str):
339
+ uncond_prompt = [uncond_prompt]
340
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
341
+ else:
342
+ prompt_tokens = [
343
+ token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
344
+ ]
345
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
346
+ if uncond_prompt is not None:
347
+ if isinstance(uncond_prompt, str):
348
+ uncond_prompt = [uncond_prompt]
349
+ uncond_tokens = [
350
+ token[1:-1]
351
+ for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
352
+ ]
353
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
354
+
355
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
356
+ max_length = max([len(token) for token in prompt_tokens])
357
+ if uncond_prompt is not None:
358
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
359
+
360
+ max_embeddings_multiples = min(
361
+ max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
362
+ )
363
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
364
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
365
+
366
+ # pad the length of tokens and weights
367
+ # support bert tokenizer
368
+ bos = pipe.tokenizer.bos_token_id if pipe.tokenizer.bos_token_id is not None else pipe.tokenizer.cls_token_id
369
+ eos = pipe.tokenizer.eos_token_id if pipe.tokenizer.eos_token_id is not None else pipe.tokenizer.sep_token_id
370
+ pad = pipe.tokenizer.pad_token_id
371
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
372
+ prompt_tokens,
373
+ prompt_weights,
374
+ max_length,
375
+ bos,
376
+ eos,
377
+ pad,
378
+ no_boseos_middle=no_boseos_middle,
379
+ chunk_length=pipe.tokenizer.model_max_length,
380
+ )
381
+ prompt_tokens = paddle.to_tensor(prompt_tokens)
382
+ if uncond_prompt is not None:
383
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
384
+ uncond_tokens,
385
+ uncond_weights,
386
+ max_length,
387
+ bos,
388
+ eos,
389
+ pad,
390
+ no_boseos_middle=no_boseos_middle,
391
+ chunk_length=pipe.tokenizer.model_max_length,
392
+ )
393
+ uncond_tokens = paddle.to_tensor(uncond_tokens)
394
+
395
+ # get the embeddings
396
+ text_embeddings = get_unweighted_text_embeddings(
397
+ pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
398
+ )
399
+ prompt_weights = paddle.to_tensor(prompt_weights, dtype=text_embeddings.dtype)
400
+ if uncond_prompt is not None:
401
+ uncond_embeddings = get_unweighted_text_embeddings(
402
+ pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
403
+ )
404
+ uncond_weights = paddle.to_tensor(uncond_weights, dtype=uncond_embeddings.dtype)
405
+
406
+ # assign weights to the prompts and normalize in the sense of mean
407
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
408
+ if (not skip_parsing) and (not skip_weighting):
409
+ previous_mean = text_embeddings.mean(axis=[-2, -1])
410
+ text_embeddings *= prompt_weights.unsqueeze(-1)
411
+ text_embeddings *= previous_mean / text_embeddings.mean(axis=[-2, -1])
412
+ if uncond_prompt is not None:
413
+ previous_mean = uncond_embeddings.mean(axis=[-2, -1])
414
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
415
+ uncond_embeddings *= previous_mean / uncond_embeddings.mean(axis=[-2, -1])
416
+
417
+ # For classifier free guidance, we need to do two forward passes.
418
+ # Here we concatenate the unconditional and text embeddings into a single batch
419
+ # to avoid doing two forward passes
420
+ if uncond_prompt is not None:
421
+ text_embeddings = paddle.concat([uncond_embeddings, text_embeddings])
422
+
423
+ return text_embeddings
424
+
425
+
426
+ def preprocess_image(image):
427
+ w, h = image.size
428
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
429
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
430
+ image = np.array(image).astype(np.float32) / 255.0
431
+ image = image[None].transpose(0, 3, 1, 2)
432
+ image = paddle.to_tensor(image)
433
+ return 2.0 * image - 1.0
434
+
435
+
436
+ def preprocess_mask(mask):
437
+ mask = mask.convert("L")
438
+ w, h = mask.size
439
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
440
+ mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
441
+ mask = np.array(mask).astype(np.float32) / 255.0
442
+ mask = np.tile(mask, (4, 1, 1))
443
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
444
+ mask = 1 - mask # repaint white, keep black
445
+ mask = paddle.to_tensor(mask)
446
+ return mask
447
+
448
+
449
+ class StableDiffusionPipelineAllinOne(DiffusionPipeline):
450
+ r"""
451
+ Pipeline for text-to-image image-to-image inpainting generation using Stable Diffusion.
452
+
453
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
454
+ library implements for all the pipelines (such as downloading or saving, running on a particular xxxx, etc.)
455
+
456
+ Args:
457
+ vae ([`AutoencoderKL`]):
458
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
459
+ text_encoder ([`CLIPTextModel`]):
460
+ Frozen text-encoder. Stable Diffusion uses the text portion of
461
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
462
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
463
+ tokenizer (`CLIPTokenizer`):
464
+ Tokenizer of class
465
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
466
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
467
+ scheduler ([`SchedulerMixin`]):
468
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
469
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`PNDMScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`]
470
+ or [`DPMSolverMultistepScheduler`].
471
+ safety_checker ([`StableDiffusionSafetyChecker`]):
472
+ Classification module that estimates whether generated images could be considered offensive or harmful.
473
+ Please, refer to the [model card](https://huggingface.co/junnyu/stable-diffusion-v1-4-paddle) for details.
474
+ feature_extractor ([`CLIPFeatureExtractor`]):
475
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
476
+ """
477
+ _optional_components = ["safety_checker", "feature_extractor"]
478
+
479
+ def __init__(
480
+ self,
481
+ vae: AutoencoderKL,
482
+ text_encoder: CLIPTextModel,
483
+ tokenizer: CLIPTokenizer,
484
+ unet: UNet2DConditionModel,
485
+ scheduler: Union[
486
+ DDIMScheduler,
487
+ PNDMScheduler,
488
+ LMSDiscreteScheduler,
489
+ EulerDiscreteScheduler,
490
+ EulerAncestralDiscreteScheduler,
491
+ DPMSolverMultistepScheduler,
492
+ ],
493
+ safety_checker: StableDiffusionSafetyChecker,
494
+ feature_extractor: CLIPFeatureExtractor,
495
+ requires_safety_checker: bool = False,
496
+ ):
497
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
498
+ deprecation_message = (
499
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
500
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
501
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
502
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
503
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
504
+ " file"
505
+ )
506
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
507
+ new_config = dict(scheduler.config)
508
+ new_config["steps_offset"] = 1
509
+ scheduler._internal_dict = FrozenDict(new_config)
510
+
511
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
512
+ deprecation_message = (
513
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
514
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
515
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
516
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
517
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
518
+ )
519
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
520
+ new_config = dict(scheduler.config)
521
+ new_config["clip_sample"] = False
522
+ scheduler._internal_dict = FrozenDict(new_config)
523
+
524
+ if safety_checker is None and requires_safety_checker:
525
+ logger.warning(
526
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
527
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
528
+ " results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
529
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
530
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
531
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
532
+ )
533
+ if safety_checker is not None and feature_extractor is None:
534
+ raise ValueError(
535
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
536
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
537
+ )
538
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_ppdiffusers_version") and version.parse(
539
+ version.parse(unet.config._ppdiffusers_version).base_version
540
+ ) < version.parse("0.9.0.dev0")
541
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
542
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
543
+ deprecation_message = (
544
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
545
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
546
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
547
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
548
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
549
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
550
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
551
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
552
+ " the `unet/config.json` file"
553
+ )
554
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
555
+ new_config = dict(unet.config)
556
+ new_config["sample_size"] = 64
557
+ unet._internal_dict = FrozenDict(new_config)
558
+
559
+ self.register_modules(
560
+ vae=vae,
561
+ text_encoder=text_encoder,
562
+ tokenizer=tokenizer,
563
+ unet=unet,
564
+ scheduler=scheduler,
565
+ safety_checker=safety_checker,
566
+ feature_extractor=feature_extractor,
567
+ )
568
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
569
+
570
+ def create_scheduler(self, name="DPMSolver"):
571
+ config = self.scheduler.config
572
+ if name == "DPMSolver":
573
+ return DPMSolverMultistepScheduler.from_config(
574
+ config,
575
+ thresholding=False,
576
+ algorithm_type="dpmsolver++",
577
+ solver_type="midpoint",
578
+ lower_order_final=True,
579
+ )
580
+ if name == "EulerDiscrete":
581
+ return EulerDiscreteScheduler.from_config(config)
582
+ elif name == "EulerAncestralDiscrete":
583
+ return EulerAncestralDiscreteScheduler.from_config(config)
584
+ elif name == "PNDM":
585
+ return PNDMScheduler.from_config(config)
586
+ elif name == "DDIM":
587
+ return DDIMScheduler.from_config(config)
588
+ elif name == "LMSDiscrete":
589
+ return LMSDiscreteScheduler.from_config(config)
590
+ elif name == "HeunDiscrete":
591
+ return HeunDiscreteScheduler.from_config(config)
592
+ elif name == "KDPM2AncestralDiscrete":
593
+ return KDPM2AncestralDiscreteScheduler.from_config(config)
594
+ elif name == "KDPM2Discrete":
595
+ return KDPM2DiscreteScheduler.from_config(config)
596
+ else:
597
+ raise NotImplementedError
598
+
599
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
600
+ r"""
601
+ Enable sliced attention computation.
602
+
603
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
604
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
605
+
606
+ Args:
607
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
608
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
609
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
610
+ `attention_head_dim` must be a multiple of `slice_size`.
611
+ """
612
+ if slice_size == "auto":
613
+ if isinstance(self.unet.config.attention_head_dim, int):
614
+ # half the attention head size is usually a good trade-off between
615
+ # speed and memory
616
+ slice_size = self.unet.config.attention_head_dim // 2
617
+ else:
618
+ # if `attention_head_dim` is a list, take the smallest head size
619
+ slice_size = min(self.unet.config.attention_head_dim)
620
+ self.unet.set_attention_slice(slice_size)
621
+
622
+ def disable_attention_slicing(self):
623
+ r"""
624
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
625
+ back to computing attention in one step.
626
+ """
627
+ # set slice_size = `None` to disable `attention slicing`
628
+ self.enable_attention_slicing(None)
629
+
630
+ def __call__(self, *args, **kwargs):
631
+ return self.text2image(*args, **kwargs)
632
+
633
+ def text2img(self, *args, **kwargs):
634
+ return self.text2image(*args, **kwargs)
635
+
636
+ def _encode_prompt(
637
+ self,
638
+ prompt,
639
+ negative_prompt,
640
+ max_embeddings_multiples,
641
+ no_boseos_middle,
642
+ skip_parsing,
643
+ skip_weighting,
644
+ do_classifier_free_guidance,
645
+ num_images_per_prompt,
646
+ ):
647
+ if do_classifier_free_guidance and negative_prompt is None:
648
+ negative_prompt = ""
649
+ text_embeddings = get_weighted_text_embeddings(
650
+ self, prompt, negative_prompt, max_embeddings_multiples, no_boseos_middle, skip_parsing, skip_weighting
651
+ )
652
+
653
+ bs_embed, seq_len, _ = text_embeddings.shape
654
+ text_embeddings = text_embeddings.tile([1, num_images_per_prompt, 1])
655
+ text_embeddings = text_embeddings.reshape([bs_embed * num_images_per_prompt, seq_len, -1])
656
+ return text_embeddings
657
+
658
+ def run_safety_checker(self, image, dtype):
659
+ if self.safety_checker is not None:
660
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pd")
661
+ image, has_nsfw_concept = self.safety_checker(
662
+ images=image, clip_input=safety_checker_input.pixel_values.cast(dtype)
663
+ )
664
+ else:
665
+ has_nsfw_concept = None
666
+ return image, has_nsfw_concept
667
+
668
+ def decode_latents(self, latents):
669
+ latents = 1 / 0.18215 * latents
670
+ image = self.vae.decode(latents).sample
671
+ image = (image / 2 + 0.5).clip(0, 1)
672
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
673
+ image = image.transpose([0, 2, 3, 1]).cast("float32").numpy()
674
+ return image
675
+
676
+ def prepare_extra_step_kwargs(self, eta, scheduler):
677
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
678
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
679
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
680
+ # and should be between [0, 1]
681
+
682
+ accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
683
+ extra_step_kwargs = {}
684
+ if accepts_eta:
685
+ extra_step_kwargs["eta"] = eta
686
+
687
+ return extra_step_kwargs
688
+
689
+ def check_inputs_text2img(self, prompt, height, width, callback_steps):
690
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
691
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
692
+
693
+ if height % 8 != 0 or width % 8 != 0:
694
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
695
+
696
+ if (callback_steps is None) or (
697
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
698
+ ):
699
+ raise ValueError(
700
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
701
+ f" {type(callback_steps)}."
702
+ )
703
+
704
+ def check_inputs_img2img_inpaint(self, prompt, strength, callback_steps):
705
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
706
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
707
+
708
+ if strength < 0 or strength > 1:
709
+ raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
710
+
711
+ if (callback_steps is None) or (
712
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
713
+ ):
714
+ raise ValueError(
715
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
716
+ f" {type(callback_steps)}."
717
+ )
718
+
719
+ def prepare_latents_text2img(self, batch_size, num_channels_latents, height, width, dtype, latents=None, scheduler=None):
720
+ shape = [batch_size, num_channels_latents, height // 8, width // 8]
721
+ if latents is None:
722
+ latents = paddle.randn(shape, dtype=dtype)
723
+ else:
724
+ if latents.shape != shape:
725
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
726
+
727
+ # scale the initial noise by the standard deviation required by the scheduler
728
+ latents = latents * scheduler.init_noise_sigma
729
+ return latents
730
+
731
+ def prepare_latents_img2img(self, image, timestep, num_images_per_prompt, dtype, scheduler):
732
+ image = image.cast(dtype=dtype)
733
+ init_latent_dist = self.vae.encode(image).latent_dist
734
+ init_latents = init_latent_dist.sample()
735
+ init_latents = 0.18215 * init_latents
736
+
737
+ b, c, h, w = init_latents.shape
738
+ init_latents = init_latents.tile([1, num_images_per_prompt, 1, 1])
739
+ init_latents = init_latents.reshape([b * num_images_per_prompt, c, h, w])
740
+
741
+ # add noise to latents using the timesteps
742
+ noise = paddle.randn(init_latents.shape, dtype=dtype)
743
+
744
+ # get latents
745
+ init_latents = scheduler.add_noise(init_latents, noise, timestep)
746
+ latents = init_latents
747
+
748
+ return latents
749
+
750
+ def get_timesteps(self, num_inference_steps, strength, scheduler):
751
+ # get the original timestep using init_timestep
752
+ offset = scheduler.config.get("steps_offset", 0)
753
+ init_timestep = int(num_inference_steps * strength) + offset
754
+ init_timestep = min(init_timestep, num_inference_steps)
755
+
756
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
757
+ timesteps = scheduler.timesteps[t_start:]
758
+
759
+ return timesteps, num_inference_steps - t_start
760
+
761
+ def prepare_latents_inpaint(self, image, timestep, num_images_per_prompt, dtype, scheduler):
762
+ image = image.cast(dtype)
763
+ init_latent_dist = self.vae.encode(image).latent_dist
764
+ init_latents = init_latent_dist.sample()
765
+ init_latents = 0.18215 * init_latents
766
+
767
+ b, c, h, w = init_latents.shape
768
+ init_latents = init_latents.tile([1, num_images_per_prompt, 1, 1])
769
+ init_latents = init_latents.reshape([b * num_images_per_prompt, c, h, w])
770
+
771
+ init_latents_orig = init_latents
772
+
773
+ # add noise to latents using the timesteps
774
+ noise = paddle.randn(init_latents.shape, dtype=dtype)
775
+ init_latents = scheduler.add_noise(init_latents, noise, timestep)
776
+ latents = init_latents
777
+ return latents, init_latents_orig, noise
778
+
779
+ @paddle.no_grad()
780
+ def text2image(
781
+ self,
782
+ prompt: Union[str, List[str]],
783
+ height: int = 512,
784
+ width: int = 512,
785
+ num_inference_steps: int = 50,
786
+ guidance_scale: float = 7.5,
787
+ negative_prompt: Optional[Union[str, List[str]]] = None,
788
+ num_images_per_prompt: Optional[int] = 1,
789
+ eta: float = 0.0,
790
+ seed: Optional[int] = None,
791
+ latents: Optional[paddle.Tensor] = None,
792
+ output_type: Optional[str] = "pil",
793
+ return_dict: bool = True,
794
+ callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
795
+ callback_steps: Optional[int] = 1,
796
+ # new add
797
+ max_embeddings_multiples: Optional[int] = 1,
798
+ no_boseos_middle: Optional[bool] = False,
799
+ skip_parsing: Optional[bool] = False,
800
+ skip_weighting: Optional[bool] = False,
801
+ scheduler=None,
802
+ **kwargs,
803
+ ):
804
+ r"""
805
+ Function invoked when calling the pipeline for generation.
806
+
807
+ Args:
808
+ prompt (`str` or `List[str]`):
809
+ The prompt or prompts to guide the image generation.
810
+ height (`int`, *optional*, defaults to 512):
811
+ The height in pixels of the generated image.
812
+ width (`int`, *optional*, defaults to 512):
813
+ The width in pixels of the generated image.
814
+ num_inference_steps (`int`, *optional*, defaults to 50):
815
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
816
+ expense of slower inference.
817
+ guidance_scale (`float`, *optional*, defaults to 7.5):
818
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
819
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
820
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
821
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
822
+ usually at the expense of lower image quality.
823
+ negative_prompt (`str` or `List[str]`, *optional*):
824
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
825
+ if `guidance_scale` is less than `1`).
826
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
827
+ The number of images to generate per prompt.
828
+ eta (`float`, *optional*, defaults to 0.0):
829
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
830
+ [`schedulers.DDIMScheduler`], will be ignored for others.
831
+ seed (`int`, *optional*):
832
+ Random number seed.
833
+ latents (`paddle.Tensor`, *optional*):
834
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
835
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
836
+ tensor will ge generated by sampling using the supplied random `seed`.
837
+ output_type (`str`, *optional*, defaults to `"pil"`):
838
+ The output format of the generate image. Choose between
839
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
840
+ return_dict (`bool`, *optional*, defaults to `True`):
841
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
842
+ plain tuple.
843
+ callback (`Callable`, *optional*):
844
+ A function that will be called every `callback_steps` steps during inference. The function will be
845
+ called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
846
+ callback_steps (`int`, *optional*, defaults to 1):
847
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
848
+ called at every step.
849
+
850
+ Returns:
851
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
852
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
853
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
854
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
855
+ (nsfw) content, according to the `safety_checker`.
856
+ """
857
+ if scheduler is None:
858
+ scheduler = self.scheduler
859
+ seed = random.randint(0, 2**32) if seed is None else seed
860
+ argument = dict(
861
+ prompt=prompt,
862
+ negative_prompt=negative_prompt,
863
+ height=height,
864
+ width=width,
865
+ num_inference_steps=num_inference_steps,
866
+ guidance_scale=guidance_scale,
867
+ num_images_per_prompt=num_images_per_prompt,
868
+ eta=eta,
869
+ seed=seed,
870
+ latents=latents,
871
+ max_embeddings_multiples=max_embeddings_multiples,
872
+ no_boseos_middle=no_boseos_middle,
873
+ skip_parsing=skip_parsing,
874
+ skip_weighting=skip_weighting,
875
+ epoch_time=time.time(),
876
+ )
877
+ paddle.seed(seed)
878
+ # 1. Check inputs. Raise error if not correct
879
+ self.check_inputs_text2img(prompt, height, width, callback_steps)
880
+
881
+ # 2. Define call parameters
882
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
883
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
884
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
885
+ # corresponds to doing no classifier free guidance.
886
+ do_classifier_free_guidance = guidance_scale > 1.0
887
+
888
+ # 3. Encode input prompt
889
+ text_embeddings = self._encode_prompt(
890
+ prompt,
891
+ negative_prompt,
892
+ max_embeddings_multiples,
893
+ no_boseos_middle,
894
+ skip_parsing,
895
+ skip_weighting,
896
+ do_classifier_free_guidance,
897
+ num_images_per_prompt,
898
+ )
899
+
900
+ # 4. Prepare timesteps
901
+ scheduler.set_timesteps(num_inference_steps)
902
+ timesteps = scheduler.timesteps
903
+
904
+ # 5. Prepare latent variables
905
+ num_channels_latents = self.unet.in_channels
906
+ latents = self.prepare_latents_text2img(
907
+ batch_size * num_images_per_prompt,
908
+ num_channels_latents,
909
+ height,
910
+ width,
911
+ text_embeddings.dtype,
912
+ latents,
913
+ scheduler=scheduler,
914
+ )
915
+
916
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
917
+ extra_step_kwargs = self.prepare_extra_step_kwargs(eta, scheduler)
918
+
919
+ # 7. Denoising loop
920
+ num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order
921
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
922
+ for i, t in enumerate(timesteps):
923
+ # expand the latents if we are doing classifier free guidance
924
+ latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
925
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
926
+
927
+ # predict the noise residual
928
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
929
+
930
+ # perform guidance
931
+ if do_classifier_free_guidance:
932
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
933
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
934
+
935
+ # compute the previous noisy sample x_t -> x_t-1
936
+ latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
937
+
938
+ # call the callback, if provided
939
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
940
+ progress_bar.update()
941
+ if callback is not None and i % callback_steps == 0:
942
+ callback(progress_bar.n, progress_bar.total, progress_bar)
943
+
944
+ # 8. Post-processing
945
+ image = self.decode_latents(latents)
946
+
947
+ # 9. Run safety checker
948
+ image, has_nsfw_concept = self.run_safety_checker(image, text_embeddings.dtype)
949
+
950
+ # 10. Convert to PIL
951
+ if output_type == "pil":
952
+ image = self.numpy_to_pil(image, argument=argument)
953
+
954
+ if not return_dict:
955
+ return (image, has_nsfw_concept)
956
+
957
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
958
+
959
+ @paddle.no_grad()
960
+ def img2img(
961
+ self,
962
+ prompt: Union[str, List[str]],
963
+ image: Union[paddle.Tensor, PIL.Image.Image],
964
+ strength: float = 0.8,
965
+ height=None,
966
+ width=None,
967
+ num_inference_steps: Optional[int] = 50,
968
+ guidance_scale: Optional[float] = 7.5,
969
+ negative_prompt: Optional[Union[str, List[str]]] = None,
970
+ num_images_per_prompt: Optional[int] = 1,
971
+ eta: Optional[float] = 0.0,
972
+ seed: Optional[int] = None,
973
+ output_type: Optional[str] = "pil",
974
+ return_dict: bool = True,
975
+ callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
976
+ callback_steps: Optional[int] = 1,
977
+ # new add
978
+ max_embeddings_multiples: Optional[int] = 1,
979
+ no_boseos_middle: Optional[bool] = False,
980
+ skip_parsing: Optional[bool] = False,
981
+ skip_weighting: Optional[bool] = False,
982
+ scheduler=None,
983
+ **kwargs,
984
+ ):
985
+ r"""
986
+ Function invoked when calling the pipeline for generation.
987
+
988
+ Args:
989
+ prompt (`str` or `List[str]`):
990
+ The prompt or prompts to guide the image generation.
991
+ image (`paddle.Tensor` or `PIL.Image.Image`):
992
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
993
+ process.
994
+ strength (`float`, *optional*, defaults to 0.8):
995
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
996
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
997
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
998
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
999
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
1000
+ num_inference_steps (`int`, *optional*, defaults to 50):
1001
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1002
+ expense of slower inference. This parameter will be modulated by `strength`.
1003
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1004
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1005
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1006
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1007
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1008
+ usually at the expense of lower image quality.
1009
+ negative_prompt (`str` or `List[str]`, *optional*):
1010
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1011
+ if `guidance_scale` is less than `1`).
1012
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1013
+ The number of images to generate per prompt.
1014
+ eta (`float`, *optional*, defaults to 0.0):
1015
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1016
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1017
+ seed (`int`, *optional*):
1018
+ A random seed.
1019
+ output_type (`str`, *optional*, defaults to `"pil"`):
1020
+ The output format of the generate image. Choose between
1021
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1022
+ return_dict (`bool`, *optional*, defaults to `True`):
1023
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1024
+ plain tuple.
1025
+ callback (`Callable`, *optional*):
1026
+ A function that will be called every `callback_steps` steps during inference. The function will be
1027
+ called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
1028
+ callback_steps (`int`, *optional*, defaults to 1):
1029
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1030
+ called at every step.
1031
+
1032
+ Returns:
1033
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1034
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1035
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1036
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1037
+ (nsfw) content, according to the `safety_checker`.
1038
+ """
1039
+ if scheduler is None:
1040
+ scheduler = self.scheduler
1041
+ seed = random.randint(0, 2**32) if seed is None else seed
1042
+ image_str = image
1043
+ if isinstance(image_str, str):
1044
+ image = load_image(image_str)
1045
+
1046
+ if height is None and width is None:
1047
+ width = (image.size[0] // 8) * 8
1048
+ height = (image.size[1] // 8) * 8
1049
+ elif height is None and width is not None:
1050
+ height = (image.size[1] // 8) * 8
1051
+ elif width is None and height is not None:
1052
+ width = (image.size[0] // 8) * 8
1053
+ else:
1054
+ height = height
1055
+ width = width
1056
+
1057
+ argument = dict(
1058
+ prompt=prompt,
1059
+ image=image_str,
1060
+ negative_prompt=negative_prompt,
1061
+ height=height,
1062
+ width=width,
1063
+ strength=strength,
1064
+ num_inference_steps=num_inference_steps,
1065
+ guidance_scale=guidance_scale,
1066
+ num_images_per_prompt=num_images_per_prompt,
1067
+ eta=eta,
1068
+ seed=seed,
1069
+ max_embeddings_multiples=max_embeddings_multiples,
1070
+ no_boseos_middle=no_boseos_middle,
1071
+ skip_parsing=skip_parsing,
1072
+ skip_weighting=skip_weighting,
1073
+ epoch_time=time.time(),
1074
+ )
1075
+ paddle.seed(seed)
1076
+
1077
+ # 1. Check inputs
1078
+ self.check_inputs_img2img_inpaint(prompt, strength, callback_steps)
1079
+
1080
+ # 2. Define call parameters
1081
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
1082
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1083
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1084
+ # corresponds to doing no classifier free guidance.
1085
+ do_classifier_free_guidance = guidance_scale > 1.0
1086
+
1087
+ # 3. Encode input prompt
1088
+ text_embeddings = self._encode_prompt(
1089
+ prompt,
1090
+ negative_prompt,
1091
+ max_embeddings_multiples,
1092
+ no_boseos_middle,
1093
+ skip_parsing,
1094
+ skip_weighting,
1095
+ do_classifier_free_guidance,
1096
+ num_images_per_prompt,
1097
+ )
1098
+
1099
+ # 4. Preprocess image
1100
+ if isinstance(image, PIL.Image.Image):
1101
+ image = image.resize((width, height))
1102
+ image = preprocess_image(image)
1103
+
1104
+ # 5. set timesteps
1105
+ scheduler.set_timesteps(num_inference_steps)
1106
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
1107
+ latent_timestep = timesteps[:1].tile([batch_size * num_images_per_prompt])
1108
+
1109
+ # 6. Prepare latent variables
1110
+ latents = self.prepare_latents_img2img(image, latent_timestep, num_images_per_prompt, text_embeddings.dtype, scheduler)
1111
+
1112
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1113
+ extra_step_kwargs = self.prepare_extra_step_kwargs(eta, scheduler)
1114
+
1115
+ # 8. Denoising loop
1116
+ num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order
1117
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1118
+ for i, t in enumerate(timesteps):
1119
+ # expand the latents if we are doing classifier free guidance
1120
+ latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
1121
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
1122
+
1123
+ # predict the noise residual
1124
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
1125
+
1126
+ # perform guidance
1127
+ if do_classifier_free_guidance:
1128
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1129
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1130
+
1131
+ # compute the previous noisy sample x_t -> x_t-1
1132
+ latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1133
+
1134
+ # call the callback, if provided
1135
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
1136
+ progress_bar.update()
1137
+ if callback is not None and i % callback_steps == 0:
1138
+ callback(progress_bar.n, progress_bar.total, progress_bar)
1139
+
1140
+ # 9. Post-processing
1141
+ image = self.decode_latents(latents)
1142
+
1143
+ # 10. Run safety checker
1144
+ image, has_nsfw_concept = self.run_safety_checker(image, text_embeddings.dtype)
1145
+
1146
+ # 11. Convert to PIL
1147
+ if output_type == "pil":
1148
+ image = self.numpy_to_pil(image, argument=argument)
1149
+
1150
+ if not return_dict:
1151
+ return (image, has_nsfw_concept)
1152
+
1153
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
1154
+
1155
+ @paddle.no_grad()
1156
+ def inpaint(
1157
+ self,
1158
+ prompt: Union[str, List[str]],
1159
+ image: Union[paddle.Tensor, PIL.Image.Image],
1160
+ mask_image: Union[paddle.Tensor, PIL.Image.Image],
1161
+ height=None,
1162
+ width=None,
1163
+ strength: float = 0.8,
1164
+ num_inference_steps: Optional[int] = 50,
1165
+ guidance_scale: Optional[float] = 7.5,
1166
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1167
+ num_images_per_prompt: Optional[int] = 1,
1168
+ eta: Optional[float] = 0.0,
1169
+ seed: Optional[int] = None,
1170
+ output_type: Optional[str] = "pil",
1171
+ return_dict: bool = True,
1172
+ callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
1173
+ callback_steps: Optional[int] = 1,
1174
+ # new add
1175
+ max_embeddings_multiples: Optional[int] = 1,
1176
+ no_boseos_middle: Optional[bool] = False,
1177
+ skip_parsing: Optional[bool] = False,
1178
+ skip_weighting: Optional[bool] = False,
1179
+ scheduler=None,
1180
+ **kwargs,
1181
+ ):
1182
+ r"""
1183
+ Function invoked when calling the pipeline for generation.
1184
+
1185
+ Args:
1186
+ prompt (`str` or `List[str]`):
1187
+ The prompt or prompts to guide the image generation.
1188
+ image (`paddle.Tensor` or `PIL.Image.Image`):
1189
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1190
+ process. This is the image whose masked region will be inpainted.
1191
+ mask_image (`paddle.Tensor` or `PIL.Image.Image`):
1192
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1193
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1194
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1195
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1196
+ strength (`float`, *optional*, defaults to 0.8):
1197
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1198
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
1199
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1200
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1201
+ num_inference_steps (`int`, *optional*, defaults to 50):
1202
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1203
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1204
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1205
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1206
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1207
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1208
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1209
+ usually at the expense of lower image quality.
1210
+ negative_prompt (`str` or `List[str]`, *optional*):
1211
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1212
+ if `guidance_scale` is less than `1`).
1213
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1214
+ The number of images to generate per prompt.
1215
+ eta (`float`, *optional*, defaults to 0.0):
1216
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1217
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1218
+ seed (`int`, *optional*):
1219
+ A random seed.
1220
+ output_type (`str`, *optional*, defaults to `"pil"`):
1221
+ The output format of the generate image. Choose between
1222
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1223
+ return_dict (`bool`, *optional*, defaults to `True`):
1224
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1225
+ plain tuple.
1226
+ callback (`Callable`, *optional*):
1227
+ A function that will be called every `callback_steps` steps during inference. The function will be
1228
+ called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
1229
+ callback_steps (`int`, *optional*, defaults to 1):
1230
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1231
+ called at every step.
1232
+
1233
+ Returns:
1234
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1235
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1236
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1237
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1238
+ (nsfw) content, according to the `safety_checker`.
1239
+ """
1240
+ if scheduler is None:
1241
+ scheduler = self.scheduler
1242
+ seed = random.randint(0, 2**32) if seed is None else seed
1243
+ image_str = image
1244
+ mask_image_str = mask_image
1245
+
1246
+ if isinstance(image_str, str):
1247
+ image = load_image(image_str)
1248
+ if isinstance(mask_image_str, str):
1249
+ mask_image = load_image(mask_image_str)
1250
+
1251
+ if height is None and width is None:
1252
+ width = (image.size[0] // 8) * 8
1253
+ height = (image.size[1] // 8) * 8
1254
+ elif height is None and width is not None:
1255
+ height = (image.size[1] // 8) * 8
1256
+ elif width is None and height is not None:
1257
+ width = (image.size[0] // 8) * 8
1258
+ else:
1259
+ height = height
1260
+ width = width
1261
+
1262
+ argument = dict(
1263
+ prompt=prompt,
1264
+ image=image_str,
1265
+ mask_image=mask_image_str,
1266
+ negative_prompt=negative_prompt,
1267
+ height=height,
1268
+ width=width,
1269
+ strength=strength,
1270
+ num_inference_steps=num_inference_steps,
1271
+ guidance_scale=guidance_scale,
1272
+ num_images_per_prompt=num_images_per_prompt,
1273
+ eta=eta,
1274
+ seed=seed,
1275
+ max_embeddings_multiples=max_embeddings_multiples,
1276
+ no_boseos_middle=no_boseos_middle,
1277
+ skip_parsing=skip_parsing,
1278
+ skip_weighting=skip_weighting,
1279
+ epoch_time=time.time(),
1280
+ )
1281
+ paddle.seed(seed)
1282
+
1283
+ # 1. Check inputs
1284
+ self.check_inputs_img2img_inpaint(prompt, strength, callback_steps)
1285
+
1286
+ # 2. Define call parameters
1287
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
1288
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1289
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1290
+ # corresponds to doing no classifier free guidance.
1291
+ do_classifier_free_guidance = guidance_scale > 1.0
1292
+
1293
+ # 3. Encode input prompt
1294
+ text_embeddings = self._encode_prompt(
1295
+ prompt,
1296
+ negative_prompt,
1297
+ max_embeddings_multiples,
1298
+ no_boseos_middle,
1299
+ skip_parsing,
1300
+ skip_weighting,
1301
+ do_classifier_free_guidance,
1302
+ num_images_per_prompt,
1303
+ )
1304
+
1305
+ if not isinstance(image, paddle.Tensor):
1306
+ image = image.resize((width, height))
1307
+ image = preprocess_image(image)
1308
+
1309
+ if not isinstance(mask_image, paddle.Tensor):
1310
+ mask_image = mask_image.resize((width, height))
1311
+ mask_image = preprocess_mask(mask_image)
1312
+
1313
+ # 5. set timesteps
1314
+ scheduler.set_timesteps(num_inference_steps)
1315
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
1316
+ latent_timestep = timesteps[:1].tile([batch_size * num_images_per_prompt])
1317
+
1318
+ # 6. Prepare latent variables
1319
+ # encode the init image into latents and scale the latents
1320
+ latents, init_latents_orig, noise = self.prepare_latents_inpaint(
1321
+ image, latent_timestep, num_images_per_prompt, text_embeddings.dtype, scheduler
1322
+ )
1323
+
1324
+ # 7. Prepare mask latent
1325
+ mask = mask_image.cast(latents.dtype)
1326
+ mask = paddle.concat([mask] * batch_size * num_images_per_prompt)
1327
+
1328
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1329
+ extra_step_kwargs = self.prepare_extra_step_kwargs(eta, scheduler)
1330
+
1331
+ # 9. Denoising loop
1332
+ num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order
1333
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1334
+ for i, t in enumerate(timesteps):
1335
+ # expand the latents if we are doing classifier free guidance
1336
+ latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
1337
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
1338
+
1339
+ # predict the noise residual
1340
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
1341
+
1342
+ # perform guidance
1343
+ if do_classifier_free_guidance:
1344
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1345
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1346
+
1347
+ # compute the previous noisy sample x_t -> x_t-1
1348
+ latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1349
+ # masking
1350
+ init_latents_proper = scheduler.add_noise(init_latents_orig, noise, t)
1351
+
1352
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
1353
+
1354
+ # call the callback, if provided
1355
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
1356
+ progress_bar.update()
1357
+ if callback is not None and i % callback_steps == 0:
1358
+ callback(progress_bar.n, progress_bar.total, progress_bar)
1359
+
1360
+ # 10. Post-processing
1361
+ image = self.decode_latents(latents)
1362
+
1363
+ # 11. Run safety checker
1364
+ image, has_nsfw_concept = self.run_safety_checker(image, text_embeddings.dtype)
1365
+
1366
+ # 12. Convert to PIL
1367
+ if output_type == "pil":
1368
+ image = self.numpy_to_pil(image, argument=argument)
1369
+
1370
+ if not return_dict:
1371
+ return (image, has_nsfw_concept)
1372
+
1373
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
1374
+
1375
+ @staticmethod
1376
+ def numpy_to_pil(images, **kwargs):
1377
+ """
1378
+ Convert a numpy image or a batch of images to a PIL image.
1379
+ """
1380
+ if images.ndim == 3:
1381
+ images = images[None, ...]
1382
+ images = (images * 255).round().astype("uint8")
1383
+ pil_images = []
1384
+ argument = kwargs.pop("argument", None)
1385
+ for image in images:
1386
+ image = PIL.Image.fromarray(image)
1387
+ if argument is not None:
1388
+ image.argument = argument
1389
+ pil_images.append(image)
1390
+
1391
+ return pil_images
1392
+ pipeline = StableDiffusionPipelineAllinOne.from_pretrained(BASE_MODEL_NAME, safety_checker=None)
1393
+
1394
+ if LORA_WEIGHTS_PATH is not None:
1395
+ pipeline.unet.load_attn_procs(LORA_WEIGHTS_PATH, from_hf_hub=True)
1396
+
1397
+ support_scheduler = [
1398
+ "DPMSolver",
1399
+ "EulerDiscrete",
1400
+ "EulerAncestralDiscrete",
1401
+ "PNDM",
1402
+ "DDIM",
1403
+ "LMSDiscrete",
1404
+ "HeunDiscrete",
1405
+ "KDPM2AncestralDiscrete",
1406
+ "KDPM2Discrete"
1407
+ ]
1408
+
1409
+ # generate images
1410
+ def infer(prompt, negative, scale, height, width, num_inference_steps, scheduler_name):
1411
+ scheduler = pipeline.create_scheduler(scheduler_name)
1412
+
1413
+ images = pipeline(
1414
+ prompt=prompt, negative_prompt=negative, guidance_scale=scale, height=height, width=width, num_inference_steps=num_inference_steps, scheduler=scheduler,
1415
+ ).images
1416
+ return images
1417
+
1418
+
1419
+ css = """
1420
+ .gradio-container {
1421
+ font-family: 'IBM Plex Sans', sans-serif;
1422
+ }
1423
+ .gr-button {
1424
+ color: white;
1425
+ border-color: black;
1426
+ background: black;
1427
+ }
1428
+ input[type='range'] {
1429
+ accent-color: black;
1430
+ }
1431
+ .dark input[type='range'] {
1432
+ accent-color: #dfdfdf;
1433
+ }
1434
+ .container {
1435
+ max-width: 730px;
1436
+ margin: auto;
1437
+ padding-top: 1.5rem;
1438
+ }
1439
+ #gallery {
1440
+ min-height: 22rem;
1441
+ margin-bottom: 15px;
1442
+ margin-left: auto;
1443
+ margin-right: auto;
1444
+ border-bottom-right-radius: .5rem !important;
1445
+ border-bottom-left-radius: .5rem !important;
1446
+ }
1447
+ #gallery>div>.h-full {
1448
+ min-height: 20rem;
1449
+ }
1450
+ .details:hover {
1451
+ text-decoration: underline;
1452
+ }
1453
+ .gr-button {
1454
+ white-space: nowrap;
1455
+ }
1456
+ .gr-button:focus {
1457
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
1458
+ outline: none;
1459
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
1460
+ --tw-border-opacity: 1;
1461
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
1462
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
1463
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
1464
+ --tw-ring-opacity: .5;
1465
+ }
1466
+ #advanced-btn {
1467
+ font-size: .7rem !important;
1468
+ line-height: 19px;
1469
+ margin-top: 12px;
1470
+ margin-bottom: 12px;
1471
+ padding: 2px 8px;
1472
+ border-radius: 14px !important;
1473
+ }
1474
+ #advanced-options {
1475
+ display: none;
1476
+ margin-bottom: 20px;
1477
+ }
1478
+ .footer {
1479
+ margin-bottom: 45px;
1480
+ margin-top: 35px;
1481
+ text-align: center;
1482
+ border-bottom: 1px solid #e5e5e5;
1483
+ }
1484
+ .footer>p {
1485
+ font-size: .8rem;
1486
+ display: inline-block;
1487
+ padding: 0 10px;
1488
+ transform: translateY(10px);
1489
+ background: white;
1490
+ }
1491
+ .dark .footer {
1492
+ border-color: #303030;
1493
+ }
1494
+ .dark .footer>p {
1495
+ background: #0b0f19;
1496
+ }
1497
+ .acknowledgments h4{
1498
+ margin: 1.25em 0 .25em 0;
1499
+ font-weight: bold;
1500
+ font-size: 115%;
1501
+ }
1502
+ .animate-spin {
1503
+ animation: spin 1s linear infinite;
1504
+ }
1505
+ @keyframes spin {
1506
+ from {
1507
+ transform: rotate(0deg);
1508
+ }
1509
+ to {
1510
+ transform: rotate(360deg);
1511
+ }
1512
+ }
1513
+ #share-btn-container {
1514
+ display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
1515
+ margin-top: 10px;
1516
+ margin-left: auto;
1517
+ }
1518
+ #share-btn {
1519
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
1520
+ }
1521
+ #share-btn * {
1522
+ all: unset;
1523
+ }
1524
+ #share-btn-container div:nth-child(-n+2){
1525
+ width: auto !important;
1526
+ min-height: 0px !important;
1527
+ }
1528
+ #share-btn-container .wrap {
1529
+ display: none !important;
1530
+ }
1531
+
1532
+ .gr-form{
1533
+ flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
1534
+ }
1535
+ #prompt-container{
1536
+ gap: 0;
1537
+ }
1538
+ #prompt-text-input, #negative-prompt-text-input{padding: .45rem 0.625rem}
1539
+ #component-16{border-top-width: 1px!important;margin-top: 1em}
1540
+ .image_duplication{position: absolute; width: 100px; left: 50px}
1541
+ """
1542
+
1543
+ block = gr.Blocks(css=css)
1544
+
1545
+ with block:
1546
+ gr.HTML(
1547
+ """
1548
+ <div style="text-align: center; margin: 0 auto;">
1549
+ <div
1550
+ style="
1551
+ display: inline-flex;
1552
+ align-items: center;
1553
+ gap: 0.8rem;
1554
+ font-size: 1.75rem;
1555
+ "
1556
+ >
1557
+ <svg
1558
+ width="0.65em"
1559
+ height="0.65em"
1560
+ viewBox="0 0 115 115"
1561
+ fill="none"
1562
+ xmlns="http://www.w3.org/2000/svg"
1563
+ >
1564
+ <rect width="23" height="23" fill="white"></rect>
1565
+ <rect y="69" width="23" height="23" fill="white"></rect>
1566
+ <rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
1567
+ <rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
1568
+ <rect x="46" width="23" height="23" fill="white"></rect>
1569
+ <rect x="46" y="69" width="23" height="23" fill="white"></rect>
1570
+ <rect x="69" width="23" height="23" fill="black"></rect>
1571
+ <rect x="69" y="69" width="23" height="23" fill="black"></rect>
1572
+ <rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
1573
+ <rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
1574
+ <rect x="115" y="46" width="23" height="23" fill="white"></rect>
1575
+ <rect x="115" y="115" width="23" height="23" fill="white"></rect>
1576
+ <rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
1577
+ <rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
1578
+ <rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
1579
+ <rect x="92" y="69" width="23" height="23" fill="white"></rect>
1580
+ <rect x="69" y="46" width="23" height="23" fill="white"></rect>
1581
+ <rect x="69" y="115" width="23" height="23" fill="white"></rect>
1582
+ <rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
1583
+ <rect x="46" y="46" width="23" height="23" fill="black"></rect>
1584
+ <rect x="46" y="115" width="23" height="23" fill="black"></rect>
1585
+ <rect x="46" y="69" width="23" height="23" fill="black"></rect>
1586
+ <rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
1587
+ <rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
1588
+ <rect x="23" y="69" width="23" height="23" fill="black"></rect>
1589
+ </svg>
1590
+ <h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
1591
+ Dreambooth LoRa Demo
1592
+ </h1>
1593
+ </div>
1594
+ </div>
1595
+ """
1596
+ )
1597
+ with gr.Group():
1598
+ with gr.Box():
1599
+ with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
1600
+ with gr.Column():
1601
+ text = gr.Textbox(
1602
+ label="Enter your prompt",
1603
+ value=PROMPTS,
1604
+ show_label=False,
1605
+ max_lines=1,
1606
+ placeholder="Enter your prompt",
1607
+ elem_id="prompt-text-input",
1608
+ ).style(
1609
+ border=(True, False, True, True),
1610
+ rounded=(True, False, False, True),
1611
+ container=False,
1612
+ )
1613
+ negative = gr.Textbox(
1614
+ label="Enter your negative prompt",
1615
+ show_label=False,
1616
+ max_lines=1,
1617
+ placeholder="Enter a negative prompt",
1618
+ elem_id="negative-prompt-text-input",
1619
+ ).style(
1620
+ border=(True, False, True, True),
1621
+ rounded=(True, False, False, True),
1622
+ container=False,
1623
+ )
1624
+ btn = gr.Button("Generate image").style(
1625
+ margin=False,
1626
+ rounded=(False, True, True, False),
1627
+ full_width=False,
1628
+ )
1629
+
1630
+ gallery = gr.Gallery(
1631
+ label="Generated images", show_label=False, elem_id="gallery"
1632
+ ).style(grid=[1], height="auto")
1633
+
1634
+
1635
+ with gr.Accordion("Advanced settings", open=False):
1636
+ scheduler_name = gr.Dropdown(
1637
+ label="scheduler_name", choices=support_scheduler, value="DPMSolver"
1638
+ )
1639
+ guidance_scale = gr.Slider(
1640
+ label="Guidance Scale", minimum=1, maximum=30, value=7.5, step=0.1
1641
+ )
1642
+ height = gr.Slider(
1643
+ label="Height", minimum=256, maximum=1024, value=512, step=8
1644
+ )
1645
+ width = gr.Slider(
1646
+ label="Width", minimum=256, maximum=1024, value=512, step=0.1
1647
+ )
1648
+ num_inference_steps = gr.Slider(
1649
+ label="num_inference_steps", minimum=10, maximum=100, value=25, step=1
1650
+ )
1651
+
1652
+
1653
+ inputs = [text, negative, guidance_scale, height, width, num_inference_steps, scheduler_name]
1654
+ # ex = gr.Examples(examples=examples, fn=infer, inputs=inputs, outputs=gallery, cache_examples=False)
1655
+ # ex.dataset.headers = [""]
1656
+ negative.submit(infer, inputs=inputs, outputs=gallery)
1657
+ text.submit(infer, inputs=inputs, outputs=gallery)
1658
+ btn.click(infer, inputs=inputs, outputs=gallery)
1659
+
1660
+
1661
+ gr.HTML(
1662
+ """
1663
+ <div class="footer">
1664
+ <p>Model by <a href="https://www.paddlepaddle.org.cn/" style="text-decoration: underline;" target="_blank">PaddlePaddle</a> - Gradio Demo by 🤗 Hugging Face
1665
+ </p>
1666
+ </div>
1667
+ <div class="acknowledgments">
1668
+ <p><h4>LICENSE</h4>
1669
+ The model is licensed with a <a href="https://huggingface.co/stabilityai/stable-diffusion-2/blob/main/LICENSE-MODEL" style="text-decoration: underline;" target="_blank">CreativeML OpenRAIL++</a> license. The authors claim no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in this license. The license forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, spread misinformation and target vulnerable groups. For the full list of restrictions please <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" target="_blank" style="text-decoration: underline;" target="_blank">read the license</a></p>
1670
+ <p><h4>Biases and content acknowledgment</h4>
1671
+ Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography and violence. The model was trained on the <a href="https://laion.ai/blog/laion-5b/" style="text-decoration: underline;" target="_blank">LAION-5B dataset</a>, which scraped non-curated image-text-pairs from the internet (the exception being the removal of illegal content) and is meant for research purposes. You can read more in the <a href="https://huggingface.co/CompVis/stable-diffusion-v1-4" style="text-decoration: underline;" target="_blank">model card</a></p>
1672
+ </div>
1673
+ """
1674
+ )
1675
+
1676
+ block.launch(server_name="0.0.0.0", server_port=8221)
1677
+
env.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ############################################################################################################################
2
+ # 修改下面的参数
3
+ # (1)BASE_MODEL_NAME 代表你训练的基础模型
4
+ BASE_MODEL_NAME = "runwayml/stable-diffusion-v1-5"
5
+
6
+ # 是否开启lora
7
+ # (2)LORA_WEIGHTS_PATH 代码你上传到huggingface后的lora权重。
8
+ # LORA_WEIGHTS_PATH = None 表示不适应lora
9
+ LORA_WEIGHTS_PATH = "xianbao/demo_test"
10
+
11
+ # (3)PROMPTS 需要展示的prompt文本
12
+ PROMPTS = "A photo of sks dog in a bucket"
13
+ ############################################################################################################################
ppdiffusers/__init__.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # flake8: noqa
16
+
17
+ from .configuration_utils import ConfigMixin
18
+ from .fastdeploy_utils import FastDeployRuntimeModel
19
+ from .ppnlp_patch_utils import *
20
+ from .utils import (
21
+ OptionalDependencyNotAvailable,
22
+ is_fastdeploy_available,
23
+ is_inflect_available,
24
+ is_k_diffusion_available,
25
+ is_librosa_available,
26
+ is_onnx_available,
27
+ is_paddle_available,
28
+ is_paddlenlp_available,
29
+ is_scipy_available,
30
+ is_unidecode_available,
31
+ logging,
32
+ )
33
+ from .version import VERSION as __version__
34
+
35
+ try:
36
+ if not is_paddle_available():
37
+ raise OptionalDependencyNotAvailable()
38
+ except OptionalDependencyNotAvailable:
39
+ from .utils.dummy_paddle_objects import * # noqa F403
40
+ else:
41
+ from .initializer import *
42
+ from .modeling_utils import ModelMixin
43
+ from .models import (
44
+ AutoencoderKL,
45
+ PriorTransformer,
46
+ Transformer2DModel,
47
+ UNet1DModel,
48
+ UNet2DConditionModel,
49
+ UNet2DModel,
50
+ VQModel,
51
+ )
52
+ from .optimization import (
53
+ get_constant_schedule,
54
+ get_constant_schedule_with_warmup,
55
+ get_cosine_schedule_with_warmup,
56
+ get_cosine_with_hard_restarts_schedule_with_warmup,
57
+ get_linear_schedule_with_warmup,
58
+ get_polynomial_decay_schedule_with_warmup,
59
+ get_scheduler,
60
+ )
61
+ from .pipeline_utils import DiffusionPipeline
62
+ from .pipelines import (
63
+ DanceDiffusionPipeline,
64
+ DDIMPipeline,
65
+ DDPMPipeline,
66
+ KarrasVePipeline,
67
+ LDMPipeline,
68
+ LDMSuperResolutionPipeline,
69
+ PNDMPipeline,
70
+ RePaintPipeline,
71
+ ScoreSdeVePipeline,
72
+ )
73
+ from .schedulers import (
74
+ DDIMScheduler,
75
+ DDPMScheduler,
76
+ DPMSolverMultistepScheduler,
77
+ DPMSolverSinglestepScheduler,
78
+ EulerAncestralDiscreteScheduler,
79
+ EulerDiscreteScheduler,
80
+ HeunDiscreteScheduler,
81
+ IPNDMScheduler,
82
+ KarrasVeScheduler,
83
+ KDPM2AncestralDiscreteScheduler,
84
+ KDPM2DiscreteScheduler,
85
+ PNDMScheduler,
86
+ RePaintScheduler,
87
+ SchedulerMixin,
88
+ ScoreSdeVeScheduler,
89
+ UnCLIPScheduler,
90
+ VQDiffusionScheduler,
91
+ )
92
+ from .schedulers.preconfig import PreconfigEulerAncestralDiscreteScheduler
93
+ from .training_utils import EMAModel
94
+
95
+ try:
96
+ if not (is_paddle_available() and is_scipy_available()):
97
+ raise OptionalDependencyNotAvailable()
98
+ except OptionalDependencyNotAvailable:
99
+ from .utils.dummy_paddle_and_scipy_objects import * # noqa F403
100
+ else:
101
+ from .schedulers import LMSDiscreteScheduler
102
+ from .schedulers.preconfig import PreconfigLMSDiscreteScheduler
103
+
104
+ try:
105
+ if not (is_paddle_available() and is_paddlenlp_available()):
106
+ raise OptionalDependencyNotAvailable()
107
+ except OptionalDependencyNotAvailable:
108
+ from .utils.dummy_paddle_and_paddlenlp_objects import * # noqa F403
109
+ else:
110
+ from .pipelines import (
111
+ AltDiffusionImg2ImgPipeline,
112
+ AltDiffusionPipeline,
113
+ CycleDiffusionPipeline,
114
+ LDMBertModel,
115
+ LDMTextToImagePipeline,
116
+ PaintByExamplePipeline,
117
+ StableDiffusionDepth2ImgPipeline,
118
+ StableDiffusionImageVariationPipeline,
119
+ StableDiffusionImg2ImgPipeline,
120
+ StableDiffusionInpaintPipeline,
121
+ StableDiffusionInpaintPipelineLegacy,
122
+ StableDiffusionMegaPipeline,
123
+ StableDiffusionPipeline,
124
+ StableDiffusionPipelineAllinOne,
125
+ StableDiffusionPipelineSafe,
126
+ StableDiffusionUpscalePipeline,
127
+ UnCLIPPipeline,
128
+ VersatileDiffusionDualGuidedPipeline,
129
+ VersatileDiffusionImageVariationPipeline,
130
+ VersatileDiffusionPipeline,
131
+ VersatileDiffusionTextToImagePipeline,
132
+ VQDiffusionPipeline,
133
+ )
134
+
135
+ try:
136
+ if not (is_paddle_available() and is_paddlenlp_available() and is_k_diffusion_available()):
137
+ raise OptionalDependencyNotAvailable()
138
+ except OptionalDependencyNotAvailable:
139
+ from .utils.dummy_paddle_and_paddlenlp_and_k_diffusion_objects import * # noqa F403
140
+ else:
141
+ from .pipelines import StableDiffusionKDiffusionPipeline
142
+
143
+ try:
144
+ if not (is_paddle_available() and is_paddlenlp_available() and is_fastdeploy_available()):
145
+ raise OptionalDependencyNotAvailable()
146
+ except OptionalDependencyNotAvailable:
147
+ from .utils.dummy_paddle_and_paddlenlp_and_fastdeploy_objects import * # noqa F403
148
+ else:
149
+ from .pipelines import (
150
+ FastDeployStableDiffusionImg2ImgPipeline,
151
+ FastDeployStableDiffusionInpaintPipeline,
152
+ FastDeployStableDiffusionInpaintPipelineLegacy,
153
+ FastDeployStableDiffusionMegaPipeline,
154
+ FastDeployStableDiffusionPipeline,
155
+ )
156
+ try:
157
+ if not (is_paddle_available() and is_librosa_available()):
158
+ raise OptionalDependencyNotAvailable()
159
+ except OptionalDependencyNotAvailable:
160
+ from .utils.dummy_paddle_and_librosa_objects import * # noqa F403
161
+ else:
162
+ from .pipelines import AudioDiffusionPipeline, Mel
ppdiffusers/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (4.28 kB). View file
 
ppdiffusers/__pycache__/configuration_utils.cpython-37.pyc ADDED
Binary file (20.7 kB). View file
 
ppdiffusers/__pycache__/download_utils.cpython-37.pyc ADDED
Binary file (818 Bytes). View file
 
ppdiffusers/__pycache__/fastdeploy_utils.cpython-37.pyc ADDED
Binary file (8.18 kB). View file
 
ppdiffusers/__pycache__/initializer.cpython-37.pyc ADDED
Binary file (8.69 kB). View file
 
ppdiffusers/__pycache__/loaders.cpython-37.pyc ADDED
Binary file (7.47 kB). View file
 
ppdiffusers/__pycache__/modeling_utils.cpython-37.pyc ADDED
Binary file (19.8 kB). View file
 
ppdiffusers/__pycache__/optimization.cpython-37.pyc ADDED
Binary file (10.7 kB). View file
 
ppdiffusers/__pycache__/pipeline_utils.cpython-37.pyc ADDED
Binary file (22.3 kB). View file
 
ppdiffusers/__pycache__/ppnlp_patch_utils.cpython-37.pyc ADDED
Binary file (15.6 kB). View file
 
ppdiffusers/__pycache__/training_utils.cpython-37.pyc ADDED
Binary file (4.01 kB). View file
 
ppdiffusers/__pycache__/version.cpython-37.pyc ADDED
Binary file (141 Bytes). View file
 
ppdiffusers/commands/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import ABC, abstractmethod
17
+ from argparse import ArgumentParser
18
+
19
+
20
+ class BasePPDiffusersCLICommand(ABC):
21
+ @staticmethod
22
+ @abstractmethod
23
+ def register_subcommand(parser: ArgumentParser):
24
+ raise NotImplementedError()
25
+
26
+ @abstractmethod
27
+ def run(self):
28
+ raise NotImplementedError()
ppdiffusers/commands/env.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import platform
17
+ from argparse import ArgumentParser
18
+
19
+ from .. import __version__ as version
20
+ from ..utils import is_paddle_available, is_paddlenlp_available
21
+ from . import BasePPDiffusersCLICommand
22
+
23
+
24
+ def info_command_factory(_):
25
+ return EnvironmentCommand()
26
+
27
+
28
+ class EnvironmentCommand(BasePPDiffusersCLICommand):
29
+ @staticmethod
30
+ def register_subcommand(parser: ArgumentParser):
31
+ download_parser = parser.add_parser("env")
32
+ download_parser.set_defaults(func=info_command_factory)
33
+
34
+ def run(self):
35
+
36
+ pd_version = "not installed"
37
+ pd_cuda_available = "NA"
38
+ if is_paddle_available():
39
+ import paddle
40
+
41
+ pd_version = paddle.__version__
42
+ pd_cuda_available = paddle.device.is_compiled_with_cuda()
43
+
44
+ paddlenlp_version = "not installed"
45
+ if is_paddlenlp_available:
46
+ import paddlenlp
47
+
48
+ paddlenlp_version = paddlenlp.__version__
49
+
50
+ info = {
51
+ "`ppdiffusers` version": version,
52
+ "Platform": platform.platform(),
53
+ "Python version": platform.python_version(),
54
+ "Paddle version (GPU?)": f"{pd_version} ({pd_cuda_available})",
55
+ "PaddleNLP version": paddlenlp_version,
56
+ "Using GPU in script?": "<fill in>",
57
+ "Using distributed or parallel set-up in script?": "<fill in>",
58
+ }
59
+
60
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
61
+ print(self.format_dict(info))
62
+
63
+ return info
64
+
65
+ @staticmethod
66
+ def format_dict(d):
67
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
ppdiffusers/commands/ppdiffusers_cli.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .env import EnvironmentCommand
19
+
20
+
21
+ def main():
22
+ parser = ArgumentParser("PPDiffusers CLI tool", usage="ppdiffusers-cli <command> [<args>]")
23
+ commands_parser = parser.add_subparsers(help="ppdiffusers-cli command helpers")
24
+
25
+ # Register commands
26
+ EnvironmentCommand.register_subcommand(commands_parser)
27
+
28
+ # Let's go
29
+ args = parser.parse_args()
30
+
31
+ if not hasattr(args, "func"):
32
+ parser.print_help()
33
+ exit(1)
34
+
35
+ # Run
36
+ service = args.func(args)
37
+ service.run()
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
ppdiffusers/configuration_utils.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ ConfigMixin base class and utilities."""
17
+ import functools
18
+ import importlib
19
+ import inspect
20
+ import json
21
+ import os
22
+ import re
23
+ import tempfile
24
+ from collections import OrderedDict
25
+ from typing import Any, Dict, Optional, Tuple, Union
26
+
27
+ import numpy as np
28
+ from huggingface_hub import (
29
+ create_repo,
30
+ get_hf_file_metadata,
31
+ hf_hub_download,
32
+ hf_hub_url,
33
+ repo_type_and_id_from_hf_id,
34
+ upload_folder,
35
+ )
36
+ from huggingface_hub.utils import EntryNotFoundError
37
+ from requests import HTTPError
38
+
39
+ from .download_utils import ppdiffusers_bos_download
40
+ from .utils import (
41
+ DOWNLOAD_SERVER,
42
+ HF_CACHE,
43
+ PPDIFFUSERS_CACHE,
44
+ DummyObject,
45
+ deprecate,
46
+ logging,
47
+ )
48
+ from .version import VERSION as __version__
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
53
+
54
+
55
+ class FrozenDict(OrderedDict):
56
+ def __init__(self, *args, **kwargs):
57
+ super().__init__(*args, **kwargs)
58
+
59
+ for key, value in self.items():
60
+ setattr(self, key, value)
61
+
62
+ self.__frozen = True
63
+
64
+ def __delitem__(self, *args, **kwargs):
65
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
66
+
67
+ def setdefault(self, *args, **kwargs):
68
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
69
+
70
+ def pop(self, *args, **kwargs):
71
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
72
+
73
+ def update(self, *args, **kwargs):
74
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
75
+
76
+ def __setattr__(self, name, value):
77
+ if hasattr(self, "__frozen") and self.__frozen:
78
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
79
+ super().__setattr__(name, value)
80
+
81
+ def __setitem__(self, name, value):
82
+ if hasattr(self, "__frozen") and self.__frozen:
83
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
84
+ super().__setitem__(name, value)
85
+
86
+
87
+ class ConfigMixin:
88
+ r"""
89
+ Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
90
+ methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
91
+ - [`~ConfigMixin.from_config`]
92
+ - [`~ConfigMixin.save_config`]
93
+
94
+ Class attributes:
95
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
96
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
97
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
98
+ overridden by subclass).
99
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
100
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
101
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
102
+ subclass).
103
+ """
104
+ config_name = None
105
+ ignore_for_config = []
106
+ has_compatibles = False
107
+ _deprecated_kwargs = []
108
+
109
+ def register_to_config(self, **kwargs):
110
+ if self.config_name is None:
111
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
112
+
113
+ # Special case for `kwargs` used in deprecation warning added to schedulers
114
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
115
+ # or solve in a more general way.
116
+ kwargs.pop("kwargs", None)
117
+ for key, value in kwargs.items():
118
+ try:
119
+ setattr(self, key, value)
120
+ except AttributeError as err:
121
+ logger.error(f"Can't set {key} with value {value} for {self}")
122
+ raise err
123
+
124
+ if not hasattr(self, "_internal_dict"):
125
+ internal_dict = kwargs
126
+ else:
127
+ previous_dict = dict(self._internal_dict)
128
+ internal_dict = {**self._internal_dict, **kwargs}
129
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
130
+
131
+ self._internal_dict = FrozenDict(internal_dict)
132
+
133
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
134
+ """
135
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
136
+ [`~ConfigMixin.from_config`] class method.
137
+
138
+ Args:
139
+ save_directory (`str` or `os.PathLike`):
140
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
141
+ """
142
+ if os.path.isfile(save_directory):
143
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
144
+
145
+ os.makedirs(save_directory, exist_ok=True)
146
+
147
+ # If we save using the predefined names, we can load using `from_config`
148
+ output_config_file = os.path.join(save_directory, self.config_name)
149
+
150
+ self.to_json_file(output_config_file)
151
+ logger.info(f"Configuration saved in {output_config_file}")
152
+
153
+ def save_to_hf_hub(
154
+ self,
155
+ repo_id: str,
156
+ private: Optional[bool] = None,
157
+ subfolder: Optional[str] = None,
158
+ commit_message: Optional[str] = None,
159
+ revision: Optional[str] = None,
160
+ create_pr: bool = False,
161
+ ):
162
+ """
163
+ Uploads all elements of this config to a new HuggingFace Hub repository.
164
+ Args:
165
+ repo_id (str): Repository name for your model/tokenizer in the Hub.
166
+ private (bool, optional): Whether the model/tokenizer is set to private
167
+ subfolder (str, optional): Push to a subfolder of the repo instead of the root
168
+ commit_message (str, optional): The summary / title / first line of the generated commit. Defaults to: f"Upload {path_in_repo} with huggingface_hub"
169
+ revision (str, optional): The git revision to commit from. Defaults to the head of the "main" branch.
170
+ create_pr (boolean, optional): Whether or not to create a Pull Request with that commit. Defaults to False.
171
+ If revision is not set, PR is opened against the "main" branch. If revision is set and is a branch, PR is opened against this branch.
172
+ If revision is set and is not a branch name (example: a commit oid), an RevisionNotFoundError is returned by the server.
173
+
174
+ Returns: The url of the commit of your model in the given repository.
175
+ """
176
+ repo_url = create_repo(repo_id, private=private, exist_ok=True)
177
+
178
+ # Infer complete repo_id from repo_url
179
+ # Can be different from the input `repo_id` if repo_owner was implicit
180
+ _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
181
+
182
+ repo_id = f"{repo_owner}/{repo_name}"
183
+
184
+ # Check if README file already exist in repo
185
+ try:
186
+ get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
187
+ has_readme = True
188
+ except EntryNotFoundError:
189
+ has_readme = False
190
+
191
+ with tempfile.TemporaryDirectory() as root_dir:
192
+ if subfolder is not None:
193
+ save_dir = os.path.join(root_dir, subfolder)
194
+ else:
195
+ save_dir = root_dir
196
+ # save config
197
+ self.save_config(save_dir)
198
+ # Add readme if does not exist
199
+ logger.info("README.md not found, adding the default README.md")
200
+ if not has_readme:
201
+ with open(os.path.join(root_dir, "README.md"), "w") as f:
202
+ f.write(f"---\nlibrary_name: ppdiffusers\n---\n# {repo_id}")
203
+
204
+ # Upload model and return
205
+ logger.info(f"Pushing to the {repo_id}. This might take a while")
206
+ return upload_folder(
207
+ repo_id=repo_id,
208
+ repo_type="model",
209
+ folder_path=root_dir,
210
+ commit_message=commit_message,
211
+ revision=revision,
212
+ create_pr=create_pr,
213
+ )
214
+
215
+ @classmethod
216
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
217
+ r"""
218
+ Instantiate a Python class from a config dictionary
219
+
220
+ Parameters:
221
+ config (`Dict[str, Any]`):
222
+ A config dictionary from which the Python class will be instantiated. Make sure to only load
223
+ configuration files of compatible classes.
224
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
225
+ Whether kwargs that are not consumed by the Python class should be returned or not.
226
+
227
+ kwargs (remaining dictionary of keyword arguments, *optional*):
228
+ Can be used to update the configuration object (after it being loaded) and initiate the Python class.
229
+ `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
230
+ overwrite same named arguments of `config`.
231
+
232
+ Examples:
233
+
234
+ ```python
235
+ >>> from ppdiffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
236
+
237
+ >>> # Download scheduler from BOS and cache.
238
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
239
+
240
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
241
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
242
+
243
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
244
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
245
+ ```
246
+ """
247
+ # <===== TO BE REMOVED WITH DEPRECATION
248
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
249
+ if "pretrained_model_name_or_path" in kwargs:
250
+ config = kwargs.pop("pretrained_model_name_or_path")
251
+
252
+ if config is None:
253
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
254
+ # ======>
255
+
256
+ if not isinstance(config, dict):
257
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
258
+ if "Scheduler" in cls.__name__:
259
+ deprecation_message += (
260
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
261
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
262
+ " be removed in v1.0.0."
263
+ )
264
+ elif "Model" in cls.__name__:
265
+ deprecation_message += (
266
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
267
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
268
+ " instead. This functionality will be removed in v1.0.0."
269
+ )
270
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
271
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
272
+
273
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
274
+
275
+ # Allow dtype to be specified on initialization
276
+ if "dtype" in unused_kwargs:
277
+ # (TODO junnyu, donot use dtype)
278
+ unused_kwargs.pop("dtype")
279
+ # init_dict["dtype"] = unused_kwargs.pop("dtype")
280
+
281
+ # add possible deprecated kwargs
282
+ for deprecated_kwarg in cls._deprecated_kwargs:
283
+ if deprecated_kwarg in unused_kwargs:
284
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
285
+
286
+ # Return model and optionally state and/or unused_kwargs
287
+ model = cls(**init_dict)
288
+
289
+ # make sure to also save config parameters that might be used for compatible classes
290
+ model.register_to_config(**hidden_dict)
291
+
292
+ # add hidden kwargs of compatible classes to unused_kwargs
293
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
294
+
295
+ if return_unused_kwargs:
296
+ return (model, unused_kwargs)
297
+ else:
298
+ return model
299
+
300
+ @classmethod
301
+ def get_config_dict(cls, *args, **kwargs):
302
+ deprecation_message = (
303
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
304
+ " removed in version v1.0.0"
305
+ )
306
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
307
+ return cls.load_config(*args, **kwargs)
308
+
309
+ @classmethod
310
+ def load_config(
311
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
312
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
313
+ r"""
314
+ Instantiate a Python class from a config dictionary
315
+
316
+ Parameters:
317
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
318
+ Can be either:
319
+
320
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
321
+ organization name, like `google/ddpm-celebahq-256`.
322
+ - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
323
+ `./my_model_directory/`.
324
+
325
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
326
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
327
+ standard cache should not be used.
328
+ output_loading_info(`bool`, *optional*, defaults to `False`):
329
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
330
+ subfolder (`str`, *optional*, defaults to `""`):
331
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
332
+ huggingface.co or downloaded locally), you can specify the folder name here.
333
+ from_hf_hub (bool, *optional*):
334
+ Whether to load from Hugging Face Hub. Defaults to False
335
+ """
336
+ from_hf_hub = kwargs.pop("from_hf_hub", False)
337
+ if from_hf_hub:
338
+ cache_dir = kwargs.pop("cache_dir", HF_CACHE)
339
+ else:
340
+ cache_dir = kwargs.pop("cache_dir", PPDIFFUSERS_CACHE)
341
+ subfolder = kwargs.pop("subfolder", None)
342
+
343
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
344
+
345
+ if cls.config_name is None:
346
+ raise ValueError(
347
+ "`self.config_name` is not defined. Note that one should not load a config from "
348
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
349
+ )
350
+
351
+ if os.path.isfile(pretrained_model_name_or_path):
352
+ config_file = pretrained_model_name_or_path
353
+ elif os.path.isdir(pretrained_model_name_or_path):
354
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
355
+ # Load from a Paddle checkpoint
356
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
357
+ elif subfolder is not None and os.path.isfile(
358
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
359
+ ):
360
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
361
+ else:
362
+ raise EnvironmentError(
363
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
364
+ )
365
+ elif from_hf_hub:
366
+ config_file = hf_hub_download(
367
+ repo_id=pretrained_model_name_or_path,
368
+ filename=cls.config_name,
369
+ cache_dir=cache_dir,
370
+ subfolder=subfolder,
371
+ library_name="PPDiffusers",
372
+ library_version=__version__,
373
+ )
374
+ else:
375
+ try:
376
+ config_file = ppdiffusers_bos_download(
377
+ pretrained_model_name_or_path,
378
+ filename=cls.config_name,
379
+ subfolder=subfolder,
380
+ cache_dir=cache_dir,
381
+ )
382
+ except HTTPError as err:
383
+ raise EnvironmentError(
384
+ "There was a specific connection error when trying to load"
385
+ f" {pretrained_model_name_or_path}:\n{err}"
386
+ )
387
+ except ValueError:
388
+ raise EnvironmentError(
389
+ f"We couldn't connect to '{DOWNLOAD_SERVER}' to load this model, couldn't find it"
390
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
391
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
392
+ " run the library in offline mode at"
393
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
394
+ )
395
+ except EnvironmentError:
396
+ raise EnvironmentError(
397
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
398
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
399
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
400
+ f"containing a {cls.config_name} file"
401
+ )
402
+
403
+ try:
404
+ # Load config dict
405
+ config_dict = cls._dict_from_json_file(config_file)
406
+ except (json.JSONDecodeError, UnicodeDecodeError):
407
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
408
+
409
+ if return_unused_kwargs:
410
+ return config_dict, kwargs
411
+
412
+ return config_dict
413
+
414
+ @staticmethod
415
+ def _get_init_keys(cls):
416
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
417
+
418
+ @classmethod
419
+ def extract_init_dict(cls, config_dict, **kwargs):
420
+ # 0. Copy origin config dict
421
+ original_dict = {k: v for k, v in config_dict.items()}
422
+
423
+ # 1. Retrieve expected config attributes from __init__ signature
424
+ expected_keys = cls._get_init_keys(cls)
425
+ expected_keys.remove("self")
426
+ # remove general kwargs if present in dict
427
+ if "kwargs" in expected_keys:
428
+ expected_keys.remove("kwargs")
429
+
430
+ # 2. Remove attributes that cannot be expected from expected config attributes
431
+ # remove keys to be ignored
432
+ if len(cls.ignore_for_config) > 0:
433
+ expected_keys = expected_keys - set(cls.ignore_for_config)
434
+
435
+ # load ppdiffusers library to import compatible and original scheduler
436
+ ppdiffusers_library = importlib.import_module(__name__.split(".")[0])
437
+
438
+ if cls.has_compatibles:
439
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
440
+ else:
441
+ compatible_classes = []
442
+
443
+ expected_keys_comp_cls = set()
444
+ for c in compatible_classes:
445
+ expected_keys_c = cls._get_init_keys(c)
446
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
447
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
448
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
449
+
450
+ # remove attributes from orig class that cannot be expected
451
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
452
+ if orig_cls_name != cls.__name__ and hasattr(ppdiffusers_library, orig_cls_name):
453
+ orig_cls = getattr(ppdiffusers_library, orig_cls_name)
454
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
455
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
456
+
457
+ # remove private attributes
458
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
459
+
460
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
461
+ init_dict = {}
462
+ for key in expected_keys:
463
+ # if config param is passed to kwarg and is present in config dict
464
+ # it should overwrite existing config dict key
465
+ if key in kwargs and key in config_dict:
466
+ config_dict[key] = kwargs.pop(key)
467
+
468
+ if key in kwargs:
469
+ # overwrite key
470
+ init_dict[key] = kwargs.pop(key)
471
+ elif key in config_dict:
472
+ # use value from config dict
473
+ init_dict[key] = config_dict.pop(key)
474
+
475
+ # 4. Give nice warning if unexpected values have been passed
476
+ if len(config_dict) > 0:
477
+ logger.warning(
478
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
479
+ "but are not expected and will be ignored. Please verify your "
480
+ f"{cls.config_name} configuration file."
481
+ )
482
+
483
+ # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
484
+ passed_keys = set(init_dict.keys())
485
+ if len(expected_keys - passed_keys) > 0:
486
+ logger.info(
487
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
488
+ )
489
+
490
+ # 6. Define unused keyword arguments
491
+ unused_kwargs = {**config_dict, **kwargs}
492
+
493
+ # 7. Define "hidden" config parameters that were saved for compatible classes
494
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
495
+
496
+ return init_dict, unused_kwargs, hidden_config_dict
497
+
498
+ @classmethod
499
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
500
+ with open(json_file, "r", encoding="utf-8") as reader:
501
+ text = reader.read()
502
+ return json.loads(text)
503
+
504
+ def __repr__(self):
505
+ return f"{self.__class__.__name__} {self.to_json_string()}"
506
+
507
+ @property
508
+ def config(self) -> Dict[str, Any]:
509
+ """
510
+ Returns the config of the class as a frozen dictionary
511
+
512
+ Returns:
513
+ `Dict[str, Any]`: Config of the class.
514
+ """
515
+ return self._internal_dict
516
+
517
+ def to_json_string(self) -> str:
518
+ """
519
+ Serializes this instance to a JSON string.
520
+
521
+ Returns:
522
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
523
+ """
524
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
525
+ config_dict["_class_name"] = self.__class__.__name__
526
+ config_dict["_ppdiffusers_version"] = __version__
527
+
528
+ def to_json_saveable(value):
529
+ if isinstance(value, np.ndarray):
530
+ value = value.tolist()
531
+ return value
532
+
533
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
534
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
535
+
536
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
537
+ """
538
+ Save this instance to a JSON file.
539
+
540
+ Args:
541
+ json_file_path (`str` or `os.PathLike`):
542
+ Path to the JSON file in which this configuration instance's parameters will be saved.
543
+ """
544
+ with open(json_file_path, "w", encoding="utf-8") as writer:
545
+ writer.write(self.to_json_string())
546
+
547
+
548
+ def register_to_config(init):
549
+ r"""
550
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
551
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
552
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
553
+
554
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
555
+ """
556
+
557
+ @functools.wraps(init)
558
+ def inner_init(self, *args, **kwargs):
559
+ # Ignore private kwargs in the init.
560
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
561
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
562
+
563
+ if not isinstance(self, ConfigMixin):
564
+ raise RuntimeError(
565
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
566
+ "not inherit from `ConfigMixin`."
567
+ )
568
+
569
+ ignore = getattr(self, "ignore_for_config", [])
570
+ # Get positional arguments aligned with kwargs
571
+ new_kwargs = {}
572
+ signature = inspect.signature(init)
573
+ parameters = {
574
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
575
+ }
576
+ for arg, name in zip(args, parameters.keys()):
577
+ new_kwargs[name] = arg
578
+
579
+ # Then add all kwargs
580
+ new_kwargs.update(
581
+ {
582
+ k: init_kwargs.get(k, default)
583
+ for k, default in parameters.items()
584
+ if k not in ignore and k not in new_kwargs
585
+ }
586
+ )
587
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
588
+ getattr(self, "register_to_config")(**new_kwargs)
589
+ init(self, *args, **init_kwargs)
590
+
591
+ return inner_init
ppdiffusers/download_utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from paddlenlp.utils.downloader import get_path_from_url_with_filelock
19
+ from paddlenlp.utils.log import logger
20
+
21
+ from .utils import DOWNLOAD_SERVER, PPDIFFUSERS_CACHE
22
+
23
+
24
+ def ppdiffusers_bos_download(pretrained_model_name_or_path, filename=None, subfolder=None, cache_dir=None):
25
+ if cache_dir is None:
26
+ cache_dir = PPDIFFUSERS_CACHE
27
+ cache_dir = (
28
+ pretrained_model_name_or_path
29
+ if os.path.isdir(pretrained_model_name_or_path)
30
+ else os.path.join(cache_dir, pretrained_model_name_or_path)
31
+ )
32
+ url = DOWNLOAD_SERVER + "/" + pretrained_model_name_or_path
33
+ if subfolder is not None:
34
+ url = url + "/" + subfolder
35
+ cache_dir = os.path.join(cache_dir, subfolder)
36
+ if filename is not None:
37
+ url = url + "/" + filename
38
+
39
+ file_path = os.path.join(cache_dir, filename)
40
+ if os.path.exists(file_path):
41
+ logger.info("Already cached %s" % file_path)
42
+ else:
43
+ file_path = get_path_from_url_with_filelock(url, cache_dir)
44
+ return file_path
ppdiffusers/experimental/README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # 🧨 PPDiffusers Experimental
2
+
3
+ 为了使得**PPDiffusers库**能够有更多的应用场景,我们在这里添加了一些**实验性的代码**。
4
+
5
+ 目前我们支持了以下场景:
6
+ * Reinforcement learning via an implementation of the [PPDiffuser](https://arxiv.org/abs/2205.09991) model.
ppdiffusers/experimental/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # flake8: noqa
16
+
17
+ from .rl import ValueGuidedRLPipeline
ppdiffusers/experimental/rl/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # flake8: noqa
16
+
17
+ from .value_guided_sampling import ValueGuidedRLPipeline
ppdiffusers/experimental/rl/value_guided_sampling.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import numpy as np
17
+ import paddle
18
+
19
+ from ...models.unet_1d import UNet1DModel
20
+ from ...pipeline_utils import DiffusionPipeline
21
+ from ...utils.dummy_paddle_objects import DDPMScheduler
22
+
23
+
24
+ class ValueGuidedRLPipeline(DiffusionPipeline):
25
+ r"""
26
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
27
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
28
+ Pipeline for sampling actions from a diffusion model trained to predict sequences of states.
29
+ Original implementation inspired by this repository: https://github.com/jannerm/diffuser.
30
+
31
+ Parameters:
32
+ value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward.
33
+ unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories.
34
+ scheduler ([`SchedulerMixin`]):
35
+ A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
36
+ application is [`DDPMScheduler`].
37
+ env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ value_function: UNet1DModel,
43
+ unet: UNet1DModel,
44
+ scheduler: DDPMScheduler,
45
+ env,
46
+ ):
47
+ super().__init__()
48
+ self.value_function = value_function
49
+ self.unet = unet
50
+ self.scheduler = scheduler
51
+ self.env = env
52
+ self.data = env.get_dataset()
53
+ self.means = dict()
54
+ for key in self.data.keys():
55
+ try:
56
+ self.means[key] = self.data[key].mean()
57
+ except Exception:
58
+ pass
59
+ self.stds = dict()
60
+ for key in self.data.keys():
61
+ try:
62
+ self.stds[key] = self.data[key].std()
63
+ except Exception:
64
+ pass
65
+ self.state_dim = env.observation_space.shape[0]
66
+ self.action_dim = env.action_space.shape[0]
67
+
68
+ def normalize(self, x_in, key):
69
+ return (x_in - self.means[key]) / self.stds[key]
70
+
71
+ def de_normalize(self, x_in, key):
72
+ return x_in * self.stds[key] + self.means[key]
73
+
74
+ def to_paddle(self, x_in):
75
+ if type(x_in) is dict:
76
+ return {k: self.to_paddle(v) for k, v in x_in.items()}
77
+ elif paddle.is_tensor(x_in):
78
+ return x_in
79
+ return paddle.to_tensor(x_in)
80
+
81
+ def reset_x0(self, x_in, cond, act_dim):
82
+ for key, val in cond.items():
83
+ x_in[:, key, act_dim:] = val.clone()
84
+ return x_in
85
+
86
+ def run_diffusion(self, x, conditions, n_guide_steps, scale):
87
+ batch_size = x.shape[0]
88
+ y = None
89
+ for i in self.progress_bar(self.scheduler.timesteps):
90
+ # create batch of timesteps to pass into model
91
+ timesteps = paddle.full((batch_size,), i, dtype="int64")
92
+ for _ in range(n_guide_steps):
93
+ with paddle.set_grad_enabled(True):
94
+ x.stop_gradient = False
95
+ # permute to match dimension for pre-trained models
96
+ y = self.value_function(x.transpose([0, 2, 1]), timesteps).sample
97
+ grad = paddle.autograd.grad([y.sum()], [x])[0]
98
+
99
+ posterior_variance = self.scheduler._get_variance(i)
100
+ model_std = paddle.exp(0.5 * posterior_variance)
101
+ grad = model_std * grad
102
+
103
+ grad[timesteps < 2] = 0
104
+ x = x.detach()
105
+ x = x + scale * grad
106
+ x = self.reset_x0(x, conditions, self.action_dim)
107
+ prev_x = self.unet(x.transpose([0, 2, 1]), timesteps).sample.transpose([0, 2, 1])
108
+ # TODO: verify deprecation of this kwarg
109
+ x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
110
+
111
+ # apply conditions to the trajectory (set the initial state)
112
+ x = self.reset_x0(x, conditions, self.action_dim)
113
+ x = self.to_paddle(x)
114
+ return x, y
115
+
116
+ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
117
+ # normalize the observations and create batch dimension
118
+ obs = self.normalize(obs, "observations")
119
+ obs = obs[None].repeat(batch_size, axis=0)
120
+
121
+ conditions = {0: self.to_paddle(obs)}
122
+ shape = [batch_size, planning_horizon, self.state_dim + self.action_dim]
123
+
124
+ # generate initial noise and apply our conditions (to make the trajectories start at current state)
125
+ x1 = paddle.randn(shape)
126
+ x = self.reset_x0(x1, conditions, self.action_dim)
127
+ x = self.to_paddle(x)
128
+
129
+ # run the diffusion process
130
+ x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
131
+
132
+ # sort output trajectories by value
133
+ sorted_idx = paddle.argsort(y, 0, descending=True).squeeze()
134
+ sorted_values = x[sorted_idx]
135
+ actions = sorted_values[:, :, : self.action_dim]
136
+ actions = actions.detach().numpy()
137
+ denorm_actions = self.de_normalize(actions, key="actions")
138
+
139
+ # select the action with the highest value
140
+ if y is not None:
141
+ selected_index = 0
142
+ else:
143
+ # if we didn't run value guiding, select a random action
144
+ selected_index = np.random.randint(0, batch_size)
145
+ denorm_actions = denorm_actions[selected_index, 0]
146
+ return denorm_actions
ppdiffusers/fastdeploy_utils.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import shutil
19
+ from pathlib import Path
20
+ from typing import Optional, Union
21
+
22
+ import numpy as np
23
+
24
+ from .download_utils import ppdiffusers_bos_download
25
+ from .utils import (
26
+ FASTDEPLOY_MODEL_NAME,
27
+ FASTDEPLOY_WEIGHTS_NAME,
28
+ is_fastdeploy_available,
29
+ is_paddle_available,
30
+ logging,
31
+ )
32
+
33
+ if is_paddle_available():
34
+ import paddle
35
+
36
+
37
+ if is_fastdeploy_available():
38
+ import fastdeploy as fd
39
+
40
+ def fdtensor2pdtensor(fdtensor: fd.C.FDTensor):
41
+ dltensor = fdtensor.to_dlpack()
42
+ pdtensor = paddle.utils.dlpack.from_dlpack(dltensor)
43
+ return pdtensor
44
+
45
+ def pdtensor2fdtensor(pdtensor: paddle.Tensor, name: str = "", share_with_raw_ptr=False):
46
+ if not share_with_raw_ptr:
47
+ dltensor = paddle.utils.dlpack.to_dlpack(pdtensor)
48
+ return fd.C.FDTensor.from_dlpack(name, dltensor)
49
+ else:
50
+ return fd.C.FDTensor.from_external_data(
51
+ name,
52
+ pdtensor.data_ptr(),
53
+ pdtensor.shape,
54
+ pdtensor.dtype.name,
55
+ str(pdtensor.place),
56
+ int(pdtensor.place.gpu_device_id()),
57
+ )
58
+
59
+
60
+ logger = logging.get_logger(__name__)
61
+
62
+
63
+ class FastDeployRuntimeModel:
64
+ def __init__(self, model=None, **kwargs):
65
+ logger.info("`ppdiffusers.FastDeployRuntimeModel` is experimental and might change in the future.")
66
+ self.model = model
67
+ self.model_save_dir = kwargs.get("model_save_dir", None)
68
+ self.latest_model_name = kwargs.get("latest_model_name", "inference.pdmodel")
69
+ self.latest_params_name = kwargs.get("latest_params_name", "inference.pdiparams")
70
+
71
+ def zero_copy_infer(self, prebinded_inputs: dict, prebinded_outputs: dict, share_with_raw_ptr=True, **kwargs):
72
+ """
73
+ Execute inference without copying data from cpu to gpu.
74
+
75
+ Arguments:
76
+ kwargs (`dict(name, paddle.Tensor)`):
77
+ An input map from name to tensor.
78
+ Return:
79
+ List of output tensor.
80
+ """
81
+ for inputs_name, inputs_tensor in prebinded_inputs.items():
82
+ input_fdtensor = pdtensor2fdtensor(inputs_tensor, inputs_name, share_with_raw_ptr=share_with_raw_ptr)
83
+ self.model.bind_input_tensor(inputs_name, input_fdtensor)
84
+
85
+ for outputs_name, outputs_tensor in prebinded_outputs.items():
86
+ output_fdtensor = pdtensor2fdtensor(outputs_tensor, outputs_name, share_with_raw_ptr=share_with_raw_ptr)
87
+ self.model.bind_output_tensor(outputs_name, output_fdtensor)
88
+
89
+ self.model.zero_copy_infer()
90
+
91
+ def __call__(self, **kwargs):
92
+ inputs = {k: np.array(v) for k, v in kwargs.items()}
93
+ return self.model.infer(inputs)
94
+
95
+ @staticmethod
96
+ def load_model(
97
+ model_path: Union[str, Path],
98
+ params_path: Union[str, Path],
99
+ runtime_options: Optional["fd.RuntimeOption"] = None,
100
+ ):
101
+ """
102
+ Loads an FastDeploy Inference Model with fastdeploy.RuntimeOption
103
+
104
+ Arguments:
105
+ model_path (`str` or `Path`):
106
+ Model path from which to load
107
+ params_path (`str` or `Path`):
108
+ Params path from which to load
109
+ runtime_options (fd.RuntimeOption, *optional*):
110
+ The RuntimeOption of fastdeploy to initialize the fastdeploy runtime. Default setting
111
+ the device to cpu and the backend to paddle inference
112
+ """
113
+ option = runtime_options
114
+ if option is None or not isinstance(runtime_options, fd.RuntimeOption):
115
+ logger.info("No fastdeploy.RuntimeOption specified, using CPU device and paddle inference backend.")
116
+ option = fd.RuntimeOption()
117
+ option.use_paddle_backend()
118
+ option.use_cpu()
119
+ option.set_model_path(model_path, params_path)
120
+ return fd.Runtime(option)
121
+
122
+ def _save_pretrained(
123
+ self,
124
+ save_directory: Union[str, Path],
125
+ model_file_name: Optional[str] = None,
126
+ params_file_name: Optional[str] = None,
127
+ **kwargs
128
+ ):
129
+ """
130
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
131
+ [`~FastDeployRuntimeModel.from_pretrained`] class method. It will always save the
132
+ latest_model_name.
133
+
134
+ Arguments:
135
+ save_directory (`str` or `Path`):
136
+ Directory where to save the model file.
137
+ model_file_name(`str`, *optional*):
138
+ Overwrites the default model file name from `"inference.pdmodel"` to `model_file_name`. This allows you to save the
139
+ model with a different name.
140
+ params_file_name(`str`, *optional*):
141
+ Overwrites the default model file name from `"inference.pdiparams"` to `params_file_name`. This allows you to save the
142
+ model with a different name.
143
+ """
144
+
145
+ model_file_name = model_file_name if model_file_name is not None else FASTDEPLOY_MODEL_NAME
146
+ params_file_name = params_file_name if params_file_name is not None else FASTDEPLOY_WEIGHTS_NAME
147
+
148
+ src_model_path = self.model_save_dir.joinpath(self.latest_model_name)
149
+ dst_model_path = Path(save_directory).joinpath(model_file_name)
150
+
151
+ src_params_path = self.model_save_dir.joinpath(self.latest_params_name)
152
+ dst_params_path = Path(save_directory).joinpath(params_file_name)
153
+ try:
154
+ shutil.copyfile(src_model_path, dst_model_path)
155
+ shutil.copyfile(src_params_path, dst_params_path)
156
+ except shutil.SameFileError:
157
+ pass
158
+
159
+ def save_pretrained(
160
+ self,
161
+ save_directory: Union[str, os.PathLike],
162
+ **kwargs,
163
+ ):
164
+ """
165
+ Save a model to a directory, so that it can be re-loaded using the [`~FastDeployRuntimeModel.from_pretrained`] class
166
+ method.:
167
+
168
+ Arguments:
169
+ save_directory (`str` or `os.PathLike`):
170
+ Directory to which to save. Will be created if it doesn't exist.
171
+ """
172
+ if os.path.isfile(save_directory):
173
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
174
+ return
175
+
176
+ os.makedirs(save_directory, exist_ok=True)
177
+
178
+ # saving model weights/files
179
+ self._save_pretrained(save_directory, **kwargs)
180
+
181
+ @classmethod
182
+ def _from_pretrained(
183
+ cls,
184
+ pretrained_model_name_or_path: Union[str, Path],
185
+ cache_dir: Optional[str] = None,
186
+ model_file_name: Optional[str] = None,
187
+ params_file_name: Optional[str] = None,
188
+ runtime_options: Optional["fd.RuntimeOption"] = None,
189
+ **kwargs,
190
+ ):
191
+ """
192
+ Load a model from a directory or the BOS.
193
+
194
+ Arguments:
195
+ pretrained_model_name_or_path (`str` or `Path`):
196
+ Directory from which to load
197
+ cache_dir (`Union[str, Path]`, *optional*):
198
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
199
+ standard cache should not be used.
200
+ model_file_name (`str`):
201
+ Overwrites the default model file name from `"inference.pdmodel"` to `file_name`. This allows you to load
202
+ different model files from the same repository or directory.
203
+ params_file_name (`str`):
204
+ Overwrites the default params file name from `"inference.pdiparams"` to `file_name`. This allows you to load
205
+ different model files from the same repository or directory.
206
+ runtime_options (`fastdeploy.RuntimeOption`, *optional*):
207
+ The RuntimeOption of fastdeploy.
208
+ kwargs (`Dict`, *optional*):
209
+ kwargs will be passed to the model during initialization
210
+ """
211
+ model_file_name = model_file_name if model_file_name is not None else FASTDEPLOY_MODEL_NAME
212
+ params_file_name = params_file_name if params_file_name is not None else FASTDEPLOY_WEIGHTS_NAME
213
+ # load model from local directory
214
+ if os.path.isdir(pretrained_model_name_or_path):
215
+ model = FastDeployRuntimeModel.load_model(
216
+ os.path.join(pretrained_model_name_or_path, model_file_name),
217
+ os.path.join(pretrained_model_name_or_path, params_file_name),
218
+ runtime_options=runtime_options,
219
+ )
220
+ kwargs["model_save_dir"] = Path(pretrained_model_name_or_path)
221
+ # load model from hub
222
+ else:
223
+ # download model
224
+ model_cache_path = ppdiffusers_bos_download(
225
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
226
+ filename=model_file_name,
227
+ cache_dir=cache_dir,
228
+ )
229
+ # download params
230
+ params_cache_path = ppdiffusers_bos_download(
231
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
232
+ filename=params_file_name,
233
+ cache_dir=cache_dir,
234
+ )
235
+ kwargs["model_save_dir"] = Path(model_cache_path).parent
236
+ kwargs["latest_model_name"] = Path(model_cache_path).name
237
+ kwargs["latest_params_name"] = Path(params_cache_path).name
238
+ model = FastDeployRuntimeModel.load_model(
239
+ model_cache_path, params_cache_path, runtime_options=runtime_options
240
+ )
241
+ return cls(model=model, **kwargs)
242
+
243
+ @classmethod
244
+ def from_pretrained(
245
+ cls,
246
+ pretrained_model_name_or_path: Union[str, Path],
247
+ cache_dir: Optional[str] = None,
248
+ model_file_name: Optional[str] = None,
249
+ params_file_name: Optional[str] = None,
250
+ runtime_options: Optional["fd.RuntimeOption"] = None,
251
+ **model_kwargs,
252
+ ):
253
+ return cls._from_pretrained(
254
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
255
+ cache_dir=cache_dir,
256
+ model_file_name=model_file_name,
257
+ params_file_name=params_file_name,
258
+ runtime_options=runtime_options,
259
+ **model_kwargs,
260
+ )
ppdiffusers/initializer.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This code is based on https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
17
+ Ths copyright of pytorch/pytorch is a BSD-style license, as found in the LICENSE file.
18
+ """
19
+
20
+ import math
21
+
22
+ import numpy as np
23
+ import paddle
24
+ import paddle.nn as nn
25
+
26
+ __all__ = [
27
+ "uniform_",
28
+ "normal_",
29
+ "constant_",
30
+ "ones_",
31
+ "zeros_",
32
+ "xavier_uniform_",
33
+ "xavier_normal_",
34
+ "kaiming_uniform_",
35
+ "kaiming_normal_",
36
+ "linear_init_",
37
+ "conv_init_",
38
+ "reset_initialized_parameter",
39
+ ]
40
+
41
+
42
+ def _no_grad_uniform_(tensor, a, b):
43
+ with paddle.no_grad():
44
+ tensor.set_value(paddle.uniform(shape=tensor.shape, dtype=tensor.dtype, min=a, max=b))
45
+ return tensor
46
+
47
+
48
+ def _no_grad_normal_(tensor, mean=0.0, std=1.0):
49
+ with paddle.no_grad():
50
+ tensor.set_value(paddle.normal(mean=mean, std=std, shape=tensor.shape))
51
+ return tensor
52
+
53
+
54
+ def _no_grad_fill_(tensor, value=0.0):
55
+ with paddle.no_grad():
56
+ tensor.set_value(paddle.full_like(tensor, value, dtype=tensor.dtype))
57
+ return tensor
58
+
59
+
60
+ def uniform_(tensor, a, b):
61
+ """
62
+ Modified tensor inspace using uniform_
63
+ Args:
64
+ tensor (paddle.Tensor): paddle Tensor
65
+ a (float|int): min value.
66
+ b (float|int): max value.
67
+ Return:
68
+ tensor
69
+ """
70
+ return _no_grad_uniform_(tensor, a, b)
71
+
72
+
73
+ def normal_(tensor, mean=0.0, std=1.0):
74
+ """
75
+ Modified tensor inspace using normal_
76
+ Args:
77
+ tensor (paddle.Tensor): paddle Tensor
78
+ mean (float|int): mean value.
79
+ std (float|int): std value.
80
+ Return:
81
+ tensor
82
+ """
83
+ return _no_grad_normal_(tensor, mean, std)
84
+
85
+
86
+ def constant_(tensor, value=0.0):
87
+ """
88
+ Modified tensor inspace using constant_
89
+ Args:
90
+ tensor (paddle.Tensor): paddle Tensor
91
+ value (float|int): value to fill tensor.
92
+ Return:
93
+ tensor
94
+ """
95
+ return _no_grad_fill_(tensor, value)
96
+
97
+
98
+ def ones_(tensor):
99
+ """
100
+ Modified tensor inspace using ones_
101
+ Args:
102
+ tensor (paddle.Tensor): paddle Tensor
103
+ Return:
104
+ tensor
105
+ """
106
+ return _no_grad_fill_(tensor, 1)
107
+
108
+
109
+ def zeros_(tensor):
110
+ """
111
+ Modified tensor inspace using zeros_
112
+ Args:
113
+ tensor (paddle.Tensor): paddle Tensor
114
+ Return:
115
+ tensor
116
+ """
117
+ return _no_grad_fill_(tensor, 0)
118
+
119
+
120
+ def vector_(tensor, vector):
121
+ with paddle.no_grad():
122
+ tensor.set_value(paddle.to_tensor(vector, dtype=tensor.dtype))
123
+ return tensor
124
+
125
+
126
+ def _calculate_fan_in_and_fan_out(tensor, reverse=False):
127
+ """
128
+ Calculate (fan_in, _fan_out) for tensor
129
+ Args:
130
+ tensor (Tensor): paddle.Tensor
131
+ reverse (bool: False): tensor data format order, False by default as [fout, fin, ...]. e.g. : conv.weight [cout, cin, kh, kw] is False; linear.weight [cin, cout] is True
132
+ Return:
133
+ Tuple[fan_in, fan_out]
134
+ """
135
+ if tensor.ndim < 2:
136
+ raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
137
+
138
+ if reverse:
139
+ num_input_fmaps, num_output_fmaps = tensor.shape[0], tensor.shape[1]
140
+ else:
141
+ num_input_fmaps, num_output_fmaps = tensor.shape[1], tensor.shape[0]
142
+
143
+ receptive_field_size = 1
144
+ if tensor.ndim > 2:
145
+ receptive_field_size = np.prod(tensor.shape[2:])
146
+
147
+ fan_in = num_input_fmaps * receptive_field_size
148
+ fan_out = num_output_fmaps * receptive_field_size
149
+
150
+ return fan_in, fan_out
151
+
152
+
153
+ def xavier_uniform_(tensor, gain=1.0, reverse=False):
154
+ """
155
+ Modified tensor inspace using xavier_uniform_
156
+ Args:
157
+ tensor (paddle.Tensor): paddle Tensor
158
+ gain (float): super parameter, 1. default.
159
+ reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
160
+ Return:
161
+ tensor
162
+ """
163
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse)
164
+ std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
165
+ k = math.sqrt(3.0) * std
166
+ return _no_grad_uniform_(tensor, -k, k)
167
+
168
+
169
+ def xavier_normal_(tensor, gain=1.0, reverse=False):
170
+ """
171
+ Modified tensor inspace using xavier_normal_
172
+ Args:
173
+ tensor (paddle.Tensor): paddle Tensor
174
+ gain (float): super parameter, 1. default.
175
+ reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
176
+ Return:
177
+ tensor
178
+ """
179
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse)
180
+ std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
181
+ return _no_grad_normal_(tensor, 0, std)
182
+
183
+
184
+ # reference: https://pytorch.org/docs/stable/_modules/torch/nn/init.html
185
+ def _calculate_correct_fan(tensor, mode, reverse=False):
186
+ mode = mode.lower()
187
+ valid_modes = ["fan_in", "fan_out"]
188
+ if mode not in valid_modes:
189
+ raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
190
+
191
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse)
192
+
193
+ return fan_in if mode == "fan_in" else fan_out
194
+
195
+
196
+ def _calculate_gain(nonlinearity, param=None):
197
+ linear_fns = ["linear", "conv1d", "conv2d", "conv3d", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d"]
198
+ if nonlinearity in linear_fns or nonlinearity == "sigmoid":
199
+ return 1
200
+ elif nonlinearity == "tanh":
201
+ return 5.0 / 3
202
+ elif nonlinearity == "relu":
203
+ return math.sqrt(2.0)
204
+ elif nonlinearity == "leaky_relu":
205
+ if param is None:
206
+ negative_slope = 0.01
207
+ elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
208
+ # True/False are instances of int, hence check above
209
+ negative_slope = param
210
+ else:
211
+ raise ValueError("negative_slope {} not a valid number".format(param))
212
+ return math.sqrt(2.0 / (1 + negative_slope**2))
213
+ elif nonlinearity == "selu":
214
+ return 3.0 / 4
215
+ else:
216
+ raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
217
+
218
+
219
+ def kaiming_uniform_(tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", reverse=False):
220
+ """
221
+ Modified tensor inspace using kaiming_uniform method
222
+ Args:
223
+ tensor (paddle.Tensor): paddle Tensor
224
+ mode (str): ['fan_in', 'fan_out'], 'fin_in' defalut
225
+ nonlinearity (str): nonlinearity method name
226
+ reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
227
+ Return:
228
+ tensor
229
+ """
230
+ fan = _calculate_correct_fan(tensor, mode, reverse)
231
+ gain = _calculate_gain(nonlinearity, a)
232
+ std = gain / math.sqrt(fan)
233
+ k = math.sqrt(3.0) * std
234
+ return _no_grad_uniform_(tensor, -k, k)
235
+
236
+
237
+ def kaiming_normal_(tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", reverse=False):
238
+ """
239
+ Modified tensor inspace using kaiming_normal_
240
+ Args:
241
+ tensor (paddle.Tensor): paddle Tensor
242
+ mode (str): ['fan_in', 'fan_out'], 'fin_in' defalut
243
+ nonlinearity (str): nonlinearity method name
244
+ reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
245
+ Return:
246
+ tensor
247
+ """
248
+ fan = _calculate_correct_fan(tensor, mode, reverse)
249
+ gain = _calculate_gain(nonlinearity, a)
250
+ std = gain / math.sqrt(fan)
251
+ return _no_grad_normal_(tensor, 0, std)
252
+
253
+
254
+ def linear_init_(module):
255
+ bound = 1 / math.sqrt(module.weight.shape[0])
256
+ uniform_(module.weight, -bound, bound)
257
+ uniform_(module.bias, -bound, bound)
258
+
259
+
260
+ def conv_init_(module):
261
+ bound = 1 / np.sqrt(np.prod(module.weight.shape[1:]))
262
+ uniform_(module.weight, -bound, bound)
263
+ if module.bias is not None:
264
+ uniform_(module.bias, -bound, bound)
265
+
266
+
267
+ def bias_init_with_prob(prior_prob=0.01):
268
+ """initialize conv/fc bias value according to a given probability value."""
269
+ bias_init = float(-np.log((1 - prior_prob) / prior_prob))
270
+ return bias_init
271
+
272
+
273
+ @paddle.no_grad()
274
+ def reset_initialized_parameter(model, include_self=True):
275
+ """
276
+ Reset initialized parameter using following method for [conv, linear, embedding, bn]
277
+ Args:
278
+ model (paddle.Layer): paddle Layer
279
+ include_self (bool: False): include_self for Layer.named_sublayers method. Indicate whether including itself
280
+ Return:
281
+ None
282
+ """
283
+ for _, m in model.named_sublayers(include_self=include_self):
284
+ if isinstance(m, nn.Conv2D):
285
+ k = float(m._groups) / (m._in_channels * m._kernel_size[0] * m._kernel_size[1])
286
+ k = math.sqrt(k)
287
+ _no_grad_uniform_(m.weight, -k, k)
288
+ if hasattr(m, "bias") and getattr(m, "bias") is not None:
289
+ _no_grad_uniform_(m.bias, -k, k)
290
+
291
+ elif isinstance(m, nn.Linear):
292
+ k = math.sqrt(1.0 / m.weight.shape[0])
293
+ _no_grad_uniform_(m.weight, -k, k)
294
+ if hasattr(m, "bias") and getattr(m, "bias") is not None:
295
+ _no_grad_uniform_(m.bias, -k, k)
296
+
297
+ elif isinstance(m, nn.Embedding):
298
+ _no_grad_normal_(m.weight, mean=0.0, std=1.0)
299
+
300
+ elif isinstance(m, (nn.BatchNorm2D, nn.LayerNorm)):
301
+ _no_grad_fill_(m.weight, 1.0)
302
+ if hasattr(m, "bias") and getattr(m, "bias") is not None:
303
+ _no_grad_fill_(m.bias, 0)
ppdiffusers/loaders.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import os
16
+ from collections import defaultdict
17
+ from typing import Callable, Dict, Union
18
+
19
+ import paddle
20
+ import paddle.nn as nn
21
+
22
+ from .modeling_utils import _get_model_file, load_dict
23
+ from .models.cross_attention import LoRACrossAttnProcessor
24
+ from .utils import HF_CACHE, PPDIFFUSERS_CACHE, logging
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ LORA_WEIGHT_NAME = "paddle_lora_weights.pdparams"
30
+
31
+
32
+ class AttnProcsLayers(nn.Layer):
33
+ def __init__(self, state_dict: Dict[str, paddle.Tensor]):
34
+ super().__init__()
35
+ self.layers = nn.LayerList(state_dict.values())
36
+ self.mapping = {k: v for k, v in enumerate(state_dict.keys())}
37
+ self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
38
+
39
+ # we add a hook to state_dict() and load_state_dict() so that the
40
+ # naming fits with `unet.attn_processors`
41
+ def map_to(state_dict, *args, **kwargs):
42
+ new_state_dict = {}
43
+ for key, value in state_dict.items():
44
+ num = int(key.split(".")[1]) # 0 is always "layers"
45
+ new_key = key.replace(f"layers.{num}", self.mapping[num])
46
+ new_state_dict[new_key] = value
47
+
48
+ return new_state_dict
49
+
50
+ def map_from(module, state_dict, *args, **kwargs):
51
+ all_keys = list(state_dict.keys())
52
+ for key in all_keys:
53
+ replace_key = key.split(".processor")[0] + ".processor"
54
+ new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
55
+ state_dict[new_key] = state_dict[key]
56
+ del state_dict[key]
57
+
58
+ self.register_state_dict_hook(map_to)
59
+ self.register_load_state_dict_pre_hook(map_from, with_module=True)
60
+
61
+
62
+ class UNet2DConditionLoadersMixin:
63
+ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, paddle.Tensor]], **kwargs):
64
+ r"""
65
+ Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be
66
+ defined in
67
+ [cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
68
+ and be a `paddle.nn.Layer` class.
69
+ <Tip warning={true}>
70
+ This function is experimental and might change in the future
71
+ </Tip>
72
+ Parameters:
73
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
74
+ Can be either:
75
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
76
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
77
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
78
+ `./my_model_directory/`.
79
+ - A [paddle state
80
+ dict].
81
+ from_hf_hub (bool, optional): whether to load from Huggingface Hub.
82
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
83
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
84
+ standard cache should not be used.
85
+ subfolder (`str`, *optional*, defaults to `None`):
86
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
87
+ huggingface.co or downloaded locally), you can specify the folder name here.
88
+ """
89
+
90
+ from_hf_hub = kwargs.pop("from_hf_hub", False)
91
+ if from_hf_hub:
92
+ cache_dir = kwargs.pop("cache_dir", HF_CACHE)
93
+ else:
94
+ cache_dir = kwargs.pop("cache_dir", PPDIFFUSERS_CACHE)
95
+ subfolder = kwargs.pop("subfolder", None)
96
+ weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME)
97
+
98
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
99
+ model_file = _get_model_file(
100
+ pretrained_model_name_or_path_or_dict,
101
+ weights_name=weight_name,
102
+ cache_dir=cache_dir,
103
+ subfolder=subfolder,
104
+ from_hf_hub=from_hf_hub,
105
+ )
106
+ state_dict = load_dict(model_file, map_location="cpu")
107
+ else:
108
+ state_dict = pretrained_model_name_or_path_or_dict
109
+
110
+ # fill attn processors
111
+ attn_processors = {}
112
+
113
+ is_lora = all("lora" in k for k in state_dict.keys())
114
+
115
+ if is_lora:
116
+ lora_grouped_dict = defaultdict(dict)
117
+ for key, value in state_dict.items():
118
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
119
+ lora_grouped_dict[attn_processor_key][sub_key] = value
120
+
121
+ for key, value_dict in lora_grouped_dict.items():
122
+ rank = value_dict["to_k_lora.down.weight"].shape[1] # 0 -> 1, torch vs paddle nn.Linear
123
+ cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[0] # 1 -> 0, torch vs paddle nn.Linear
124
+ hidden_size = value_dict["to_k_lora.up.weight"].shape[1] # 0 -> 1, torch vs paddle nn.Linear
125
+
126
+ attn_processors[key] = LoRACrossAttnProcessor(
127
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
128
+ )
129
+ attn_processors[key].load_dict(value_dict)
130
+
131
+ else:
132
+ raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
133
+
134
+ # set correct dtype & device
135
+ attn_processors = {k: v.to(dtype=self.dtype) for k, v in attn_processors.items()}
136
+
137
+ # set layers
138
+ self.set_attn_processor(attn_processors)
139
+
140
+ def save_attn_procs(
141
+ self,
142
+ save_directory: Union[str, os.PathLike],
143
+ is_main_process: bool = True,
144
+ weights_name: str = LORA_WEIGHT_NAME,
145
+ save_function: Callable = None,
146
+ ):
147
+ r"""
148
+ Save an attention procesor to a directory, so that it can be re-loaded using the
149
+ `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method.
150
+ Arguments:
151
+ save_directory (`str` or `os.PathLike`):
152
+ Directory to which to save. Will be created if it doesn't exist.
153
+ is_main_process (`bool`, *optional*, defaults to `True`):
154
+ Whether the process calling this is the main process or not. Useful when in distributed training like
155
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
156
+ the main process to avoid race conditions.
157
+ weights_name (`str`, *optional*, defaults to `LORA_WEIGHT_NAME`):
158
+ The name of weights.
159
+ save_function (`Callable`):
160
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
161
+ need to replace `torch.save` by another method. Can be configured with the environment variable
162
+ `DIFFUSERS_SAVE_MODE`.
163
+ """
164
+ if os.path.isfile(save_directory):
165
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
166
+ return
167
+
168
+ if save_function is None:
169
+ save_function = paddle.save
170
+
171
+ os.makedirs(save_directory, exist_ok=True)
172
+
173
+ model_to_save = AttnProcsLayers(self.attn_processors)
174
+
175
+ # Save the model
176
+ state_dict = model_to_save.state_dict()
177
+
178
+ # Clean the folder from a previous save
179
+ for filename in os.listdir(save_directory):
180
+ full_filename = os.path.join(save_directory, filename)
181
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
182
+ # in distributed settings to avoid race conditions.
183
+ weights_no_suffix = weights_name.replace(".pdparams", "")
184
+ if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
185
+ os.remove(full_filename)
186
+
187
+ # Save the model
188
+ save_function(state_dict, os.path.join(save_directory, weights_name))
189
+
190
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
ppdiffusers/modeling_paddle_pytorch_utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch - Paddle general utilities."""
16
+ import re
17
+
18
+ from .utils import logging
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ def rename_key(key):
24
+ regex = r"\w+[.]\d+"
25
+ pats = re.findall(regex, key)
26
+ for pat in pats:
27
+ key = key.replace(pat, "_".join(pat.split(".")))
28
+ return key
29
+
30
+
31
+ #####################
32
+ # PyTorch => Paddle #
33
+ #####################
34
+
35
+
36
+ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_paddle_state_dict):
37
+ """Rename PT weight names to corresponding Paddle weight names and reshape tensor if necessary"""
38
+
39
+ # conv norm or layer norm
40
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
41
+ if (
42
+ any("norm" in str_ for str_ in pt_tuple_key)
43
+ and (pt_tuple_key[-1] in ["bias", "beta"])
44
+ and (pt_tuple_key[:-1] + ("bias",) in random_paddle_state_dict)
45
+ ):
46
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
47
+ return renamed_pt_tuple_key, pt_tensor
48
+ elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("bias",) in random_paddle_state_dict:
49
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
50
+ return renamed_pt_tuple_key, pt_tensor
51
+
52
+ # embedding
53
+ if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("weight",) in random_paddle_state_dict:
54
+ pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
55
+ return renamed_pt_tuple_key, pt_tensor
56
+
57
+ # conv layer
58
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
59
+ if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
60
+ return renamed_pt_tuple_key, pt_tensor
61
+
62
+ # linear layer
63
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
64
+ if pt_tuple_key[-1] == "weight":
65
+ pt_tensor = pt_tensor.t()
66
+ return renamed_pt_tuple_key, pt_tensor
67
+
68
+ # old PyTorch layer norm weight
69
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
70
+ if pt_tuple_key[-1] == "gamma":
71
+ return renamed_pt_tuple_key, pt_tensor
72
+
73
+ # old PyTorch layer norm bias
74
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
75
+ if pt_tuple_key[-1] == "beta":
76
+ return renamed_pt_tuple_key, pt_tensor
77
+
78
+ return pt_tuple_key, pt_tensor
79
+
80
+
81
+ def convert_pytorch_state_dict_to_paddle(pt_state_dict, paddle_model):
82
+ # Step 1: Convert pytorch tensor to numpy
83
+ pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
84
+
85
+ random_paddle_state_dict = paddle_model.state_dict
86
+ paddle_state_dict = {}
87
+
88
+ # Need to change some parameters name to match Paddle names
89
+ for pt_key, pt_tensor in pt_state_dict.items():
90
+ renamed_pt_key = rename_key(pt_key)
91
+ pt_tuple_key = tuple(renamed_pt_key.split("."))
92
+
93
+ # Correctly rename weight parameters
94
+ paddle_key, paddle_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_paddle_state_dict)
95
+
96
+ if paddle_key in random_paddle_state_dict:
97
+ if list(paddle_tensor.shape) != list(random_paddle_state_dict[paddle_key].shape):
98
+ raise ValueError(
99
+ f"Paddle checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
100
+ f"{random_paddle_state_dict[paddle_key].shape}, but is {paddle_tensor.shape}."
101
+ )
102
+
103
+ # also add unexpected weight so that warning is thrown
104
+ paddle_state_dict[paddle_key] = paddle_tensor.numpy()
105
+
106
+ return paddle_state_dict
ppdiffusers/modeling_utils.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import tempfile
19
+ from functools import partial
20
+ from typing import Callable, Optional, Union
21
+
22
+ import paddle
23
+ import paddle.nn as nn
24
+ from huggingface_hub import (
25
+ create_repo,
26
+ get_hf_file_metadata,
27
+ hf_hub_download,
28
+ hf_hub_url,
29
+ repo_type_and_id_from_hf_id,
30
+ upload_folder,
31
+ )
32
+ from huggingface_hub.utils import EntryNotFoundError
33
+ from requests import HTTPError
34
+
35
+ from .download_utils import ppdiffusers_bos_download
36
+ from .utils import (
37
+ CONFIG_NAME,
38
+ DOWNLOAD_SERVER,
39
+ HF_CACHE,
40
+ PPDIFFUSERS_CACHE,
41
+ WEIGHTS_NAME,
42
+ logging,
43
+ )
44
+ from .version import VERSION as __version__
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+
49
+ def unfreeze_params(params):
50
+ for param in params:
51
+ param.stop_gradient = False
52
+
53
+
54
+ def freeze_params(params):
55
+ for param in params:
56
+ param.stop_gradient = True
57
+
58
+
59
+ # device
60
+ def get_parameter_device(parameter: nn.Layer):
61
+ try:
62
+ return next(parameter.named_parameters())[1].place
63
+ except StopIteration:
64
+ return paddle.get_device()
65
+
66
+
67
+ def get_parameter_dtype(parameter: nn.Layer):
68
+ try:
69
+ return next(parameter.named_parameters())[1].dtype
70
+ except StopIteration:
71
+ return paddle.get_default_dtype()
72
+
73
+
74
+ def load_dict(checkpoint_file: Union[str, os.PathLike], map_location: str = "cpu"):
75
+ """
76
+ Reads a Paddle checkpoint file, returning properly formatted errors if they arise.
77
+ """
78
+ try:
79
+ if map_location == "cpu":
80
+ with paddle.device_scope("cpu"):
81
+ state_dict = paddle.load(checkpoint_file)
82
+ else:
83
+ state_dict = paddle.load(checkpoint_file)
84
+ return state_dict
85
+ except Exception as e:
86
+ try:
87
+ with open(checkpoint_file) as f:
88
+ if f.read().startswith("version"):
89
+ raise OSError(
90
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
91
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
92
+ "you cloned."
93
+ )
94
+ else:
95
+ raise ValueError(
96
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
97
+ "model. Make sure you have saved the model properly."
98
+ ) from e
99
+ except (UnicodeDecodeError, ValueError):
100
+ raise OSError(
101
+ f"Unable to load weights from Paddle checkpoint file for '{checkpoint_file}' "
102
+ f"at '{checkpoint_file}'. "
103
+ "If you tried to load a Paddle model from a TF 2.0 checkpoint, please set from_tf=True."
104
+ )
105
+
106
+
107
+ class ModelMixin(nn.Layer):
108
+ r"""
109
+ Base class for all models.
110
+
111
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
112
+ and saving models.
113
+
114
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
115
+ [`~modeling_utils.ModelMixin.save_pretrained`].
116
+ """
117
+ config_name = CONFIG_NAME
118
+ _automatically_saved_args = ["_ppdiffusers_version", "_class_name", "_name_or_path"]
119
+ _supports_gradient_checkpointing = False
120
+
121
+ def __init__(self):
122
+ super().__init__()
123
+
124
+ @property
125
+ def is_gradient_checkpointing(self) -> bool:
126
+ """
127
+ Whether gradient checkpointing is activated for this model or not.
128
+
129
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
130
+ activations".
131
+ """
132
+ return any(
133
+ hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing
134
+ for m in self.sublayers(include_self=True)
135
+ )
136
+
137
+ def enable_gradient_checkpointing(self):
138
+ """
139
+ Activates gradient checkpointing for the current model.
140
+
141
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
142
+ activations".
143
+ """
144
+ if not self._supports_gradient_checkpointing:
145
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
146
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
147
+
148
+ def disable_gradient_checkpointing(self):
149
+ """
150
+ Deactivates gradient checkpointing for the current model.
151
+
152
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
153
+ activations".
154
+ """
155
+ if self._supports_gradient_checkpointing:
156
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
157
+
158
+ def save_pretrained(
159
+ self,
160
+ save_directory: Union[str, os.PathLike],
161
+ is_main_process: bool = True,
162
+ save_function: Callable = paddle.save,
163
+ ):
164
+ """
165
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
166
+ `[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
167
+
168
+ Arguments:
169
+ save_directory (`str` or `os.PathLike`):
170
+ Directory to which to save. Will be created if it doesn't exist.
171
+ is_main_process (`bool`, *optional*, defaults to `True`):
172
+ Whether the process calling this is the main process or not. Useful when in distributed training like
173
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
174
+ the main process to avoid race conditions.
175
+ save_function (`Callable`):
176
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
177
+ need to replace `paddle.save` by another method.
178
+ """
179
+ if os.path.isfile(save_directory):
180
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
181
+ return
182
+
183
+ os.makedirs(save_directory, exist_ok=True)
184
+
185
+ model_to_save = self
186
+
187
+ # Attach architecture to the config
188
+ # Save the config
189
+ if is_main_process:
190
+ model_to_save.save_config(save_directory)
191
+
192
+ # Save the model
193
+ state_dict = model_to_save.state_dict()
194
+
195
+ # Clean the folder from a previous save
196
+ for filename in os.listdir(save_directory):
197
+ full_filename = os.path.join(save_directory, filename)
198
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
199
+ # in distributed settings to avoid race conditions.
200
+ if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
201
+ os.remove(full_filename)
202
+
203
+ # Save the model
204
+ save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
205
+
206
+ logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
207
+
208
+ def save_to_hf_hub(
209
+ self,
210
+ repo_id: str,
211
+ private: Optional[bool] = None,
212
+ subfolder: Optional[str] = None,
213
+ commit_message: Optional[str] = None,
214
+ revision: Optional[str] = None,
215
+ create_pr: bool = False,
216
+ ):
217
+ """
218
+ Uploads all elements of this model to a new HuggingFace Hub repository.
219
+ Args:
220
+ repo_id (str): Repository name for your model/tokenizer in the Hub.
221
+ private (bool, optional): Whether the model/tokenizer is set to private
222
+ subfolder (str, optional): Push to a subfolder of the repo instead of the root
223
+ commit_message (str, optional) — The summary / title / first line of the generated commit. Defaults to: f"Upload {path_in_repo} with huggingface_hub"
224
+ revision (str, optional) — The git revision to commit from. Defaults to the head of the "main" branch.
225
+ create_pr (boolean, optional) — Whether or not to create a Pull Request with that commit. Defaults to False.
226
+ If revision is not set, PR is opened against the "main" branch. If revision is set and is a branch, PR is opened against this branch.
227
+ If revision is set and is not a branch name (example: a commit oid), an RevisionNotFoundError is returned by the server.
228
+
229
+ Returns: The url of the commit of your model in the given repository.
230
+ """
231
+ repo_url = create_repo(repo_id, private=private, exist_ok=True)
232
+
233
+ # Infer complete repo_id from repo_url
234
+ # Can be different from the input `repo_id` if repo_owner was implicit
235
+ _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
236
+
237
+ repo_id = f"{repo_owner}/{repo_name}"
238
+
239
+ # Check if README file already exist in repo
240
+ try:
241
+ get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
242
+ has_readme = True
243
+ except EntryNotFoundError:
244
+ has_readme = False
245
+
246
+ with tempfile.TemporaryDirectory() as root_dir:
247
+ if subfolder is not None:
248
+ save_dir = os.path.join(root_dir, subfolder)
249
+ else:
250
+ save_dir = root_dir
251
+ # save model
252
+ self.save_pretrained(save_dir)
253
+ # Add readme if does not exist
254
+ logger.info("README.md not found, adding the default README.md")
255
+ if not has_readme:
256
+ with open(os.path.join(root_dir, "README.md"), "w") as f:
257
+ f.write(f"---\nlibrary_name: ppdiffusers\n---\n# {repo_id}")
258
+
259
+ # Upload model and return
260
+ logger.info(f"Pushing to the {repo_id}. This might take a while")
261
+ return upload_folder(
262
+ repo_id=repo_id,
263
+ repo_type="model",
264
+ folder_path=root_dir,
265
+ commit_message=commit_message,
266
+ revision=revision,
267
+ create_pr=create_pr,
268
+ )
269
+
270
+ @classmethod
271
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
272
+ r"""
273
+ Instantiate a pretrained paddle model from a pre-trained model configuration.
274
+
275
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
276
+ the model, you should first set it back in training mode with `model.train()`.
277
+
278
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
279
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
280
+ task.
281
+
282
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
283
+ weights are discarded.
284
+
285
+ Parameters:
286
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
287
+ Can be either:
288
+
289
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
290
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
291
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
292
+ `./my_model_directory/`.
293
+
294
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
295
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
296
+ standard cache should not be used.
297
+ paddle_dtype (`str` or `paddle.dtype`, *optional*):
298
+ Override the default `paddle.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
299
+ will be automatically derived from the model's weights.
300
+ output_loading_info(`bool`, *optional*, defaults to `False`):
301
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
302
+ subfolder (`str`, *optional*, defaults to `""`):
303
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
304
+ huggingface.co or downloaded locally), you can specify the folder name here.
305
+ from_hf_hub (bool, *optional*):
306
+ Whether to load from Hugging Face Hub. Defaults to False
307
+ """
308
+ from_hf_hub = kwargs.pop("from_hf_hub", False)
309
+ if from_hf_hub:
310
+ cache_dir = kwargs.pop("cache_dir", HF_CACHE)
311
+ else:
312
+ cache_dir = kwargs.pop("cache_dir", PPDIFFUSERS_CACHE)
313
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
314
+ output_loading_info = kwargs.pop("output_loading_info", False)
315
+ paddle_dtype = kwargs.pop("paddle_dtype", None)
316
+ subfolder = kwargs.pop("subfolder", None)
317
+ ignore_keys = kwargs.pop("ignore_keys", [])
318
+
319
+ # Load config if we don't provide a configuration
320
+ config_path = pretrained_model_name_or_path
321
+
322
+ model_file = None
323
+ if model_file is None:
324
+ model_file = _get_model_file(
325
+ pretrained_model_name_or_path,
326
+ weights_name=WEIGHTS_NAME,
327
+ cache_dir=cache_dir,
328
+ subfolder=subfolder,
329
+ from_hf_hub=from_hf_hub,
330
+ )
331
+
332
+ config, unused_kwargs = cls.load_config(
333
+ config_path,
334
+ cache_dir=cache_dir,
335
+ return_unused_kwargs=True,
336
+ subfolder=subfolder,
337
+ from_hf_hub=from_hf_hub,
338
+ **kwargs,
339
+ )
340
+ model = cls.from_config(config, **unused_kwargs)
341
+
342
+ state_dict = load_dict(model_file, map_location="cpu")
343
+
344
+ keys = list(state_dict.keys())
345
+ for k in keys:
346
+ for ik in ignore_keys:
347
+ if k.startswith(ik):
348
+ logger.warning("Deleting key {} from state_dict.".format(k))
349
+ del state_dict[k]
350
+
351
+ dtype = set(v.dtype for v in state_dict.values())
352
+
353
+ if len(dtype) > 1 and paddle.float32 not in dtype:
354
+ raise ValueError(
355
+ f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
356
+ f" make sure that {model_file} weights have only one dtype."
357
+ )
358
+ elif len(dtype) > 1 and paddle.float32 in dtype:
359
+ dtype = paddle.float32
360
+ else:
361
+ dtype = dtype.pop()
362
+
363
+ # move model to correct dtype
364
+ model = model.to(dtype=dtype)
365
+
366
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
367
+ model,
368
+ state_dict,
369
+ model_file,
370
+ pretrained_model_name_or_path,
371
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
372
+ )
373
+
374
+ loading_info = {
375
+ "missing_keys": missing_keys,
376
+ "unexpected_keys": unexpected_keys,
377
+ "mismatched_keys": mismatched_keys,
378
+ "error_msgs": error_msgs,
379
+ }
380
+
381
+ if paddle_dtype is not None and not isinstance(paddle_dtype, paddle.dtype):
382
+ raise ValueError(
383
+ f"{paddle_dtype} needs to be of type `paddle.dtype`, e.g. `paddle.float16`, but is {type(paddle_dtype)}."
384
+ )
385
+ elif paddle_dtype is not None:
386
+ model = model.to(dtype=paddle_dtype)
387
+
388
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
389
+
390
+ # Set model in evaluation mode to deactivate DropOut modules by default
391
+ model.eval()
392
+ if output_loading_info:
393
+ return model, loading_info
394
+
395
+ return model
396
+
397
+ @classmethod
398
+ def _load_pretrained_model(
399
+ cls,
400
+ model,
401
+ state_dict,
402
+ resolved_archive_file,
403
+ pretrained_model_name_or_path,
404
+ ignore_mismatched_sizes=False,
405
+ ):
406
+ # Retrieve missing & unexpected_keys
407
+ model_state_dict = model.state_dict()
408
+ loaded_keys = [k for k in state_dict.keys()]
409
+
410
+ expected_keys = list(model_state_dict.keys())
411
+
412
+ original_loaded_keys = loaded_keys
413
+
414
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
415
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
416
+
417
+ # Make sure we are able to load base models as well as derived models (with heads)
418
+ model_to_load = model
419
+
420
+ def _find_mismatched_keys(
421
+ state_dict,
422
+ model_state_dict,
423
+ loaded_keys,
424
+ ignore_mismatched_sizes,
425
+ ):
426
+ mismatched_keys = []
427
+ if ignore_mismatched_sizes:
428
+ for checkpoint_key in loaded_keys:
429
+ model_key = checkpoint_key
430
+
431
+ if model_key in model_state_dict and list(state_dict[checkpoint_key].shape) != list(
432
+ model_state_dict[model_key].shape
433
+ ):
434
+ mismatched_keys.append(
435
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
436
+ )
437
+ del state_dict[checkpoint_key]
438
+ return mismatched_keys
439
+
440
+ if state_dict is not None:
441
+ # Whole checkpoint
442
+ mismatched_keys = _find_mismatched_keys(
443
+ state_dict,
444
+ model_state_dict,
445
+ original_loaded_keys,
446
+ ignore_mismatched_sizes,
447
+ )
448
+ error_msgs = ""
449
+ model_to_load.load_dict(state_dict)
450
+
451
+ if len(error_msgs) > 0:
452
+ error_msg = "\n\t".join(error_msgs)
453
+ if "size mismatch" in error_msg:
454
+ error_msg += (
455
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
456
+ )
457
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
458
+
459
+ if len(unexpected_keys) > 0:
460
+ logger.warning(
461
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
462
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
463
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
464
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
465
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
466
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
467
+ " identical (initializing a BertForSequenceClassification model from a"
468
+ " BertForSequenceClassification model)."
469
+ )
470
+ else:
471
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
472
+ if len(missing_keys) > 0:
473
+ logger.warning(
474
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
475
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
476
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
477
+ )
478
+ elif len(mismatched_keys) == 0:
479
+ logger.info(
480
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
481
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
482
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
483
+ " without further training."
484
+ )
485
+ if len(mismatched_keys) > 0:
486
+ mismatched_warning = "\n".join(
487
+ [
488
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
489
+ for key, shape1, shape2 in mismatched_keys
490
+ ]
491
+ )
492
+ logger.warning(
493
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
494
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
495
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
496
+ " able to use it for predictions and inference."
497
+ )
498
+
499
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
500
+
501
+ @property
502
+ def device(self):
503
+ """
504
+ `paddle.place`: The device on which the module is (assuming that all the module parameters are on the same
505
+ device).
506
+ """
507
+ return get_parameter_device(self)
508
+
509
+ @property
510
+ def dtype(self) -> paddle.dtype:
511
+ """
512
+ `paddle.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
513
+ """
514
+ return get_parameter_dtype(self)
515
+
516
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
517
+ """
518
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
519
+
520
+ Args:
521
+ only_trainable (`bool`, *optional*, defaults to `False`):
522
+ Whether or not to return only the number of trainable parameters
523
+
524
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
525
+ Whether or not to return only the number of non-embeddings parameters
526
+
527
+ Returns:
528
+ `int`: The number of parameters.
529
+ """
530
+
531
+ if exclude_embeddings:
532
+ embedding_param_names = [
533
+ f"{name}.weight"
534
+ for name, module_type in self.named_sublayers(include_self=True)
535
+ if isinstance(module_type, nn.Embedding)
536
+ ]
537
+ non_embedding_parameters = [
538
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
539
+ ]
540
+ return sum(p.numel() for p in non_embedding_parameters if not p.stop_gradient or not only_trainable)
541
+ else:
542
+ return sum(p.numel() for p in self.parameters() if not p.stop_gradient or not only_trainable)
543
+
544
+
545
+ def unwrap_model(model: nn.Layer) -> nn.Layer:
546
+ """
547
+ Recursively unwraps a model from potential containers (as used in distributed training).
548
+
549
+ Args:
550
+ model (`nn.Layer`): The model to unwrap.
551
+ """
552
+ # since there could be multiple levels of wrapping, unwrap recursively
553
+ if hasattr(model, "_layers"):
554
+ return unwrap_model(model._layers)
555
+ else:
556
+ return model
557
+
558
+
559
+ def _get_model_file(
560
+ pretrained_model_name_or_path,
561
+ *,
562
+ weights_name,
563
+ subfolder,
564
+ cache_dir,
565
+ from_hf_hub,
566
+ ):
567
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
568
+ if os.path.isdir(pretrained_model_name_or_path):
569
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
570
+ # Load from a PyTorch checkpoint
571
+ model_file = os.path.join(pretrained_model_name_or_path, weights_name)
572
+ elif subfolder is not None and os.path.isfile(
573
+ os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
574
+ ):
575
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
576
+ else:
577
+ raise EnvironmentError(
578
+ f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
579
+ )
580
+ return model_file
581
+ elif from_hf_hub:
582
+ model_file = hf_hub_download(
583
+ repo_id=pretrained_model_name_or_path,
584
+ filename=weights_name,
585
+ cache_dir=cache_dir,
586
+ subfolder=subfolder,
587
+ library_name="PPDiffusers",
588
+ library_version=__version__,
589
+ )
590
+ return model_file
591
+ else:
592
+ try:
593
+ # Load from URL or cache if already cached
594
+ model_file = ppdiffusers_bos_download(
595
+ pretrained_model_name_or_path,
596
+ filename=weights_name,
597
+ subfolder=subfolder,
598
+ cache_dir=cache_dir,
599
+ )
600
+ except HTTPError as err:
601
+ raise EnvironmentError(
602
+ "There was a specific connection error when trying to load" f" {pretrained_model_name_or_path}:\n{err}"
603
+ )
604
+ except ValueError:
605
+ raise EnvironmentError(
606
+ f"We couldn't connect to '{DOWNLOAD_SERVER}' to load this model, couldn't find it"
607
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
608
+ f" directory containing a file named {weights_name} or"
609
+ " \nCheckout your internet connection or see how to run the library in"
610
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
611
+ )
612
+ except EnvironmentError:
613
+ raise EnvironmentError(
614
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
615
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
616
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
617
+ f"containing a file named {weights_name}"
618
+ )
619
+ return model_file
ppdiffusers/models/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # flake8: noqa
16
+
17
+ from ..utils import is_paddle_available
18
+
19
+ if is_paddle_available():
20
+ from .attention import Transformer2DModel
21
+ from .prior_transformer import PriorTransformer
22
+ from .unet_1d import UNet1DModel
23
+ from .unet_2d import UNet2DModel
24
+ from .unet_2d_condition import UNet2DConditionModel
25
+ from .vae import AutoencoderKL, VQModel
ppdiffusers/models/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (511 Bytes). View file
 
ppdiffusers/models/__pycache__/attention.cpython-37.pyc ADDED
Binary file (22.5 kB). View file
 
ppdiffusers/models/__pycache__/cross_attention.cpython-37.pyc ADDED
Binary file (10.5 kB). View file
 
ppdiffusers/models/__pycache__/embeddings.cpython-37.pyc ADDED
Binary file (5.68 kB). View file
 
ppdiffusers/models/__pycache__/prior_transformer.cpython-37.pyc ADDED
Binary file (7.11 kB). View file
 
ppdiffusers/models/__pycache__/resnet.cpython-37.pyc ADDED
Binary file (19.6 kB). View file
 
ppdiffusers/models/__pycache__/unet_1d.cpython-37.pyc ADDED
Binary file (7.22 kB). View file
 
ppdiffusers/models/__pycache__/unet_1d_blocks.cpython-37.pyc ADDED
Binary file (17.4 kB). View file
 
ppdiffusers/models/__pycache__/unet_2d.cpython-37.pyc ADDED
Binary file (8.18 kB). View file
 
ppdiffusers/models/__pycache__/unet_2d_blocks.cpython-37.pyc ADDED
Binary file (36.7 kB). View file
 
ppdiffusers/models/__pycache__/unet_2d_condition.cpython-37.pyc ADDED
Binary file (15.7 kB). View file
 
ppdiffusers/models/__pycache__/vae.cpython-37.pyc ADDED
Binary file (16.9 kB). View file
 
ppdiffusers/models/attention.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
+ import paddle
20
+ import paddle.nn.functional as F
21
+ from paddle import nn
22
+
23
+ from ..configuration_utils import ConfigMixin, register_to_config
24
+ from ..modeling_utils import ModelMixin
25
+ from ..models.embeddings import ImagePositionalEmbeddings
26
+ from ..utils import BaseOutput
27
+ from .cross_attention import CrossAttention
28
+
29
+
30
+ @dataclass
31
+ class Transformer2DModelOutput(BaseOutput):
32
+ """
33
+ Args:
34
+ sample (`paddle.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
35
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
36
+ for the unnoised latent pixels.
37
+ """
38
+
39
+ sample: paddle.Tensor
40
+
41
+
42
+ class Transformer2DModel(ModelMixin, ConfigMixin):
43
+ """
44
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
45
+ embeddings) inputs.
46
+
47
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
48
+ transformer action. Finally, reshape to image.
49
+
50
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
51
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
52
+ classes of unnoised image.
53
+
54
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
55
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
56
+
57
+ Parameters:
58
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
59
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
60
+ in_channels (`int`, *optional*):
61
+ Pass if the input is continuous. The number of channels in the input and output.
62
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
63
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
64
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
65
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
66
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
67
+ `ImagePositionalEmbeddings`.
68
+ num_vector_embeds (`int`, *optional*):
69
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
70
+ Includes the class for the masked latent pixel.
71
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
72
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
73
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
74
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
75
+ up to but not more than steps than `num_embeds_ada_norm`.
76
+ attention_bias (`bool`, *optional*):
77
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
78
+ """
79
+
80
+ @register_to_config
81
+ def __init__(
82
+ self,
83
+ num_attention_heads: int = 16,
84
+ attention_head_dim: int = 88,
85
+ in_channels: Optional[int] = None,
86
+ num_layers: int = 1,
87
+ dropout: float = 0.0,
88
+ norm_num_groups: int = 32,
89
+ cross_attention_dim: Optional[int] = None,
90
+ attention_bias: bool = False,
91
+ sample_size: Optional[int] = None,
92
+ num_vector_embeds: Optional[int] = None,
93
+ activation_fn: str = "geglu",
94
+ num_embeds_ada_norm: Optional[int] = None,
95
+ use_linear_projection: bool = False,
96
+ only_cross_attention: bool = False,
97
+ upcast_attention: bool = False,
98
+ ):
99
+ super().__init__()
100
+ self.use_linear_projection = use_linear_projection
101
+ self.num_attention_heads = num_attention_heads
102
+ self.attention_head_dim = attention_head_dim
103
+ self.inner_dim = inner_dim = num_attention_heads * attention_head_dim
104
+
105
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
106
+ # Define whether input is continuous or discrete depending on configuration
107
+ self.is_input_continuous = in_channels is not None
108
+ self.is_input_vectorized = num_vector_embeds is not None
109
+
110
+ if self.is_input_continuous and self.is_input_vectorized:
111
+ raise ValueError(
112
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
113
+ " sure that either `in_channels` or `num_vector_embeds` is None."
114
+ )
115
+ elif not self.is_input_continuous and not self.is_input_vectorized:
116
+ raise ValueError(
117
+ f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
118
+ " sure that either `in_channels` or `num_vector_embeds` is not None."
119
+ )
120
+
121
+ # 2. Define input layers
122
+ if self.is_input_continuous:
123
+ self.in_channels = in_channels
124
+
125
+ self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, epsilon=1e-6)
126
+ if use_linear_projection:
127
+ self.proj_in = nn.Linear(in_channels, inner_dim)
128
+ else:
129
+ self.proj_in = nn.Conv2D(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
130
+ elif self.is_input_vectorized:
131
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
132
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
133
+
134
+ self.height = sample_size
135
+ self.width = sample_size
136
+ self.num_vector_embeds = num_vector_embeds
137
+ self.num_latent_pixels = self.height * self.width
138
+
139
+ self.latent_image_embedding = ImagePositionalEmbeddings(
140
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
141
+ )
142
+
143
+ # 3. Define transformers blocks
144
+ self.transformer_blocks = nn.LayerList(
145
+ [
146
+ BasicTransformerBlock(
147
+ inner_dim,
148
+ num_attention_heads,
149
+ attention_head_dim,
150
+ dropout=dropout,
151
+ cross_attention_dim=cross_attention_dim,
152
+ activation_fn=activation_fn,
153
+ num_embeds_ada_norm=num_embeds_ada_norm,
154
+ attention_bias=attention_bias,
155
+ only_cross_attention=only_cross_attention,
156
+ upcast_attention=upcast_attention,
157
+ )
158
+ for d in range(num_layers)
159
+ ]
160
+ )
161
+
162
+ # 4. Define output layers
163
+ if self.is_input_continuous:
164
+ if use_linear_projection:
165
+ self.proj_out = nn.Linear(in_channels, inner_dim)
166
+ else:
167
+ self.proj_out = nn.Conv2D(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
168
+ elif self.is_input_vectorized:
169
+ self.norm_out = nn.LayerNorm(inner_dim)
170
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
171
+
172
+ def forward(
173
+ self,
174
+ hidden_states,
175
+ encoder_hidden_states=None,
176
+ timestep=None,
177
+ cross_attention_kwargs=None,
178
+ return_dict: bool = True,
179
+ ):
180
+ """
181
+ Args:
182
+ hidden_states ( When discrete, `paddle.Tensor` of shape `(batch size, num latent pixels)`.
183
+ When continous, `paddle.Tensor` of shape `(batch size, channel, height, width)`): Input
184
+ hidden_states
185
+ encoder_hidden_states ( `paddle.Tensor` of shape `(batch size, encoder_hidden_states)`, *optional*):
186
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
187
+ self-attention.
188
+ timestep ( `paddle.Tensor`, *optional*):
189
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
190
+ return_dict (`bool`, *optional*, defaults to `True`):
191
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
192
+
193
+ Returns:
194
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
195
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
196
+ tensor.
197
+ """
198
+ # 1. Input
199
+ if self.is_input_continuous:
200
+ _, _, height, width = hidden_states.shape
201
+ residual = hidden_states
202
+ hidden_states = self.norm(hidden_states)
203
+ if not self.use_linear_projection:
204
+ hidden_states = self.proj_in(hidden_states)
205
+ hidden_states = hidden_states.transpose([0, 2, 3, 1]).flatten(1, 2)
206
+ if self.use_linear_projection:
207
+ hidden_states = self.proj_in(hidden_states)
208
+ elif self.is_input_vectorized:
209
+ hidden_states = self.latent_image_embedding(hidden_states.cast("int64"))
210
+
211
+ # 2. Blocks
212
+ for block in self.transformer_blocks:
213
+ hidden_states = block(
214
+ hidden_states,
215
+ encoder_hidden_states=encoder_hidden_states,
216
+ timestep=timestep,
217
+ cross_attention_kwargs=cross_attention_kwargs,
218
+ )
219
+
220
+ # 3. Output
221
+ if self.is_input_continuous:
222
+ if self.use_linear_projection:
223
+ hidden_states = self.proj_out(hidden_states)
224
+ hidden_states = hidden_states.reshape([-1, height, width, self.inner_dim]).transpose([0, 3, 1, 2])
225
+ if not self.use_linear_projection:
226
+ hidden_states = self.proj_out(hidden_states)
227
+ output = hidden_states + residual
228
+ elif self.is_input_vectorized:
229
+ hidden_states = self.norm_out(hidden_states)
230
+ logits = self.out(hidden_states)
231
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
232
+ logits = logits.transpose([0, 2, 1])
233
+
234
+ # log(p(x_0))
235
+ output = F.log_softmax(logits.cast("float64"), axis=1).cast("float32")
236
+
237
+ if not return_dict:
238
+ return (output,)
239
+
240
+ return Transformer2DModelOutput(sample=output)
241
+
242
+
243
+ class AttentionBlock(nn.Layer):
244
+ """
245
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
246
+ to the N-d case.
247
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
248
+ Uses three q, k, v linear layers to compute attention.
249
+
250
+ Parameters:
251
+ channels (`int`): The number of channels in the input and output.
252
+ num_head_channels (`int`, *optional*):
253
+ The number of channels in each head. If None, then `num_heads` = 1.
254
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
255
+ rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
256
+ eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ channels: int,
262
+ num_head_channels: Optional[int] = None,
263
+ norm_num_groups: int = 32,
264
+ rescale_output_factor: float = 1.0,
265
+ eps: float = 1e-5,
266
+ ):
267
+ super().__init__()
268
+ self.channels = channels
269
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
270
+ self.head_dim = self.channels // self.num_heads
271
+ self.scale = 1 / math.sqrt(self.channels / self.num_heads)
272
+
273
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, epsilon=eps)
274
+
275
+ # define q,k,v as linear layers
276
+ self.query = nn.Linear(channels, channels)
277
+ self.key = nn.Linear(channels, channels)
278
+ self.value = nn.Linear(channels, channels)
279
+
280
+ self.rescale_output_factor = rescale_output_factor
281
+ self.proj_attn = nn.Linear(channels, channels)
282
+
283
+ def reshape_heads_to_batch_dim(self, tensor):
284
+ tensor = tensor.reshape([0, 0, self.num_heads, self.head_dim])
285
+ tensor = tensor.transpose([0, 2, 1, 3])
286
+ return tensor
287
+
288
+ def reshape_batch_dim_to_heads(self, tensor):
289
+ tensor = tensor.transpose([0, 2, 1, 3])
290
+ tensor = tensor.reshape([0, 0, tensor.shape[2] * tensor.shape[3]])
291
+ return tensor
292
+
293
+ def forward(self, hidden_states):
294
+ residual = hidden_states
295
+ batch, channel, height, width = hidden_states.shape
296
+
297
+ # norm
298
+ hidden_states = self.group_norm(hidden_states)
299
+
300
+ hidden_states = hidden_states.reshape([batch, channel, height * width]).transpose([0, 2, 1])
301
+
302
+ # proj to q, k, v
303
+ query_proj = self.query(hidden_states)
304
+ key_proj = self.key(hidden_states)
305
+ value_proj = self.value(hidden_states)
306
+
307
+ query_proj = self.reshape_heads_to_batch_dim(query_proj)
308
+ key_proj = self.reshape_heads_to_batch_dim(key_proj)
309
+ value_proj = self.reshape_heads_to_batch_dim(value_proj)
310
+
311
+ # get scores
312
+ attention_scores = paddle.matmul(query_proj, key_proj, transpose_y=True) * self.scale
313
+ attention_probs = F.softmax(attention_scores.cast("float32"), axis=-1).cast(attention_scores.dtype)
314
+
315
+ # compute attention output
316
+ hidden_states = paddle.matmul(attention_probs, value_proj)
317
+
318
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
319
+
320
+ # compute next hidden_states
321
+ hidden_states = self.proj_attn(hidden_states)
322
+ hidden_states = hidden_states.transpose([0, 2, 1]).reshape([batch, channel, height, width])
323
+
324
+ # res connect and rescale
325
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
326
+ return hidden_states
327
+
328
+
329
+ class BasicTransformerBlock(nn.Layer):
330
+ r"""
331
+ A basic Transformer block.
332
+
333
+ Parameters:
334
+ dim (`int`): The number of channels in the input and output.
335
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
336
+ attention_head_dim (`int`): The number of channels in each head.
337
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
338
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
339
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
340
+ num_embeds_ada_norm (:
341
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
342
+ attention_bias (:
343
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
344
+ """
345
+
346
+ def __init__(
347
+ self,
348
+ dim: int,
349
+ num_attention_heads: int,
350
+ attention_head_dim: int,
351
+ dropout=0.0,
352
+ cross_attention_dim: Optional[int] = None,
353
+ activation_fn: str = "geglu",
354
+ num_embeds_ada_norm: Optional[int] = None,
355
+ attention_bias: bool = False,
356
+ only_cross_attention: bool = False,
357
+ upcast_attention: bool = False,
358
+ ):
359
+ super().__init__()
360
+ self.only_cross_attention = only_cross_attention
361
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
362
+
363
+ # 1. Self-Attn
364
+ self.attn1 = CrossAttention(
365
+ query_dim=dim,
366
+ heads=num_attention_heads,
367
+ dim_head=attention_head_dim,
368
+ dropout=dropout,
369
+ bias=attention_bias,
370
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
371
+ upcast_attention=upcast_attention,
372
+ )
373
+
374
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
375
+
376
+ # 2. Cross-Attn
377
+ if cross_attention_dim is not None:
378
+ self.attn2 = CrossAttention(
379
+ query_dim=dim,
380
+ cross_attention_dim=cross_attention_dim,
381
+ heads=num_attention_heads,
382
+ dim_head=attention_head_dim,
383
+ dropout=dropout,
384
+ bias=attention_bias,
385
+ upcast_attention=upcast_attention,
386
+ ) # is self-attn if encoder_hidden_states is none
387
+ else:
388
+ self.attn2 = None
389
+
390
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
391
+
392
+ if cross_attention_dim is not None:
393
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
394
+ else:
395
+ self.norm2 = None
396
+
397
+ # 3. Feed-forward
398
+ self.norm3 = nn.LayerNorm(dim)
399
+
400
+ def forward(
401
+ self,
402
+ hidden_states,
403
+ encoder_hidden_states=None,
404
+ timestep=None,
405
+ attention_mask=None,
406
+ cross_attention_kwargs=None,
407
+ ):
408
+ # 1. Self-Attention
409
+ norm_hidden_states = (
410
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
411
+ )
412
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
413
+ attn_output = self.attn1(
414
+ norm_hidden_states,
415
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
416
+ attention_mask=attention_mask,
417
+ **cross_attention_kwargs,
418
+ )
419
+ hidden_states = attn_output + hidden_states
420
+
421
+ if self.attn2 is not None:
422
+ # 2. Cross-Attention
423
+ norm_hidden_states = (
424
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
425
+ )
426
+ attn_output = self.attn2(
427
+ norm_hidden_states,
428
+ encoder_hidden_states=encoder_hidden_states,
429
+ attention_mask=attention_mask,
430
+ **cross_attention_kwargs,
431
+ )
432
+ hidden_states = attn_output + hidden_states
433
+
434
+ # 3. Feed-forward
435
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
436
+
437
+ return hidden_states
438
+
439
+
440
+ class FeedForward(nn.Layer):
441
+ r"""
442
+ A feed-forward layer.
443
+
444
+ Parameters:
445
+ dim (`int`): The number of channels in the input.
446
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
447
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
448
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
449
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
450
+ """
451
+
452
+ def __init__(
453
+ self,
454
+ dim: int,
455
+ dim_out: Optional[int] = None,
456
+ mult: int = 4,
457
+ dropout: float = 0.0,
458
+ activation_fn: str = "geglu",
459
+ ):
460
+ super().__init__()
461
+ inner_dim = int(dim * mult)
462
+ dim_out = dim_out if dim_out is not None else dim
463
+
464
+ if activation_fn == "gelu":
465
+ act_fn = GELU(dim, inner_dim)
466
+ elif activation_fn == "geglu":
467
+ act_fn = GEGLU(dim, inner_dim)
468
+ elif activation_fn == "geglu-approximate":
469
+ act_fn = ApproximateGELU(dim, inner_dim)
470
+
471
+ self.net = nn.LayerList([])
472
+ # project in
473
+ self.net.append(act_fn)
474
+ # project dropout
475
+ self.net.append(nn.Dropout(dropout))
476
+ # project out
477
+ self.net.append(nn.Linear(inner_dim, dim_out))
478
+
479
+ def forward(self, hidden_states):
480
+ for module in self.net:
481
+ hidden_states = module(hidden_states)
482
+ return hidden_states
483
+
484
+
485
+ class GELU(nn.Layer):
486
+ r"""
487
+ GELU activation function
488
+ """
489
+
490
+ def __init__(self, dim_in: int, dim_out: int):
491
+ super().__init__()
492
+ self.proj = nn.Linear(dim_in, dim_out)
493
+
494
+ def forward(self, hidden_states):
495
+ hidden_states = self.proj(hidden_states)
496
+ hidden_states = F.gelu(hidden_states)
497
+ return hidden_states
498
+
499
+
500
+ # feedforward
501
+ class GEGLU(nn.Layer):
502
+ r"""
503
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
504
+
505
+ Parameters:
506
+ dim_in (`int`): The number of channels in the input.
507
+ dim_out (`int`): The number of channels in the output.
508
+ """
509
+
510
+ def __init__(self, dim_in: int, dim_out: int):
511
+ super().__init__()
512
+ self.proj = nn.Linear(dim_in, dim_out * 2)
513
+
514
+ def forward(self, hidden_states):
515
+ hidden_states, gate = self.proj(hidden_states).chunk(2, axis=-1)
516
+ return hidden_states * F.gelu(gate)
517
+
518
+
519
+ class ApproximateGELU(nn.Layer):
520
+ """
521
+ The approximate form of Gaussian Error Linear Unit (GELU)
522
+
523
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
524
+ """
525
+
526
+ def __init__(self, dim_in: int, dim_out: int):
527
+ super().__init__()
528
+ self.proj = nn.Linear(dim_in, dim_out)
529
+
530
+ def forward(self, x):
531
+ x = self.proj(x)
532
+ return x * F.sigmoid(1.702 * x)
533
+
534
+
535
+ class AdaLayerNorm(nn.Layer):
536
+ """
537
+ Norm layer modified to incorporate timestep embeddings.
538
+ """
539
+
540
+ def __init__(self, embedding_dim, num_embeddings):
541
+ super().__init__()
542
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
543
+ self.silu = nn.Silu()
544
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
545
+ self.norm = nn.LayerNorm(embedding_dim) # elementwise_affine=False
546
+
547
+ def forward(self, x, timestep):
548
+ emb = self.linear(self.silu(self.emb(timestep)))
549
+ scale, shift = paddle.chunk(emb, 2, axis=-1)
550
+ x = self.norm(x) * (1 + scale) + shift
551
+ return x
552
+
553
+
554
+ class DualTransformer2DModel(nn.Layer):
555
+ """
556
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
557
+ Parameters:
558
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
559
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
560
+ in_channels (`int`, *optional*):
561
+ Pass if the input is continuous. The number of channels in the input and output.
562
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
563
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
564
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
565
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
566
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
567
+ `ImagePositionalEmbeddings`.
568
+ num_vector_embeds (`int`, *optional*):
569
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
570
+ Includes the class for the masked latent pixel.
571
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
572
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
573
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
574
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
575
+ up to but not more than steps than `num_embeds_ada_norm`.
576
+ attention_bias (`bool`, *optional*):
577
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
578
+ """
579
+
580
+ def __init__(
581
+ self,
582
+ num_attention_heads: int = 16,
583
+ attention_head_dim: int = 88,
584
+ in_channels: Optional[int] = None,
585
+ num_layers: int = 1,
586
+ dropout: float = 0.0,
587
+ norm_num_groups: int = 32,
588
+ cross_attention_dim: Optional[int] = None,
589
+ attention_bias: bool = False,
590
+ sample_size: Optional[int] = None,
591
+ num_vector_embeds: Optional[int] = None,
592
+ activation_fn: str = "geglu",
593
+ num_embeds_ada_norm: Optional[int] = None,
594
+ ):
595
+ super().__init__()
596
+ self.transformers = nn.LayerList(
597
+ [
598
+ Transformer2DModel(
599
+ num_attention_heads=num_attention_heads,
600
+ attention_head_dim=attention_head_dim,
601
+ in_channels=in_channels,
602
+ num_layers=num_layers,
603
+ dropout=dropout,
604
+ norm_num_groups=norm_num_groups,
605
+ cross_attention_dim=cross_attention_dim,
606
+ attention_bias=attention_bias,
607
+ sample_size=sample_size,
608
+ num_vector_embeds=num_vector_embeds,
609
+ activation_fn=activation_fn,
610
+ num_embeds_ada_norm=num_embeds_ada_norm,
611
+ )
612
+ for _ in range(2)
613
+ ]
614
+ )
615
+
616
+ # Variables that can be set by a pipeline:
617
+
618
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
619
+ self.mix_ratio = 0.5
620
+
621
+ # The shape of `encoder_hidden_states` is expected to be
622
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
623
+ self.condition_lengths = [77, 257]
624
+
625
+ # Which transformer to use to encode which condition.
626
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
627
+ self.transformer_index_for_condition = [1, 0]
628
+
629
+ def forward(
630
+ self,
631
+ hidden_states,
632
+ encoder_hidden_states,
633
+ timestep=None,
634
+ attention_mask=None,
635
+ cross_attention_kwargs=None,
636
+ return_dict: bool = True,
637
+ ):
638
+ """
639
+ Args:
640
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
641
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
642
+ hidden_states
643
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
644
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
645
+ self-attention.
646
+ timestep ( `torch.long`, *optional*):
647
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
648
+ attention_mask (`torch.FloatTensor`, *optional*):
649
+ Optional attention mask to be applied in CrossAttention
650
+ return_dict (`bool`, *optional*, defaults to `True`):
651
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
652
+
653
+ Returns:
654
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
655
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
656
+ tensor.
657
+ """
658
+ input_states = hidden_states
659
+
660
+ encoded_states = []
661
+ tokens_start = 0
662
+ # attention_mask is not used yet
663
+ for i in range(2):
664
+ # for each of the two transformers, pass the corresponding condition tokens
665
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
666
+ transformer_index = self.transformer_index_for_condition[i]
667
+ encoded_state = self.transformers[transformer_index](
668
+ input_states,
669
+ encoder_hidden_states=condition_state,
670
+ timestep=timestep,
671
+ cross_attention_kwargs=cross_attention_kwargs,
672
+ return_dict=False,
673
+ )[0]
674
+ encoded_states.append(encoded_state - input_states)
675
+ tokens_start += self.condition_lengths[i]
676
+
677
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
678
+ output_states = output_states + input_states
679
+
680
+ if not return_dict:
681
+ return (output_states,)
682
+
683
+ return Transformer2DModelOutput(sample=output_states)
ppdiffusers/models/cross_attention.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Union
15
+
16
+ import paddle
17
+ import paddle.nn as nn
18
+ import paddle.nn.functional as F
19
+
20
+ from ..initializer import normal_, zeros_
21
+
22
+
23
+ class CrossAttention(nn.Layer):
24
+ r"""
25
+ A cross attention layer.
26
+
27
+ Parameters:
28
+ query_dim (`int`): The number of channels in the query.
29
+ cross_attention_dim (`int`, *optional*):
30
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
31
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
32
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
33
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
34
+ bias (`bool`, *optional*, defaults to False):
35
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ query_dim: int,
41
+ cross_attention_dim: Optional[int] = None,
42
+ heads: int = 8,
43
+ dim_head: int = 64,
44
+ dropout: float = 0.0,
45
+ bias=False,
46
+ upcast_attention: bool = False,
47
+ upcast_softmax: bool = False,
48
+ added_kv_proj_dim: Optional[int] = None,
49
+ norm_num_groups: Optional[int] = None,
50
+ processor: Optional["AttnProcessor"] = None,
51
+ ):
52
+ super().__init__()
53
+ inner_dim = dim_head * heads
54
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
55
+ self.upcast_attention = upcast_attention
56
+ self.upcast_softmax = upcast_softmax
57
+
58
+ self.scale = dim_head**-0.5
59
+ self.num_heads = heads
60
+ self.head_dim = inner_dim // heads
61
+ # for slice_size > 0 the attention score computation
62
+ # is split across the batch axis to save memory
63
+ # You can set slice_size with `set_attention_slice`
64
+ self.sliceable_head_dim = heads
65
+
66
+ self.added_kv_proj_dim = added_kv_proj_dim
67
+
68
+ if norm_num_groups is not None:
69
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, epsilon=1e-5)
70
+ else:
71
+ self.group_norm = None
72
+
73
+ self.to_q = nn.Linear(query_dim, inner_dim, bias_attr=bias)
74
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias_attr=bias)
75
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias_attr=bias)
76
+
77
+ if self.added_kv_proj_dim is not None:
78
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
79
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
80
+
81
+ self.to_out = nn.LayerList([])
82
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
83
+ self.to_out.append(nn.Dropout(dropout))
84
+
85
+ # set attention processor
86
+ processor = processor if processor is not None else CrossAttnProcessor()
87
+ self.set_processor(processor)
88
+
89
+ def set_attention_slice(self, slice_size):
90
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
91
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
92
+
93
+ if slice_size is not None and self.added_kv_proj_dim is not None:
94
+ processor = SlicedAttnAddedKVProcessor(slice_size)
95
+ elif slice_size is not None:
96
+ processor = SlicedAttnProcessor(slice_size)
97
+ elif self.added_kv_proj_dim is not None:
98
+ processor = CrossAttnAddedKVProcessor()
99
+ else:
100
+ processor = CrossAttnProcessor()
101
+
102
+ self.set_processor(processor)
103
+
104
+ def set_processor(self, processor: "AttnProcessor"):
105
+ self.processor = processor
106
+
107
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
108
+ # The `CrossAttention` class can call different attention processors / attention functions
109
+ # here we simply pass along all tensors to the selected processor class
110
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
111
+ return self.processor(
112
+ self,
113
+ hidden_states,
114
+ encoder_hidden_states=encoder_hidden_states,
115
+ attention_mask=attention_mask,
116
+ **cross_attention_kwargs,
117
+ )
118
+
119
+ def batch_to_head_dim(self, tensor):
120
+ tensor = tensor.transpose([0, 2, 1, 3])
121
+ tensor = tensor.reshape([0, 0, tensor.shape[2] * tensor.shape[3]])
122
+ return tensor
123
+
124
+ def head_to_batch_dim(self, tensor):
125
+ tensor = tensor.reshape([0, 0, self.num_heads, self.head_dim])
126
+ tensor = tensor.transpose([0, 2, 1, 3])
127
+ return tensor
128
+
129
+ def get_attention_scores(self, query, key, attention_mask=None):
130
+ if self.upcast_attention:
131
+ query = query.cast("float32")
132
+ key = key.cast("float32")
133
+
134
+ attention_scores = paddle.matmul(query, key, transpose_y=True) * self.scale
135
+
136
+ if attention_mask is not None:
137
+ attention_scores = attention_scores + attention_mask
138
+
139
+ if self.upcast_softmax:
140
+ attention_scores = attention_scores.cast("float32")
141
+
142
+ attention_probs = F.softmax(attention_scores, axis=-1)
143
+ if self.upcast_softmax:
144
+ attention_probs = attention_probs.cast(query.dtype)
145
+
146
+ return attention_probs
147
+
148
+ def prepare_attention_mask(self, attention_mask, target_length):
149
+ if attention_mask is None:
150
+ return attention_mask
151
+
152
+ if attention_mask.shape[-1] != target_length:
153
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0, data_format="NCL")
154
+ attention_mask = attention_mask.repeat_interleave(self.num_heads, axis=0)
155
+ return attention_mask
156
+
157
+
158
+ class CrossAttnProcessor:
159
+ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
160
+ batch_size, sequence_length, _ = hidden_states.shape
161
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
162
+ attention_mask = (
163
+ attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]])
164
+ if attention_mask is not None
165
+ else None
166
+ )
167
+
168
+ query = attn.to_q(hidden_states)
169
+ query = attn.head_to_batch_dim(query)
170
+
171
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
172
+ key = attn.to_k(encoder_hidden_states)
173
+ value = attn.to_v(encoder_hidden_states)
174
+ key = attn.head_to_batch_dim(key)
175
+ value = attn.head_to_batch_dim(value)
176
+
177
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
178
+ hidden_states = paddle.matmul(attention_probs, value)
179
+ hidden_states = attn.batch_to_head_dim(hidden_states)
180
+
181
+ # linear proj
182
+ hidden_states = attn.to_out[0](hidden_states)
183
+ # dropout
184
+ hidden_states = attn.to_out[1](hidden_states)
185
+
186
+ return hidden_states
187
+
188
+
189
+ class LoRALinearLayer(nn.Layer):
190
+ def __init__(self, in_features, out_features, rank=4):
191
+ super().__init__()
192
+
193
+ if rank > min(in_features, out_features):
194
+ raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
195
+
196
+ self.down = nn.Linear(in_features, rank, bias_attr=False)
197
+ self.up = nn.Linear(rank, out_features, bias_attr=False)
198
+ self.scale = 1.0
199
+
200
+ normal_(self.down.weight, std=1 / rank)
201
+ zeros_(self.up.weight)
202
+
203
+ def forward(self, hidden_states):
204
+ orig_dtype = hidden_states.dtype
205
+ dtype = self.down.weight.dtype
206
+
207
+ down_hidden_states = self.down(hidden_states.cast(dtype))
208
+ up_hidden_states = self.up(down_hidden_states)
209
+
210
+ return up_hidden_states.cast(orig_dtype)
211
+
212
+
213
+ class LoRACrossAttnProcessor(nn.Layer):
214
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
215
+ super().__init__()
216
+
217
+ self.hidden_size = hidden_size
218
+ self.cross_attention_dim = cross_attention_dim
219
+ self.rank = rank
220
+
221
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
222
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
223
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
224
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
225
+
226
+ def __call__(
227
+ self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
228
+ ):
229
+ batch_size, sequence_length, _ = hidden_states.shape
230
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
231
+ attention_mask = (
232
+ attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]])
233
+ if attention_mask is not None
234
+ else None
235
+ )
236
+
237
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
238
+ query = attn.head_to_batch_dim(query)
239
+
240
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
241
+
242
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
243
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
244
+
245
+ key = attn.head_to_batch_dim(key)
246
+ value = attn.head_to_batch_dim(value)
247
+
248
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
249
+ hidden_states = paddle.matmul(attention_probs, value)
250
+ hidden_states = attn.batch_to_head_dim(hidden_states)
251
+
252
+ # linear proj
253
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
254
+ # dropout
255
+ hidden_states = attn.to_out[1](hidden_states)
256
+
257
+ return hidden_states
258
+
259
+
260
+ class CrossAttnAddedKVProcessor:
261
+ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
262
+ residual = hidden_states
263
+ hidden_states = hidden_states.reshape([hidden_states.shape[0], hidden_states.shape[1], -1]).transpose(
264
+ [0, 2, 1]
265
+ )
266
+ batch_size, sequence_length, _ = hidden_states.shape
267
+ encoder_hidden_states = encoder_hidden_states.transpose([0, 2, 1])
268
+
269
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
270
+ attention_mask = (
271
+ attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]])
272
+ if attention_mask is not None
273
+ else None
274
+ )
275
+
276
+ hidden_states = attn.group_norm(hidden_states.transpose([0, 2, 1])).transpose([0, 2, 1])
277
+
278
+ query = attn.to_q(hidden_states)
279
+ query = attn.head_to_batch_dim(query)
280
+
281
+ key = attn.to_k(hidden_states)
282
+ value = attn.to_v(hidden_states)
283
+ key = attn.head_to_batch_dim(key)
284
+ value = attn.head_to_batch_dim(value)
285
+
286
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
287
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
288
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
289
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
290
+
291
+ key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2)
292
+ value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2)
293
+
294
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
295
+ hidden_states = paddle.matmul(attention_probs, value)
296
+ hidden_states = attn.batch_to_head_dim(hidden_states)
297
+
298
+ # linear proj
299
+ hidden_states = attn.to_out[0](hidden_states)
300
+ # dropout
301
+ hidden_states = attn.to_out[1](hidden_states)
302
+
303
+ hidden_states = hidden_states.transpose([0, 2, 1]).reshape(residual.shape)
304
+ hidden_states = hidden_states + residual
305
+
306
+ return hidden_states
307
+
308
+
309
+ class SlicedAttnProcessor:
310
+ def __init__(self, slice_size):
311
+ self.slice_size = slice_size
312
+
313
+ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
314
+ batch_size, sequence_length, _ = hidden_states.shape
315
+
316
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
317
+
318
+ query = attn.to_q(hidden_states)
319
+ query = attn.head_to_batch_dim(query)
320
+
321
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
322
+ key = attn.to_k(encoder_hidden_states)
323
+ value = attn.to_v(encoder_hidden_states)
324
+ key = attn.head_to_batch_dim(key)
325
+ value = attn.head_to_batch_dim(value)
326
+
327
+ query = query.flatten(0, 1)
328
+ key = key.flatten(0, 1)
329
+ value = value.flatten(0, 1)
330
+
331
+ batch_size_attention = query.shape[0]
332
+ hidden_states = paddle.zeros((batch_size_attention, sequence_length, attn.head_dim), dtype=query.dtype)
333
+
334
+ for i in range(hidden_states.shape[0] // self.slice_size):
335
+ start_idx = i * self.slice_size
336
+ end_idx = (i + 1) * self.slice_size
337
+
338
+ query_slice = query[start_idx:end_idx]
339
+ key_slice = key[start_idx:end_idx]
340
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
341
+
342
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
343
+
344
+ attn_slice = paddle.matmul(attn_slice, value[start_idx:end_idx])
345
+
346
+ hidden_states[start_idx:end_idx] = attn_slice
347
+
348
+ # reshape back to [bs, num_heads, seqlen, head_dim]
349
+ hidden_states = hidden_states.reshape([-1, attn.num_heads, sequence_length, attn.head_dim])
350
+ # reshape hidden_states
351
+ hidden_states = attn.batch_to_head_dim(hidden_states)
352
+
353
+ # linear proj
354
+ hidden_states = attn.to_out[0](hidden_states)
355
+ # dropout
356
+ hidden_states = attn.to_out[1](hidden_states)
357
+
358
+ return hidden_states
359
+
360
+
361
+ class SlicedAttnAddedKVProcessor:
362
+ def __init__(self, slice_size):
363
+ self.slice_size = slice_size
364
+
365
+ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None):
366
+ residual = hidden_states
367
+ hidden_states = hidden_states.reshape([hidden_states.shape[0], hidden_states.shape[1], -1]).transpose(
368
+ [0, 2, 1]
369
+ )
370
+ encoder_hidden_states = encoder_hidden_states.transpose([0, 2, 1])
371
+
372
+ batch_size, sequence_length, _ = hidden_states.shape
373
+
374
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
375
+
376
+ hidden_states = attn.group_norm(hidden_states.transpose([0, 2, 1])).transpose([0, 2, 1])
377
+
378
+ query = attn.to_q(hidden_states)
379
+ query = attn.head_to_batch_dim(query)
380
+
381
+ key = attn.to_k(hidden_states)
382
+ value = attn.to_v(hidden_states)
383
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
384
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
385
+
386
+ key = attn.head_to_batch_dim(key)
387
+ value = attn.head_to_batch_dim(value)
388
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
389
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
390
+
391
+ key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2)
392
+ value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2)
393
+
394
+ query = query.flatten(0, 1)
395
+ key = key.flatten(0, 1)
396
+ value = value.flatten(0, 1)
397
+
398
+ batch_size_attention = query.shape[0]
399
+ hidden_states = paddle.zeros((batch_size_attention, sequence_length, attn.head_dim), dtype=query.dtype)
400
+ for i in range(hidden_states.shape[0] // self.slice_size):
401
+ start_idx = i * self.slice_size
402
+ end_idx = (i + 1) * self.slice_size
403
+
404
+ query_slice = query[start_idx:end_idx]
405
+ key_slice = key[start_idx:end_idx]
406
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
407
+
408
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
409
+
410
+ attn_slice = paddle.matmul(attn_slice, value[start_idx:end_idx])
411
+
412
+ hidden_states[start_idx:end_idx] = attn_slice
413
+
414
+ # reshape back to [bs, num_heads, seqlen, head_dim]
415
+ hidden_states = hidden_states.reshape([-1, attn.num_heads, sequence_length, attn.head_dim])
416
+ # reshape hidden_states
417
+ hidden_states = attn.batch_to_head_dim(hidden_states)
418
+
419
+ # linear proj
420
+ hidden_states = attn.to_out[0](hidden_states)
421
+ # dropout
422
+ hidden_states = attn.to_out[1](hidden_states)
423
+
424
+ hidden_states = hidden_states.transpose([0, 2, 1]).reshape(residual.shape)
425
+ hidden_states = hidden_states + residual
426
+
427
+ return hidden_states
428
+
429
+
430
+ AttnProcessor = Union[
431
+ CrossAttnProcessor,
432
+ SlicedAttnProcessor,
433
+ CrossAttnAddedKVProcessor,
434
+ SlicedAttnAddedKVProcessor,
435
+ ]
ppdiffusers/models/ema.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import paddle
17
+ from paddle import nn
18
+
19
+
20
+ class LitEma(nn.Layer):
21
+ """
22
+ Exponential Moving Average (EMA) of model updates
23
+
24
+ Parameters:
25
+ model: The model architecture for apply EMA.
26
+ decay: The exponential decay. Default 0.9999.
27
+ use_num_updates: Whether to use number of updates when computing
28
+ averages.
29
+ """
30
+
31
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
32
+ super().__init__()
33
+ if decay < 0.0 or decay > 1.0:
34
+ raise ValueError("Decay must be between 0 and 1")
35
+
36
+ self.m_name2s_name = {}
37
+ self.register_buffer("decay", paddle.to_tensor(decay, dtype=paddle.float32))
38
+ self.register_buffer(
39
+ "num_updates",
40
+ paddle.to_tensor(0, dtype=paddle.int64) if use_num_upates else paddle.to_tensor(-1, dtype=paddle.int64),
41
+ )
42
+
43
+ for name, p in model.named_parameters():
44
+ if not p.stop_gradient:
45
+ # remove as '.'-character is not allowed in buffers
46
+ s_name = name.replace(".", "")
47
+ self.m_name2s_name.update({name: s_name})
48
+ self.register_buffer(s_name, p.clone().detach())
49
+
50
+ self.collected_params = []
51
+
52
+ def forward(self, model):
53
+ decay = self.decay
54
+
55
+ if self.num_updates >= 0:
56
+ self.num_updates += 1
57
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
58
+
59
+ one_minus_decay = 1.0 - decay
60
+
61
+ with paddle.no_grad():
62
+ m_param = dict(model.named_parameters())
63
+ shadow_params = dict(self.named_buffers())
64
+
65
+ for key in m_param:
66
+ if not m_param[key].stop_gradient:
67
+ sname = self.m_name2s_name[key]
68
+ shadow_params[sname].scale_(decay)
69
+ shadow_params[sname].add_(m_param[key] * one_minus_decay)
70
+ else:
71
+ assert key not in self.m_name2s_name
72
+
73
+ def copy_to(self, model):
74
+ m_param = dict(model.named_parameters())
75
+ shadow_params = dict(self.named_buffers())
76
+ for key in m_param:
77
+ if not m_param[key].stop_gradient:
78
+ m_param[key].copy_(shadow_params[self.m_name2s_name[key]], True)
79
+ else:
80
+ assert key not in self.m_name2s_name
81
+
82
+ def store(self, parameters):
83
+ """
84
+ Save the current parameters for restoring later.
85
+ Args:
86
+ parameters: Iterable of `EagerParamBase`; the parameters to be
87
+ temporarily stored.
88
+ """
89
+ self.collected_params = [param.clone() for param in parameters]
90
+
91
+ def restore(self, parameters):
92
+ """
93
+ Restore the parameters stored with the `store` method.
94
+ Useful to validate the model with EMA parameters without affecting the
95
+ original optimization process. Store the parameters before the
96
+ `copy_to` method. After validation (or model saving), use this to
97
+ restore the former parameters.
98
+ Args:
99
+ parameters: Iterable of `EagerParamBase`; the parameters to be
100
+ updated with the stored parameters.
101
+ """
102
+ for c_param, param in zip(self.collected_params, parameters):
103
+ param.copy_(c_param, True)
ppdiffusers/models/embeddings.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import math
16
+
17
+ import numpy as np
18
+ import paddle
19
+ from paddle import nn
20
+
21
+
22
+ def get_timestep_embedding(
23
+ timesteps: paddle.Tensor,
24
+ embedding_dim: int,
25
+ flip_sin_to_cos: bool = False,
26
+ downscale_freq_shift: float = 1,
27
+ scale: float = 1,
28
+ max_period: int = 10000,
29
+ ):
30
+ """
31
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
32
+
33
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
34
+ These may be fractional.
35
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
36
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
37
+ """
38
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
39
+
40
+ half_dim = embedding_dim // 2
41
+ exponent = -math.log(max_period) * paddle.arange(start=0, end=half_dim, dtype="float32")
42
+ exponent = exponent / (half_dim - downscale_freq_shift)
43
+
44
+ emb = paddle.exp(exponent)
45
+ emb = timesteps[:, None].cast("float32") * emb[None, :]
46
+
47
+ # scale embeddings
48
+ emb = scale * emb
49
+
50
+ # concat sine and cosine embeddings
51
+ emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1)
52
+
53
+ # flip sine and cosine embeddings
54
+ if flip_sin_to_cos:
55
+ emb = paddle.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
56
+
57
+ # zero pad
58
+ if embedding_dim % 2 == 1:
59
+ emb = paddle.concat(emb, paddle.zeros([emb.shape[0], 1]), axis=-1)
60
+ return emb
61
+
62
+
63
+ class TimestepEmbedding(nn.Layer):
64
+ def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
65
+ super().__init__()
66
+
67
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
68
+ self.act = None
69
+ if act_fn == "silu":
70
+ self.act = nn.Silu()
71
+ elif act_fn == "mish":
72
+ self.act = nn.Mish()
73
+
74
+ if out_dim is not None:
75
+ time_embed_dim_out = out_dim
76
+ else:
77
+ time_embed_dim_out = time_embed_dim
78
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
79
+
80
+ def forward(self, sample):
81
+ sample = self.linear_1(sample)
82
+
83
+ if self.act is not None:
84
+ sample = self.act(sample)
85
+
86
+ sample = self.linear_2(sample)
87
+ return sample
88
+
89
+
90
+ class Timesteps(nn.Layer):
91
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
92
+ super().__init__()
93
+ self.num_channels = num_channels
94
+ self.flip_sin_to_cos = flip_sin_to_cos
95
+ self.downscale_freq_shift = downscale_freq_shift
96
+
97
+ def forward(self, timesteps):
98
+ t_emb = get_timestep_embedding(
99
+ timesteps,
100
+ self.num_channels,
101
+ flip_sin_to_cos=self.flip_sin_to_cos,
102
+ downscale_freq_shift=self.downscale_freq_shift,
103
+ )
104
+ return t_emb
105
+
106
+
107
+ class GaussianFourierProjection(nn.Layer):
108
+ """Gaussian Fourier embeddings for noise levels."""
109
+
110
+ def __init__(
111
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
112
+ ):
113
+ super().__init__()
114
+ self.register_buffer("weight", paddle.randn((embedding_size,)) * scale)
115
+ self.log = log
116
+ self.flip_sin_to_cos = flip_sin_to_cos
117
+
118
+ if set_W_to_weight:
119
+ # to delete later
120
+ self.register_buffer("W", paddle.randn((embedding_size,)) * scale)
121
+
122
+ self.weight = self.W
123
+
124
+ def forward(self, x):
125
+ if self.log:
126
+ x = paddle.log(x.cast(self.weight.dtype))
127
+
128
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
129
+
130
+ if self.flip_sin_to_cos:
131
+ out = paddle.concat([paddle.cos(x_proj), paddle.sin(x_proj)], axis=-1)
132
+ else:
133
+ out = paddle.concat([paddle.sin(x_proj), paddle.cos(x_proj)], axis=-1)
134
+ return out
135
+
136
+
137
+ class ImagePositionalEmbeddings(nn.Layer):
138
+ """
139
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
140
+ height and width of the latent space.
141
+
142
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
143
+
144
+ For VQ-diffusion:
145
+
146
+ Output vector embeddings are used as input for the transformer.
147
+
148
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
149
+
150
+ Args:
151
+ num_embed (`int`):
152
+ Number of embeddings for the latent pixels embeddings.
153
+ height (`int`):
154
+ Height of the latent image i.e. the number of height embeddings.
155
+ width (`int`):
156
+ Width of the latent image i.e. the number of width embeddings.
157
+ embed_dim (`int`):
158
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ num_embed: int,
164
+ height: int,
165
+ width: int,
166
+ embed_dim: int,
167
+ ):
168
+ super().__init__()
169
+
170
+ self.height = height
171
+ self.width = width
172
+ self.num_embed = num_embed
173
+ self.embed_dim = embed_dim
174
+
175
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
176
+ self.height_emb = nn.Embedding(self.height, embed_dim)
177
+ self.width_emb = nn.Embedding(self.width, embed_dim)
178
+
179
+ def forward(self, index):
180
+ emb = self.emb(index)
181
+
182
+ height_emb = self.height_emb(paddle.arange(self.height).reshape([1, self.height]))
183
+
184
+ # 1 x H x D -> 1 x H x 1 x D
185
+ height_emb = height_emb.unsqueeze(2)
186
+
187
+ width_emb = self.width_emb(paddle.arange(self.width).reshape([1, self.width]))
188
+
189
+ # 1 x W x D -> 1 x 1 x W x D
190
+ width_emb = width_emb.unsqueeze(1)
191
+
192
+ pos_emb = height_emb + width_emb
193
+
194
+ # 1 x H x W x D -> 1 x L xD
195
+ pos_emb = pos_emb.reshape([1, self.height * self.width, -1])
196
+
197
+ emb = emb + pos_emb[:, : emb.shape[1], :]
198
+
199
+ return emb
ppdiffusers/models/prior_transformer.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Union
17
+
18
+ import paddle
19
+ import paddle.nn as nn
20
+ import paddle.nn.functional as F
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..modeling_utils import ModelMixin
24
+ from ..utils import BaseOutput
25
+ from .attention import BasicTransformerBlock
26
+ from .embeddings import TimestepEmbedding, Timesteps
27
+
28
+ NEG_INF = -1e4
29
+
30
+
31
+ @dataclass
32
+ class PriorTransformerOutput(BaseOutput):
33
+ """
34
+ Args:
35
+ predicted_image_embedding (`paddle.Tensor` of shape `(batch_size, embedding_dim)`):
36
+ The predicted CLIP image embedding conditioned on the CLIP text embedding input.
37
+ """
38
+
39
+ predicted_image_embedding: paddle.Tensor
40
+
41
+
42
+ class PriorTransformer(ModelMixin, ConfigMixin):
43
+ """
44
+ The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the
45
+ transformer predicts the image embeddings through a denoising diffusion process.
46
+
47
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
48
+ implements for all the models (such as downloading or saving, etc.)
49
+
50
+ For more details, see the original paper: https://arxiv.org/abs/2204.06125
51
+
52
+ Parameters:
53
+ num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
54
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
55
+ num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
56
+ embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP
57
+ image embeddings and text embeddings are both the same dimension.
58
+ num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the
59
+ length of the prompt after it has been tokenized.
60
+ additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
61
+ projected hidden_states. The actual length of the used hidden_states is `num_embeddings +
62
+ additional_embeddings`.
63
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
64
+
65
+ """
66
+
67
+ @register_to_config
68
+ def __init__(
69
+ self,
70
+ num_attention_heads: int = 32,
71
+ attention_head_dim: int = 64,
72
+ num_layers: int = 20,
73
+ embedding_dim: int = 768,
74
+ num_embeddings=77,
75
+ additional_embeddings=4,
76
+ dropout: float = 0.0,
77
+ ):
78
+ super().__init__()
79
+ self.num_attention_heads = num_attention_heads
80
+ self.attention_head_dim = attention_head_dim
81
+ inner_dim = num_attention_heads * attention_head_dim
82
+ self.additional_embeddings = additional_embeddings
83
+
84
+ self.time_proj = Timesteps(inner_dim, True, 0)
85
+ self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
86
+
87
+ self.proj_in = nn.Linear(embedding_dim, inner_dim)
88
+
89
+ self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
90
+ self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
91
+
92
+ self.positional_embedding = self.create_parameter(
93
+ (1, num_embeddings + additional_embeddings, inner_dim),
94
+ dtype=paddle.get_default_dtype(),
95
+ default_initializer=nn.initializer.Constant(0.0),
96
+ )
97
+
98
+ self.prd_embedding = self.create_parameter(
99
+ (1, 1, inner_dim), dtype=paddle.get_default_dtype(), default_initializer=nn.initializer.Constant(0.0)
100
+ )
101
+
102
+ self.transformer_blocks = nn.LayerList(
103
+ [
104
+ BasicTransformerBlock(
105
+ inner_dim,
106
+ num_attention_heads,
107
+ attention_head_dim,
108
+ dropout=dropout,
109
+ activation_fn="gelu",
110
+ attention_bias=True,
111
+ )
112
+ for d in range(num_layers)
113
+ ]
114
+ )
115
+
116
+ self.norm_out = nn.LayerNorm(inner_dim)
117
+ self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
118
+
119
+ causal_attention_mask = paddle.triu(
120
+ paddle.full([num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], NEG_INF), 1
121
+ )
122
+ causal_attention_mask = causal_attention_mask.unsqueeze(0)
123
+ self.register_buffer("causal_attention_mask", causal_attention_mask, persistable=False)
124
+
125
+ self.clip_mean = self.create_parameter(
126
+ (1, embedding_dim), dtype=paddle.get_default_dtype(), default_initializer=nn.initializer.Constant(0.0)
127
+ )
128
+ self.clip_std = self.create_parameter(
129
+ (1, embedding_dim), dtype=paddle.get_default_dtype(), default_initializer=nn.initializer.Constant(0.0)
130
+ )
131
+
132
+ def forward(
133
+ self,
134
+ hidden_states,
135
+ timestep: Union[paddle.Tensor, float, int],
136
+ proj_embedding: paddle.Tensor,
137
+ encoder_hidden_states: paddle.Tensor,
138
+ attention_mask: Optional[paddle.Tensor] = None,
139
+ return_dict: bool = True,
140
+ ):
141
+ """
142
+ Args:
143
+ hidden_states (`paddle.Tensor` of shape `(batch_size, embedding_dim)`):
144
+ x_t, the currently predicted image embeddings.
145
+ timestep (`paddle.Tensor`):
146
+ Current denoising step.
147
+ proj_embedding (`paddle.Tensor` of shape `(batch_size, embedding_dim)`):
148
+ Projected embedding vector the denoising process is conditioned on.
149
+ encoder_hidden_states (`paddle.Tensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
150
+ Hidden states of the text embeddings the denoising process is conditioned on.
151
+ attention_mask (`paddle.Tensor` of shape `(batch_size, num_embeddings)`):
152
+ Text mask for the text embeddings.
153
+ return_dict (`bool`, *optional*, defaults to `True`):
154
+ Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain
155
+ tuple.
156
+
157
+ Returns:
158
+ [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
159
+ [`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When
160
+ returning a tuple, the first element is the sample tensor.
161
+ """
162
+ batch_size = hidden_states.shape[0]
163
+
164
+ timesteps = timestep
165
+ if not paddle.is_tensor(timesteps):
166
+ timesteps = paddle.to_tensor([timesteps], dtype=paddle.int64)
167
+ elif paddle.is_tensor(timesteps) and len(timesteps.shape) == 0:
168
+ timesteps = timesteps[None]
169
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
170
+ timesteps = timesteps * paddle.ones((batch_size,), dtype=timesteps.dtype)
171
+
172
+ timesteps_projected = self.time_proj(timesteps)
173
+
174
+ # timesteps does not contain any weights and will always return f32 tensors
175
+ # but time_embedding might be fp16, so we need to cast here.
176
+ timesteps_projected = timesteps_projected.cast(dtype=self.dtype)
177
+ time_embeddings = self.time_embedding(timesteps_projected)
178
+
179
+ proj_embeddings = self.embedding_proj(proj_embedding)
180
+ encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
181
+ hidden_states = self.proj_in(hidden_states)
182
+ prd_embedding = self.prd_embedding.cast(hidden_states.dtype).expand([batch_size, -1, -1])
183
+ positional_embeddings = self.positional_embedding.cast(hidden_states.dtype)
184
+
185
+ hidden_states = paddle.concat(
186
+ [
187
+ encoder_hidden_states,
188
+ proj_embeddings[:, None, :],
189
+ time_embeddings[:, None, :],
190
+ hidden_states[:, None, :],
191
+ prd_embedding,
192
+ ],
193
+ axis=1,
194
+ )
195
+
196
+ hidden_states = hidden_states + positional_embeddings
197
+
198
+ if attention_mask is not None:
199
+ attention_mask = (1 - attention_mask.cast(hidden_states.dtype)) * -10000.0
200
+ attention_mask = F.pad(
201
+ attention_mask.unsqueeze(0), (0, self.additional_embeddings), value=0.0, data_format="NCL"
202
+ ).squeeze(0)
203
+ attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).cast(hidden_states.dtype)
204
+ attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, axis=0)
205
+
206
+ for block in self.transformer_blocks:
207
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
208
+
209
+ hidden_states = self.norm_out(hidden_states)
210
+ hidden_states = hidden_states[:, -1]
211
+ predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
212
+
213
+ if not return_dict:
214
+ return (predicted_image_embedding,)
215
+
216
+ return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
217
+
218
+ def post_process_latents(self, prior_latents):
219
+ prior_latents = (prior_latents * self.clip_std) + self.clip_mean
220
+ return prior_latents
ppdiffusers/models/resnet.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+
18
+ import paddle
19
+ import paddle.nn as nn
20
+ import paddle.nn.functional as F
21
+
22
+
23
+ class Upsample1D(nn.Layer):
24
+ """
25
+ An upsampling layer with an optional convolution.
26
+
27
+ Parameters:
28
+ channels: channels in the inputs and outputs.
29
+ use_conv: a bool determining if a convolution is applied.
30
+ use_conv_transpose:
31
+ out_channels:
32
+ """
33
+
34
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
35
+ super().__init__()
36
+ self.channels = channels
37
+ self.out_channels = out_channels or channels
38
+ self.use_conv = use_conv
39
+ self.use_conv_transpose = use_conv_transpose
40
+ self.name = name
41
+
42
+ self.conv = None
43
+ if use_conv_transpose:
44
+ self.conv = nn.Conv1DTranspose(channels, self.out_channels, 4, 2, 1)
45
+ elif use_conv:
46
+ self.conv = nn.Conv1D(self.channels, self.out_channels, 3, padding=1)
47
+
48
+ def forward(self, x):
49
+ assert x.shape[1] == self.channels
50
+ if self.use_conv_transpose:
51
+ return self.conv(x)
52
+
53
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
54
+
55
+ if self.use_conv:
56
+ x = self.conv(x)
57
+
58
+ return x
59
+
60
+
61
+ class Downsample1D(nn.Layer):
62
+ """
63
+ A downsampling layer with an optional convolution.
64
+
65
+ Parameters:
66
+ channels: channels in the inputs and outputs.
67
+ use_conv: a bool determining if a convolution is applied.
68
+ out_channels:
69
+ padding:
70
+ """
71
+
72
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
73
+ super().__init__()
74
+ self.channels = channels
75
+ self.out_channels = out_channels or channels
76
+ self.use_conv = use_conv
77
+ self.padding = padding
78
+ stride = 2
79
+ self.name = name
80
+
81
+ if use_conv:
82
+ self.conv = nn.Conv1D(self.channels, self.out_channels, 3, stride=stride, padding=padding)
83
+ else:
84
+ assert self.channels == self.out_channels
85
+ self.conv = nn.AvgPool1D(kernel_size=stride, stride=stride)
86
+
87
+ def forward(self, x):
88
+ assert x.shape[1] == self.channels
89
+ return self.conv(x)
90
+
91
+
92
+ class Upsample2D(nn.Layer):
93
+ """
94
+ An upsampling layer with an optional convolution.
95
+
96
+ Parameters:
97
+ channels: channels in the inputs and outputs.
98
+ use_conv: a bool determining if a convolution is applied.
99
+ use_conv_transpose:
100
+ out_channels:
101
+ """
102
+
103
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
104
+ super().__init__()
105
+ self.channels = channels
106
+ self.out_channels = out_channels or channels
107
+ self.use_conv = use_conv
108
+ self.use_conv_transpose = use_conv_transpose
109
+ self.name = name
110
+
111
+ conv = None
112
+ if use_conv_transpose:
113
+ conv = nn.Conv2DTranspose(channels, self.out_channels, 4, 2, 1)
114
+ elif use_conv:
115
+ conv = nn.Conv2D(self.channels, self.out_channels, 3, padding=1)
116
+
117
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
118
+ if name == "conv":
119
+ self.conv = conv
120
+ else:
121
+ self.Conv2d_0 = conv
122
+
123
+ def forward(self, hidden_states, output_size=None):
124
+ assert hidden_states.shape[1] == self.channels
125
+
126
+ if self.use_conv_transpose:
127
+ return self.conv(hidden_states)
128
+
129
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
130
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
131
+ # https://github.com/pytorch/pytorch/issues/86679
132
+ dtype = hidden_states.dtype
133
+ if dtype == paddle.bfloat16:
134
+ hidden_states = hidden_states.cast("float32")
135
+
136
+ # if `output_size` is passed we force the interpolation output
137
+ # size and do not make use of `scale_factor=2`
138
+ if output_size is None:
139
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
140
+ else:
141
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
142
+
143
+ # If the input is bfloat16, we cast back to bfloat16
144
+ if dtype == paddle.bfloat16:
145
+ hidden_states = hidden_states.cast(dtype)
146
+
147
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
148
+ if self.use_conv:
149
+ if self.name == "conv":
150
+ hidden_states = self.conv(hidden_states)
151
+ else:
152
+ hidden_states = self.Conv2d_0(hidden_states)
153
+
154
+ return hidden_states
155
+
156
+
157
+ class Downsample2D(nn.Layer):
158
+ """
159
+ A downsampling layer with an optional convolution.
160
+
161
+ Parameters:
162
+ channels: channels in the inputs and outputs.
163
+ use_conv: a bool determining if a convolution is applied.
164
+ out_channels:
165
+ padding:
166
+ """
167
+
168
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
169
+ super().__init__()
170
+ self.channels = channels
171
+ self.out_channels = out_channels or channels
172
+ self.use_conv = use_conv
173
+ self.padding = padding
174
+ stride = 2
175
+ self.name = name
176
+
177
+ if use_conv:
178
+ conv = nn.Conv2D(self.channels, self.out_channels, 3, stride=stride, padding=padding)
179
+ else:
180
+ assert self.channels == self.out_channels
181
+ conv = nn.AvgPool2D(kernel_size=stride, stride=stride)
182
+
183
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
184
+ if name == "conv":
185
+ self.Conv2d_0 = conv
186
+ self.conv = conv
187
+ elif name == "Conv2d_0":
188
+ self.conv = conv
189
+ else:
190
+ self.conv = conv
191
+
192
+ def forward(self, hidden_states):
193
+ assert hidden_states.shape[1] == self.channels
194
+ if self.use_conv and self.padding == 0:
195
+ pad = (0, 1, 0, 1)
196
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
197
+
198
+ assert hidden_states.shape[1] == self.channels
199
+ hidden_states = self.conv(hidden_states)
200
+
201
+ return hidden_states
202
+
203
+
204
+ class FirUpsample2D(nn.Layer):
205
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
206
+ super().__init__()
207
+ out_channels = out_channels if out_channels else channels
208
+ if use_conv:
209
+ self.Conv2d_0 = nn.Conv2D(channels, out_channels, kernel_size=3, stride=1, padding=1)
210
+ self.use_conv = use_conv
211
+ self.fir_kernel = fir_kernel
212
+ self.out_channels = out_channels
213
+
214
+ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
215
+ """Fused `upsample_2d()` followed by `Conv2d()`.
216
+
217
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
218
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
219
+ arbitrary order.
220
+
221
+ Args:
222
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
223
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
224
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
225
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
226
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
227
+ factor: Integer upsampling factor (default: 2).
228
+ gain: Scaling factor for signal magnitude (default: 1.0).
229
+
230
+ Returns:
231
+ output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
232
+ datatype as `hidden_states`.
233
+ """
234
+
235
+ assert isinstance(factor, int) and factor >= 1
236
+
237
+ # Setup filter kernel.
238
+ if kernel is None:
239
+ kernel = [1] * factor
240
+
241
+ # setup kernel
242
+ kernel = paddle.to_tensor(kernel, dtype="float32")
243
+ if kernel.ndim == 1:
244
+ kernel = paddle.outer(kernel, kernel)
245
+ kernel /= paddle.sum(kernel)
246
+
247
+ kernel = kernel * (gain * (factor**2))
248
+
249
+ if self.use_conv:
250
+ convH = weight.shape[2]
251
+ convW = weight.shape[3]
252
+ inC = weight.shape[1]
253
+
254
+ pad_value = (kernel.shape[0] - factor) - (convW - 1)
255
+
256
+ stride = (factor, factor)
257
+ # Determine data dimensions.
258
+ output_shape = (
259
+ (hidden_states.shape[2] - 1) * factor + convH,
260
+ (hidden_states.shape[3] - 1) * factor + convW,
261
+ )
262
+ output_padding = (
263
+ output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
264
+ output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
265
+ )
266
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
267
+ num_groups = hidden_states.shape[1] // inC
268
+
269
+ # Transpose weights.
270
+ weight = weight.reshape([num_groups, -1, inC, convH, convW])
271
+ weight = paddle.flip(weight, axis=[3, 4]).transpose([0, 2, 1, 3, 4])
272
+ weight = weight.reshape([num_groups * inC, -1, convH, convW])
273
+
274
+ inverse_conv = F.conv2d_transpose(
275
+ hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
276
+ )
277
+
278
+ output = upfirdn2d_native(
279
+ inverse_conv,
280
+ paddle.to_tensor(kernel),
281
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
282
+ )
283
+ else:
284
+ pad_value = kernel.shape[0] - factor
285
+ output = upfirdn2d_native(
286
+ hidden_states,
287
+ paddle.to_tensor(kernel),
288
+ up=factor,
289
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
290
+ )
291
+
292
+ return output
293
+
294
+ def forward(self, hidden_states):
295
+ if self.use_conv:
296
+ height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
297
+ height = height + self.Conv2d_0.bias.reshape([1, -1, 1, 1])
298
+ else:
299
+ height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
300
+
301
+ return height
302
+
303
+
304
+ class FirDownsample2D(nn.Layer):
305
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
306
+ super().__init__()
307
+ out_channels = out_channels if out_channels else channels
308
+ if use_conv:
309
+ self.Conv2d_0 = nn.Conv2D(channels, out_channels, kernel_size=3, stride=1, padding=1)
310
+ self.fir_kernel = fir_kernel
311
+ self.use_conv = use_conv
312
+ self.out_channels = out_channels
313
+
314
+ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
315
+ """Fused `Conv2d()` followed by `downsample_2d()`.
316
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
317
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
318
+ arbitrary order.
319
+
320
+ Args:
321
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
322
+ weight:
323
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
324
+ performed by `inChannels = x.shape[0] // numGroups`.
325
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
326
+ factor`, which corresponds to average pooling.
327
+ factor: Integer downsampling factor (default: 2).
328
+ gain: Scaling factor for signal magnitude (default: 1.0).
329
+
330
+ Returns:
331
+ output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
332
+ same datatype as `x`.
333
+ """
334
+
335
+ assert isinstance(factor, int) and factor >= 1
336
+ if kernel is None:
337
+ kernel = [1] * factor
338
+
339
+ # setup kernel
340
+ kernel = paddle.to_tensor(kernel, dtype="float32")
341
+ if kernel.ndim == 1:
342
+ kernel = paddle.outer(kernel, kernel)
343
+ kernel /= paddle.sum(kernel)
344
+
345
+ kernel = kernel * gain
346
+
347
+ if self.use_conv:
348
+ _, _, convH, convW = weight.shape
349
+ pad_value = (kernel.shape[0] - factor) + (convW - 1)
350
+ stride_value = [factor, factor]
351
+ upfirdn_input = upfirdn2d_native(
352
+ hidden_states,
353
+ paddle.to_tensor(kernel),
354
+ pad=((pad_value + 1) // 2, pad_value // 2),
355
+ )
356
+ output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
357
+ else:
358
+ pad_value = kernel.shape[0] - factor
359
+ output = upfirdn2d_native(
360
+ hidden_states,
361
+ paddle.to_tensor(kernel),
362
+ down=factor,
363
+ pad=((pad_value + 1) // 2, pad_value // 2),
364
+ )
365
+
366
+ return output
367
+
368
+ def forward(self, hidden_states):
369
+ if self.use_conv:
370
+ downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
371
+ hidden_states = downsample_input + self.Conv2d_0.bias.reshape([1, -1, 1, 1])
372
+ else:
373
+ hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
374
+
375
+ return hidden_states
376
+
377
+
378
+ class ResnetBlock2D(nn.Layer):
379
+ def __init__(
380
+ self,
381
+ *,
382
+ in_channels,
383
+ out_channels=None,
384
+ conv_shortcut=False,
385
+ dropout=0.0,
386
+ temb_channels=512,
387
+ groups=32,
388
+ groups_out=None,
389
+ pre_norm=True,
390
+ eps=1e-6,
391
+ non_linearity="swish",
392
+ time_embedding_norm="default",
393
+ kernel=None,
394
+ output_scale_factor=1.0,
395
+ use_in_shortcut=None,
396
+ up=False,
397
+ down=False,
398
+ ):
399
+ super().__init__()
400
+ self.pre_norm = pre_norm
401
+ self.pre_norm = True
402
+ self.in_channels = in_channels
403
+ out_channels = in_channels if out_channels is None else out_channels
404
+ self.out_channels = out_channels
405
+ self.use_conv_shortcut = conv_shortcut
406
+ self.time_embedding_norm = time_embedding_norm
407
+ self.up = up
408
+ self.down = down
409
+ self.output_scale_factor = output_scale_factor
410
+
411
+ if groups_out is None:
412
+ groups_out = groups
413
+
414
+ self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, epsilon=eps)
415
+
416
+ self.conv1 = nn.Conv2D(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
417
+
418
+ if temb_channels is not None:
419
+ if self.time_embedding_norm == "default":
420
+ time_emb_proj_out_channels = out_channels
421
+ elif self.time_embedding_norm == "scale_shift":
422
+ time_emb_proj_out_channels = out_channels * 2
423
+ else:
424
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
425
+
426
+ self.time_emb_proj = nn.Linear(temb_channels, time_emb_proj_out_channels)
427
+ else:
428
+ self.time_emb_proj = None
429
+
430
+ self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, epsilon=eps)
431
+ self.dropout = nn.Dropout(dropout)
432
+ self.conv2 = nn.Conv2D(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
433
+
434
+ if non_linearity == "swish":
435
+ self.nonlinearity = lambda x: F.silu(x)
436
+ elif non_linearity == "mish":
437
+ self.nonlinearity = Mish()
438
+ elif non_linearity == "silu":
439
+ self.nonlinearity = nn.Silu()
440
+
441
+ self.upsample = self.downsample = None
442
+ if self.up:
443
+ if kernel == "fir":
444
+ fir_kernel = (1, 3, 3, 1)
445
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
446
+ elif kernel == "sde_vp":
447
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
448
+ else:
449
+ self.upsample = Upsample2D(in_channels, use_conv=False)
450
+ elif self.down:
451
+ if kernel == "fir":
452
+ fir_kernel = (1, 3, 3, 1)
453
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
454
+ elif kernel == "sde_vp":
455
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
456
+ else:
457
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
458
+
459
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
460
+
461
+ self.conv_shortcut = None
462
+ if self.use_in_shortcut:
463
+ self.conv_shortcut = nn.Conv2D(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
464
+
465
+ def forward(self, input_tensor, temb):
466
+ hidden_states = input_tensor
467
+
468
+ hidden_states = self.norm1(hidden_states)
469
+ hidden_states = self.nonlinearity(hidden_states)
470
+
471
+ if self.upsample is not None:
472
+ input_tensor = self.upsample(input_tensor)
473
+ hidden_states = self.upsample(hidden_states)
474
+ elif self.downsample is not None:
475
+ input_tensor = self.downsample(input_tensor)
476
+ hidden_states = self.downsample(hidden_states)
477
+
478
+ hidden_states = self.conv1(hidden_states)
479
+
480
+ if temb is not None:
481
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
482
+
483
+ if temb is not None and self.time_embedding_norm == "default":
484
+ hidden_states = hidden_states + temb
485
+
486
+ hidden_states = self.norm2(hidden_states)
487
+
488
+ if temb is not None and self.time_embedding_norm == "scale_shift":
489
+ scale, shift = paddle.chunk(temb, 2, axis=1)
490
+ hidden_states = hidden_states * (1 + scale) + shift
491
+
492
+ hidden_states = self.nonlinearity(hidden_states)
493
+
494
+ hidden_states = self.dropout(hidden_states)
495
+ hidden_states = self.conv2(hidden_states)
496
+
497
+ if self.conv_shortcut is not None:
498
+ input_tensor = self.conv_shortcut(input_tensor)
499
+
500
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
501
+
502
+ return output_tensor
503
+
504
+
505
+ class Mish(nn.Layer):
506
+ def forward(self, hidden_states):
507
+ return hidden_states * paddle.tanh(F.softplus(hidden_states))
508
+
509
+
510
+ # unet_rl.py
511
+ def rearrange_dims(tensor):
512
+ if len(tensor.shape) == 2:
513
+ return tensor[:, :, None]
514
+ if len(tensor.shape) == 3:
515
+ return tensor[:, :, None, :]
516
+ elif len(tensor.shape) == 4:
517
+ return tensor[:, :, 0, :]
518
+ else:
519
+ raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
520
+
521
+
522
+ class Conv1dBlock(nn.Layer):
523
+ """
524
+ Conv1d --> GroupNorm --> Mish
525
+ """
526
+
527
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
528
+ super().__init__()
529
+
530
+ self.conv1d = nn.Conv1D(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
531
+ self.group_norm = nn.GroupNorm(n_groups, out_channels)
532
+ self.mish = nn.Mish()
533
+
534
+ def forward(self, x):
535
+ x = self.conv1d(x)
536
+ x = rearrange_dims(x)
537
+ x = self.group_norm(x)
538
+ x = rearrange_dims(x)
539
+ x = self.mish(x)
540
+ return x
541
+
542
+
543
+ # unet_rl.py
544
+ class ResidualTemporalBlock1D(nn.Layer):
545
+ def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
546
+ super().__init__()
547
+ self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
548
+ self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
549
+
550
+ self.time_emb_act = nn.Mish()
551
+ self.time_emb = nn.Linear(embed_dim, out_channels)
552
+
553
+ self.residual_conv = (
554
+ nn.Conv1D(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
555
+ )
556
+
557
+ def forward(self, x, t):
558
+ """
559
+ Args:
560
+ x : [ batch_size x inp_channels x horizon ]
561
+ t : [ batch_size x embed_dim ]
562
+
563
+ returns:
564
+ out : [ batch_size x out_channels x horizon ]
565
+ """
566
+ t = self.time_emb_act(t)
567
+ t = self.time_emb(t)
568
+ out = self.conv_in(x) + rearrange_dims(t)
569
+ out = self.conv_out(out)
570
+ return out + self.residual_conv(x)
571
+
572
+
573
+ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
574
+ r"""Upsample2D a batch of 2D images with the given filter.
575
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
576
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
577
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
578
+ a: multiple of the upsampling factor.
579
+
580
+ Args:
581
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
582
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
583
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
584
+ factor: Integer upsampling factor (default: 2).
585
+ gain: Scaling factor for signal magnitude (default: 1.0).
586
+
587
+ Returns:
588
+ output: Tensor of the shape `[N, C, H * factor, W * factor]`
589
+ """
590
+ assert isinstance(factor, int) and factor >= 1
591
+ if kernel is None:
592
+ kernel = [1] * factor
593
+
594
+ kernel = paddle.to_tensor(kernel, dtype="float32")
595
+ if kernel.ndim == 1:
596
+ kernel = paddle.outer(kernel, kernel)
597
+ kernel /= paddle.sum(kernel)
598
+
599
+ if gain != 1:
600
+ kernel = kernel * (gain * (factor**2))
601
+ else:
602
+ kernel = kernel * (factor**2)
603
+ pad_value = kernel.shape[0] - factor
604
+ output = upfirdn2d_native(
605
+ hidden_states,
606
+ kernel,
607
+ up=factor,
608
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
609
+ )
610
+ return output
611
+
612
+
613
+ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
614
+ r"""Downsample2D a batch of 2D images with the given filter.
615
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
616
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
617
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
618
+ shape is a multiple of the downsampling factor.
619
+
620
+ Args:
621
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
622
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
623
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
624
+ factor: Integer downsampling factor (default: 2).
625
+ gain: Scaling factor for signal magnitude (default: 1.0).
626
+
627
+ Returns:
628
+ output: Tensor of the shape `[N, C, H // factor, W // factor]`
629
+ """
630
+
631
+ assert isinstance(factor, int) and factor >= 1
632
+ if kernel is None:
633
+ kernel = [1] * factor
634
+
635
+ kernel = paddle.to_tensor(kernel, dtype="float32")
636
+ if kernel.ndim == 1:
637
+ kernel = paddle.outer(kernel, kernel)
638
+ kernel /= paddle.sum(kernel)
639
+
640
+ kernel = kernel * gain
641
+ pad_value = kernel.shape[0] - factor
642
+ output = upfirdn2d_native(hidden_states, kernel, down=factor, pad=((pad_value + 1) // 2, pad_value // 2))
643
+ return output
644
+
645
+
646
+ def dummy_pad(tensor, up_x=0, up_y=0):
647
+ if up_x > 0:
648
+ tensor = paddle.concat(
649
+ [
650
+ tensor,
651
+ paddle.zeros(
652
+ [tensor.shape[0], tensor.shape[1], tensor.shape[2], tensor.shape[3], up_x, tensor.shape[5]],
653
+ dtype=tensor.dtype,
654
+ ),
655
+ ],
656
+ axis=4,
657
+ )
658
+ if up_y > 0:
659
+ tensor = paddle.concat(
660
+ [
661
+ tensor,
662
+ paddle.zeros(
663
+ [tensor.shape[0], tensor.shape[1], up_y, tensor.shape[3], tensor.shape[4], tensor.shape[5]],
664
+ dtype=tensor.dtype,
665
+ ),
666
+ ],
667
+ axis=2,
668
+ )
669
+ return tensor
670
+
671
+
672
+ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
673
+ up_x = up_y = up
674
+ down_x = down_y = down
675
+ pad_x0 = pad_y0 = pad[0]
676
+ pad_x1 = pad_y1 = pad[1]
677
+
678
+ _, channel, in_h, in_w = tensor.shape
679
+ tensor = tensor.reshape([-1, in_h, in_w, 1])
680
+
681
+ _, in_h, in_w, minor = tensor.shape
682
+ kernel_h, kernel_w = kernel.shape
683
+
684
+ out = tensor.reshape([-1, in_h, 1, in_w, 1, minor])
685
+ # (TODO, junnyu F.pad bug)
686
+ # F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
687
+ out = dummy_pad(out, up_x - 1, up_y - 1)
688
+ out = out.reshape([-1, in_h * up_y, in_w * up_x, minor])
689
+
690
+ # (TODO, junnyu F.pad bug)
691
+ # out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
692
+ out = out.unsqueeze(0)
693
+ out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0), 0, 0], data_format="NDHWC")
694
+ out = out.squeeze(0)
695
+
696
+ out = out[
697
+ :,
698
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
699
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
700
+ :,
701
+ ]
702
+
703
+ out = out.transpose([0, 3, 1, 2])
704
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
705
+ w = paddle.flip(kernel, [0, 1]).reshape([1, 1, kernel_h, kernel_w])
706
+ out = F.conv2d(out, w)
707
+ out = out.reshape(
708
+ [-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1]
709
+ )
710
+ out = out.transpose([0, 2, 3, 1])
711
+ out = out[:, ::down_y, ::down_x, :]
712
+
713
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
714
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
715
+
716
+ return out.reshape([-1, channel, out_h, out_w])