StoryVisualizationaTask
#193
by
Anyou
- opened
- README.md +0 -207
- __init__.py +0 -0
- config.yaml +63 -0
- data_script/flintstones_hdf5.py +51 -0
- data_script/pororo_hdf5.py +83 -0
- data_script/vist_hdf5.py +111 -0
- data_script/vist_img_download.py +61 -0
- datasets/flintstones.py +93 -0
- datasets/pororo.py +144 -0
- datasets/vistdii.py +94 -0
- datasets/vistsis.py +94 -0
- environment.yml +271 -0
- fid_utils.py +41 -0
- main.py +537 -0
- models/blip_override/blip.py +240 -0
- models/blip_override/med.py +955 -0
- models/blip_override/med_config.json +21 -0
- models/blip_override/vit.py +302 -0
- models/diffusers_override/attention.py +669 -0
- models/diffusers_override/unet_2d_blocks.py +1602 -0
- models/diffusers_override/unet_2d_condition.py +359 -0
- models/inception.py +314 -0
- v1-5-pruned-emaonly.ckpt → pororo_100.h5 +2 -2
- readme-storyvisualization.md +123 -0
- requirements.txt +10 -0
- run.sh +1 -0
- test.py +94 -0
- transtoyolo.py +320 -0
- v1-5-pruned-emaonly.safetensors +0 -3
- v1-5-pruned.safetensors +0 -3
README.md
DELETED
@@ -1,207 +0,0 @@
|
|
1 |
-
---
|
2 |
-
license: creativeml-openrail-m
|
3 |
-
tags:
|
4 |
-
- stable-diffusion
|
5 |
-
- stable-diffusion-diffusers
|
6 |
-
- text-to-image
|
7 |
-
inference: true
|
8 |
-
extra_gated_prompt: |-
|
9 |
-
This model is open access and available to all, with a CreativeML OpenRAIL-M license further specifying rights and usage.
|
10 |
-
The CreativeML OpenRAIL License specifies:
|
11 |
-
|
12 |
-
1. You can't use the model to deliberately produce nor share illegal or harmful outputs or content
|
13 |
-
2. CompVis claims no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in the license
|
14 |
-
3. You may re-distribute the weights and use the model commercially and/or as a service. If you do, please be aware you have to include the same use restrictions as the ones in the license and share a copy of the CreativeML OpenRAIL-M to all your users (please read the license entirely and carefully)
|
15 |
-
Please read the full license carefully here: https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
16 |
-
|
17 |
-
extra_gated_heading: Please read the LICENSE to access this model
|
18 |
-
---
|
19 |
-
|
20 |
-
# Stable Diffusion v1-5 Model Card
|
21 |
-
|
22 |
-
Stable Diffusion is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input.
|
23 |
-
For more information about how Stable Diffusion functions, please have a look at [🤗's Stable Diffusion blog](https://huggingface.co/blog/stable_diffusion).
|
24 |
-
|
25 |
-
The **Stable-Diffusion-v1-5** checkpoint was initialized with the weights of the [Stable-Diffusion-v1-2](https:/steps/huggingface.co/CompVis/stable-diffusion-v1-2)
|
26 |
-
checkpoint and subsequently fine-tuned on 595k steps at resolution 512x512 on "laion-aesthetics v2 5+" and 10% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
|
27 |
-
|
28 |
-
You can use this both with the [🧨Diffusers library](https://github.com/huggingface/diffusers) and the [RunwayML GitHub repository](https://github.com/runwayml/stable-diffusion).
|
29 |
-
|
30 |
-
### Diffusers
|
31 |
-
```py
|
32 |
-
from diffusers import StableDiffusionPipeline
|
33 |
-
import torch
|
34 |
-
|
35 |
-
model_id = "runwayml/stable-diffusion-v1-5"
|
36 |
-
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
37 |
-
pipe = pipe.to("cuda")
|
38 |
-
|
39 |
-
prompt = "a photo of an astronaut riding a horse on mars"
|
40 |
-
image = pipe(prompt).images[0]
|
41 |
-
|
42 |
-
image.save("astronaut_rides_horse.png")
|
43 |
-
```
|
44 |
-
For more detailed instructions, use-cases and examples in JAX follow the instructions [here](https://github.com/huggingface/diffusers#text-to-image-generation-with-stable-diffusion)
|
45 |
-
|
46 |
-
### Original GitHub Repository
|
47 |
-
|
48 |
-
1. Download the weights
|
49 |
-
- [v1-5-pruned-emaonly.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt) - 4.27GB, ema-only weight. uses less VRAM - suitable for inference
|
50 |
-
- [v1-5-pruned.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt) - 7.7GB, ema+non-ema weights. uses more VRAM - suitable for fine-tuning
|
51 |
-
|
52 |
-
2. Follow instructions [here](https://github.com/runwayml/stable-diffusion).
|
53 |
-
|
54 |
-
## Model Details
|
55 |
-
- **Developed by:** Robin Rombach, Patrick Esser
|
56 |
-
- **Model type:** Diffusion-based text-to-image generation model
|
57 |
-
- **Language(s):** English
|
58 |
-
- **License:** [The CreativeML OpenRAIL M license](https://huggingface.co/spaces/CompVis/stable-diffusion-license) is an [Open RAIL M license](https://www.licenses.ai/blog/2022/8/18/naming-convention-of-responsible-ai-licenses), adapted from the work that [BigScience](https://bigscience.huggingface.co/) and [the RAIL Initiative](https://www.licenses.ai/) are jointly carrying in the area of responsible AI licensing. See also [the article about the BLOOM Open RAIL license](https://bigscience.huggingface.co/blog/the-bigscience-rail-license) on which our license is based.
|
59 |
-
- **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487).
|
60 |
-
- **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752).
|
61 |
-
- **Cite as:**
|
62 |
-
|
63 |
-
@InProceedings{Rombach_2022_CVPR,
|
64 |
-
author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
|
65 |
-
title = {High-Resolution Image Synthesis With Latent Diffusion Models},
|
66 |
-
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
67 |
-
month = {June},
|
68 |
-
year = {2022},
|
69 |
-
pages = {10684-10695}
|
70 |
-
}
|
71 |
-
|
72 |
-
# Uses
|
73 |
-
|
74 |
-
## Direct Use
|
75 |
-
The model is intended for research purposes only. Possible research areas and
|
76 |
-
tasks include
|
77 |
-
|
78 |
-
- Safe deployment of models which have the potential to generate harmful content.
|
79 |
-
- Probing and understanding the limitations and biases of generative models.
|
80 |
-
- Generation of artworks and use in design and other artistic processes.
|
81 |
-
- Applications in educational or creative tools.
|
82 |
-
- Research on generative models.
|
83 |
-
|
84 |
-
Excluded uses are described below.
|
85 |
-
|
86 |
-
### Misuse, Malicious Use, and Out-of-Scope Use
|
87 |
-
_Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_.
|
88 |
-
|
89 |
-
|
90 |
-
The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
|
91 |
-
|
92 |
-
#### Out-of-Scope Use
|
93 |
-
The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
94 |
-
|
95 |
-
#### Misuse and Malicious Use
|
96 |
-
Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
|
97 |
-
|
98 |
-
- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
|
99 |
-
- Intentionally promoting or propagating discriminatory content or harmful stereotypes.
|
100 |
-
- Impersonating individuals without their consent.
|
101 |
-
- Sexual content without consent of the people who might see it.
|
102 |
-
- Mis- and disinformation
|
103 |
-
- Representations of egregious violence and gore
|
104 |
-
- Sharing of copyrighted or licensed material in violation of its terms of use.
|
105 |
-
- Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
|
106 |
-
|
107 |
-
## Limitations and Bias
|
108 |
-
|
109 |
-
### Limitations
|
110 |
-
|
111 |
-
- The model does not achieve perfect photorealism
|
112 |
-
- The model cannot render legible text
|
113 |
-
- The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
|
114 |
-
- Faces and people in general may not be generated properly.
|
115 |
-
- The model was trained mainly with English captions and will not work as well in other languages.
|
116 |
-
- The autoencoding part of the model is lossy
|
117 |
-
- The model was trained on a large-scale dataset
|
118 |
-
[LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
|
119 |
-
and is not fit for product use without additional safety mechanisms and
|
120 |
-
considerations.
|
121 |
-
- No additional measures were used to deduplicate the dataset. As a result, we observe some degree of memorization for images that are duplicated in the training data.
|
122 |
-
The training data can be searched at [https://rom1504.github.io/clip-retrieval/](https://rom1504.github.io/clip-retrieval/) to possibly assist in the detection of memorized images.
|
123 |
-
|
124 |
-
### Bias
|
125 |
-
|
126 |
-
While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
|
127 |
-
Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
|
128 |
-
which consists of images that are primarily limited to English descriptions.
|
129 |
-
Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
|
130 |
-
This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
|
131 |
-
ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
|
132 |
-
|
133 |
-
### Safety Module
|
134 |
-
|
135 |
-
The intended use of this model is with the [Safety Checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) in Diffusers.
|
136 |
-
This checker works by checking model outputs against known hard-coded NSFW concepts.
|
137 |
-
The concepts are intentionally hidden to reduce the likelihood of reverse-engineering this filter.
|
138 |
-
Specifically, the checker compares the class probability of harmful concepts in the embedding space of the `CLIPTextModel` *after generation* of the images.
|
139 |
-
The concepts are passed into the model with the generated image and compared to a hand-engineered weight for each NSFW concept.
|
140 |
-
|
141 |
-
|
142 |
-
## Training
|
143 |
-
|
144 |
-
**Training Data**
|
145 |
-
The model developers used the following dataset for training the model:
|
146 |
-
|
147 |
-
- LAION-2B (en) and subsets thereof (see next section)
|
148 |
-
|
149 |
-
**Training Procedure**
|
150 |
-
Stable Diffusion v1-5 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
|
151 |
-
|
152 |
-
- Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
|
153 |
-
- Text prompts are encoded through a ViT-L/14 text-encoder.
|
154 |
-
- The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
|
155 |
-
- The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet.
|
156 |
-
|
157 |
-
Currently six Stable Diffusion checkpoints are provided, which were trained as follows.
|
158 |
-
- [`stable-diffusion-v1-1`](https://huggingface.co/CompVis/stable-diffusion-v1-1): 237,000 steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
|
159 |
-
194,000 steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
|
160 |
-
- [`stable-diffusion-v1-2`](https://huggingface.co/CompVis/stable-diffusion-v1-2): Resumed from `stable-diffusion-v1-1`.
|
161 |
-
515,000 steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en,
|
162 |
-
filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
|
163 |
-
- [`stable-diffusion-v1-3`](https://huggingface.co/CompVis/stable-diffusion-v1-3): Resumed from `stable-diffusion-v1-2` - 195,000 steps at resolution `512x512` on "laion-improved-aesthetics" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
|
164 |
-
- [`stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) Resumed from `stable-diffusion-v1-2` - 225,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
|
165 |
-
- [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) Resumed from `stable-diffusion-v1-2` - 595,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
|
166 |
-
- [`stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting) Resumed from `stable-diffusion-v1-5` - then 440,000 steps of inpainting training at resolution 512x512 on “laion-aesthetics v2 5+” and 10% dropping of the text-conditioning. For inpainting, the UNet has 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself) whose weights were zero-initialized after restoring the non-inpainting checkpoint. During training, we generate synthetic masks and in 25% mask everything.
|
167 |
-
|
168 |
-
- **Hardware:** 32 x 8 x A100 GPUs
|
169 |
-
- **Optimizer:** AdamW
|
170 |
-
- **Gradient Accumulations**: 2
|
171 |
-
- **Batch:** 32 x 8 x 2 x 4 = 2048
|
172 |
-
- **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
|
173 |
-
|
174 |
-
## Evaluation Results
|
175 |
-
Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
|
176 |
-
5.0, 6.0, 7.0, 8.0) and 50 PNDM/PLMS sampling
|
177 |
-
steps show the relative improvements of the checkpoints:
|
178 |
-
|
179 |
-
![pareto](https://huggingface.co/CompVis/stable-diffusion/resolve/main/v1-1-to-v1-5.png)
|
180 |
-
|
181 |
-
Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
|
182 |
-
## Environmental Impact
|
183 |
-
|
184 |
-
**Stable Diffusion v1** **Estimated Emissions**
|
185 |
-
Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
|
186 |
-
|
187 |
-
- **Hardware Type:** A100 PCIe 40GB
|
188 |
-
- **Hours used:** 150000
|
189 |
-
- **Cloud Provider:** AWS
|
190 |
-
- **Compute Region:** US-east
|
191 |
-
- **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq.
|
192 |
-
|
193 |
-
|
194 |
-
## Citation
|
195 |
-
|
196 |
-
```bibtex
|
197 |
-
@InProceedings{Rombach_2022_CVPR,
|
198 |
-
author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
|
199 |
-
title = {High-Resolution Image Synthesis With Latent Diffusion Models},
|
200 |
-
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
201 |
-
month = {June},
|
202 |
-
year = {2022},
|
203 |
-
pages = {10684-10695}
|
204 |
-
}
|
205 |
-
```
|
206 |
-
|
207 |
-
*This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__init__.py
ADDED
Binary file (2 Bytes). View file
|
|
config.yaml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# device
|
2 |
+
mode: sample # train sample
|
3 |
+
gpu_ids: [3] # gpu ids
|
4 |
+
batch_size: 1 # batch size each item denotes one story
|
5 |
+
num_workers: 4 # number of workers
|
6 |
+
num_cpu_cores: -1 # number of cpu cores
|
7 |
+
seed: 0 # random seed
|
8 |
+
ckpt_dir: /root/lihui/StoryVisualization/save_ckpt_epoch5_new # checkpoint directory
|
9 |
+
run_name: ARLDM # name for this run
|
10 |
+
|
11 |
+
# task
|
12 |
+
dataset: pororo # pororo flintstones vistsis vistdii
|
13 |
+
task: visualization # continuation visualization
|
14 |
+
|
15 |
+
# train
|
16 |
+
init_lr: 1e-5 # initial learning rate
|
17 |
+
warmup_epochs: 1 # warmup epochs
|
18 |
+
max_epochs: 5 #50 # max epochs
|
19 |
+
train_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for resume, none for train from scratch
|
20 |
+
freeze_clip: True #False # whether to freeze clip
|
21 |
+
freeze_blip: True #False # whether to freeze blip
|
22 |
+
freeze_resnet: True #False # whether to freeze resnet
|
23 |
+
|
24 |
+
# sample
|
25 |
+
test_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for test
|
26 |
+
calculate_fid: True # whether to calculate FID scores
|
27 |
+
scheduler: ddim # ddim pndm
|
28 |
+
guidance_scale: 6 # guidance scale
|
29 |
+
num_inference_steps: 250 # number of inference steps
|
30 |
+
sample_output_dir: /root/lihui/StoryVisualization/save_samples_128_epoch50 # output directory
|
31 |
+
|
32 |
+
pororo:
|
33 |
+
hdf5_file: /root/lihui/StoryVisualization/pororo.h5
|
34 |
+
max_length: 85
|
35 |
+
new_tokens: [ "pororo", "loopy", "eddy", "harry", "poby", "tongtong", "crong", "rody", "petty" ]
|
36 |
+
clip_embedding_tokens: 49416
|
37 |
+
blip_embedding_tokens: 30530
|
38 |
+
|
39 |
+
flintstones:
|
40 |
+
hdf5_file: /path/to/flintstones.h5
|
41 |
+
max_length: 91
|
42 |
+
new_tokens: [ "fred", "barney", "wilma", "betty", "pebbles", "dino", "slate" ]
|
43 |
+
clip_embedding_tokens: 49412
|
44 |
+
blip_embedding_tokens: 30525
|
45 |
+
|
46 |
+
vistsis:
|
47 |
+
hdf5_file: /path/to/vist.h5
|
48 |
+
max_length: 100
|
49 |
+
clip_embedding_tokens: 49408
|
50 |
+
blip_embedding_tokens: 30524
|
51 |
+
|
52 |
+
vistdii:
|
53 |
+
hdf5_file: /path/to/vist.h5
|
54 |
+
max_length: 65
|
55 |
+
clip_embedding_tokens: 49408
|
56 |
+
blip_embedding_tokens: 30524
|
57 |
+
|
58 |
+
hydra:
|
59 |
+
run:
|
60 |
+
dir: .
|
61 |
+
output_subdir: null
|
62 |
+
hydra/job_logging: disabled
|
63 |
+
hydra/hydra_logging: disabled
|
data_script/flintstones_hdf5.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import h5py
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
def main(args):
|
13 |
+
splits = json.load(open(os.path.join(args.data_dir, 'train-val-test_split.json'), 'r'))
|
14 |
+
train_ids, val_ids, test_ids = splits["train"], splits["val"], splits["test"]
|
15 |
+
followings = pickle.load(open(os.path.join(args.data_dir, 'following_cache4.pkl'), 'rb'))
|
16 |
+
annotations = json.load(open(os.path.join(args.data_dir, 'flintstones_annotations_v1-0.json')))
|
17 |
+
descriptions = dict()
|
18 |
+
for sample in annotations:
|
19 |
+
descriptions[sample["globalID"]] = sample["description"]
|
20 |
+
|
21 |
+
f = h5py.File(args.save_path, "w")
|
22 |
+
for subset, ids in {'train': train_ids, 'val': val_ids, 'test': test_ids}.items():
|
23 |
+
ids = [i for i in ids if i in followings and len(followings[i]) == 4]
|
24 |
+
length = len(ids)
|
25 |
+
|
26 |
+
group = f.create_group(subset)
|
27 |
+
images = list()
|
28 |
+
for i in range(5):
|
29 |
+
images.append(
|
30 |
+
group.create_dataset('image{}'.format(i), (length,), dtype=h5py.vlen_dtype(np.dtype('uint8'))))
|
31 |
+
text = group.create_dataset('text', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
|
32 |
+
for i, item in enumerate(tqdm(ids, leave=True, desc="saveh5")):
|
33 |
+
globalIDs = [item] + followings[item]
|
34 |
+
txt = list()
|
35 |
+
for j, globalID in enumerate(globalIDs):
|
36 |
+
img = np.load(os.path.join(args.data_dir, 'video_frames_sampled', '{}.npy'.format(globalID)))
|
37 |
+
img = np.concatenate(img, axis=0).astype(np.uint8)
|
38 |
+
img = cv2.imencode('.png', img)[1].tobytes()
|
39 |
+
img = np.frombuffer(img, np.uint8)
|
40 |
+
images[j][i] = img
|
41 |
+
txt.append(descriptions[globalID])
|
42 |
+
text[i] = '|'.join([t.replace('\n', '').replace('\t', '').strip() for t in txt])
|
43 |
+
f.close()
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
parser = argparse.ArgumentParser(description='arguments for flintstones hdf5 file saving')
|
48 |
+
parser.add_argument('--data_dir', type=str, required=True, help='flintstones data directory')
|
49 |
+
parser.add_argument('--save_path', type=str, required=True, help='path to save hdf5')
|
50 |
+
args = parser.parse_args()
|
51 |
+
main(args)
|
data_script/pororo_hdf5.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import h5py
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
|
11 |
+
def main(args):
|
12 |
+
# 使用numpy库的load函数来加载名为descriptions.npy的文件。该文件是一个Python字典对象,因此我们使用item()方法将其转换为字典对象。
|
13 |
+
# ——os.path.join函数用于连接文件路径
|
14 |
+
# ——args.data_dir作为基础目录,将'descriptions.npy'添加到该目录中
|
15 |
+
# ——指定allow_pickle=True,表示允许加载包含Python对象的文件
|
16 |
+
# ——指定encoding='latin1',表示使用拉丁字符编码加载该文件
|
17 |
+
descriptions = np.load(os.path.join(args.data_dir, 'descriptions.npy'), allow_pickle=True, encoding='latin1').item()
|
18 |
+
# imgs_list包含一组图像文件的路径,
|
19 |
+
# followings_list包含每个图像的一些附加信息
|
20 |
+
imgs_list = np.load(os.path.join(args.data_dir, 'img_cache4.npy'), encoding='latin1')
|
21 |
+
followings_list = np.load(os.path.join(args.data_dir, 'following_cache4.npy'))
|
22 |
+
# 使用numpy库的load函数来加载名为train_seen_unseen_ids.npy的文件
|
23 |
+
# 该文件包含三个numpy数组:train_ids、val_ids和test_ids,分别代表训练集、验证集和测试集的ID列表。
|
24 |
+
# 使用元组来一次性加载这三个数组,并将它们赋值给相应的变量。
|
25 |
+
train_ids, val_ids, test_ids = np.load(os.path.join(args.data_dir, 'train_seen_unseen_ids.npy'), allow_pickle=True)
|
26 |
+
# 按照ID的顺序逐一排序
|
27 |
+
train_ids = np.sort(train_ids)
|
28 |
+
val_ids = np.sort(val_ids)
|
29 |
+
test_ids = np.sort(test_ids)
|
30 |
+
|
31 |
+
# 创建一个新的HDF5文件,并指定文件名为args.save_path。
|
32 |
+
# 使用h5py库的File函数来创建文件对象,指定打开方式为写模式("w")。
|
33 |
+
# 在这个文件中存储处理后的图像和文本数据。
|
34 |
+
f = h5py.File(args.save_path, "w")
|
35 |
+
for subset, ids in {'train': train_ids, 'val': val_ids, 'test': test_ids}.items():
|
36 |
+
length = len(ids)
|
37 |
+
|
38 |
+
# 为每个数据集(train、val和test)创建一个组
|
39 |
+
# 针对每个数据集都创建了5个数据集,名为'image0'、'image1'、'image2'、'image3'、'image4',分别对应于当前图像及其相关联的4个图像。
|
40 |
+
# 目的:将每个图像及其相关联的图像数据保存到同一个HDF5文件中,并按照一定的组织方式存储,方便后续的数据读取和处理。
|
41 |
+
group = f.create_group(subset)
|
42 |
+
# 创建一个长度为ids列表长度的空列表images,按照image0-4顺序添加了5个HDF5数据集对象
|
43 |
+
images = list()
|
44 |
+
# 为当前数据集中的每个图像创建了五个数据集。
|
45 |
+
# 每个数据集都使用vlen_dtype(np.dtype('uint8'))作为数据类型,并将其添加到当前组group中。
|
46 |
+
# ——vlen_dtype(np.dtype('uint8'))表示可变长度的无符号8位整数数组。
|
47 |
+
for i in range(5):
|
48 |
+
images.append(
|
49 |
+
group.create_dataset('image{}'.format(i), (length,), dtype=h5py.vlen_dtype(np.dtype('uint8'))))
|
50 |
+
# 创建一个数据集text,用于存储与当前数据集中图像相关的文本描述。该数据集的数据类型为字符串,编码方式为utf-8,并将其添加到当前组group中。
|
51 |
+
text = group.create_dataset('text', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
|
52 |
+
# 遍历当前数据集中的每个图像,并将相关数据保存到HDF5文件中
|
53 |
+
for i, item in enumerate(tqdm(ids, leave=True, desc="saveh5")):
|
54 |
+
# 获取与当前图像相关的所有图像的路径,存储到列表img_paths中。
|
55 |
+
# ——imgs_list是一个字典,存储了所有图像的路径
|
56 |
+
# ——followings_list是一个字典,存储了与每个图像相关的四张图像的路径
|
57 |
+
img_paths = [str(imgs_list[item])[2:-1]] + [str(followings_list[item][i])[2:-1] for i in range(4)]
|
58 |
+
# 打开img_paths列表中的每个图像,并将其转换为RGB格式的PIL图像对象。
|
59 |
+
imgs = [Image.open(os.path.join(args.data_dir, img_path)).convert('RGB') for img_path in img_paths]
|
60 |
+
# 将每个PIL图像对象转换为numpy数组
|
61 |
+
for j, img in enumerate(imgs):
|
62 |
+
img = np.array(img).astype(np.uint8)
|
63 |
+
# 使用OpenCV将其编码为png格式的二进制数据
|
64 |
+
img = cv2.imencode('.png', img)[1].tobytes()
|
65 |
+
# 将该二进制数据转换为numpy数组
|
66 |
+
img = np.frombuffer(img, np.uint8)
|
67 |
+
# 将其存储到images列表中与当前图像相关的数据集中
|
68 |
+
images[j][i] = img
|
69 |
+
# 获取与当前图像相关的所有图像的文件名,并将其存储到列表tgt_img_ids中
|
70 |
+
tgt_img_ids = [str(img_path).replace('.png', '') for img_path in img_paths]
|
71 |
+
# 根据目标图像的文件名,获取其对应的文本描述,并将其存储到列表txt中。
|
72 |
+
txt = [descriptions[tgt_img_id][0] for tgt_img_id in tgt_img_ids]
|
73 |
+
# 将txt列表中的所有文本描述合并为一个字符串,并将其中的"\n"、"\t"等无关字符替换为空格。然后,将该字符串存储到数据集text中
|
74 |
+
text[i] = '|'.join([t.replace('\n', '').replace('\t', '').strip() for t in txt])
|
75 |
+
f.close()
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == '__main__':
|
79 |
+
parser = argparse.ArgumentParser(description='arguments for flintstones pororo file saving')
|
80 |
+
parser.add_argument('--data_dir', type=str, required=True, help='pororo data directory')
|
81 |
+
parser.add_argument('--save_path', type=str, required=True, help='path to save hdf5')
|
82 |
+
args = parser.parse_args()
|
83 |
+
main(args)
|
data_script/vist_hdf5.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import h5py
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
def main(args):
|
13 |
+
train_data = json.load(open(os.path.join(args.sis_json_dir, 'train.story-in-sequence.json')))
|
14 |
+
val_data = json.load(open(os.path.join(args.sis_json_dir, 'val.story-in-sequence.json')))
|
15 |
+
test_data = json.load(open(os.path.join(args.sis_json_dir, 'test.story-in-sequence.json')))
|
16 |
+
|
17 |
+
prefix = ["train", "val", "test"]
|
18 |
+
whole_album = {}
|
19 |
+
for i, data in enumerate([train_data, val_data, test_data]):
|
20 |
+
album_mapping = {}
|
21 |
+
for annot_new in data["annotations"]:
|
22 |
+
annot = annot_new[0]
|
23 |
+
assert len(annot_new) == 1
|
24 |
+
if annot['story_id'] not in album_mapping:
|
25 |
+
album_mapping[annot['story_id']] = {"flickr_id": [annot['photo_flickr_id']],
|
26 |
+
"sis": [annot['original_text']],
|
27 |
+
"length": 1}
|
28 |
+
else:
|
29 |
+
album_mapping[annot['story_id']]["flickr_id"].append(annot['photo_flickr_id'])
|
30 |
+
album_mapping[annot['story_id']]["sis"].append(
|
31 |
+
annot['original_text'])
|
32 |
+
album_mapping[annot['story_id']]["length"] += 1
|
33 |
+
whole_album[prefix[i]] = album_mapping
|
34 |
+
|
35 |
+
for p in prefix:
|
36 |
+
deletables = []
|
37 |
+
for story_id, story in whole_album[p].items():
|
38 |
+
if story['length'] != 5:
|
39 |
+
print("deleting {}".format(story_id))
|
40 |
+
deletables.append(story_id)
|
41 |
+
continue
|
42 |
+
d = [os.path.exists(os.path.join(args.img_dir, "{}.jpg".format(_))) for _ in story["flickr_id"]]
|
43 |
+
if sum(d) < 5:
|
44 |
+
print("deleting {}".format(story_id))
|
45 |
+
deletables.append(story_id)
|
46 |
+
else:
|
47 |
+
pass
|
48 |
+
for i in deletables:
|
49 |
+
del whole_album[p][i]
|
50 |
+
|
51 |
+
train_data = json.load(open(os.path.join(args.sis_json_dir, 'train.description-in-isolation.json')))
|
52 |
+
val_data = json.load(open(os.path.join(args.sis_json_dir, 'val.description-in-isolation.json')))
|
53 |
+
test_data = json.load(open(os.path.join(args.sis_json_dir, 'test.description-in-isolation.json')))
|
54 |
+
|
55 |
+
flickr_id2text = {}
|
56 |
+
for i, data in enumerate([train_data, val_data, test_data]):
|
57 |
+
for l in data['annotations']:
|
58 |
+
assert len(l) == 1
|
59 |
+
if l[0]['photo_flickr_id'] in flickr_id2text:
|
60 |
+
flickr_id2text[l[0]['photo_flickr_id']] = \
|
61 |
+
max([flickr_id2text[l[0]['photo_flickr_id']], l[0]['original_text']], key=len)
|
62 |
+
else:
|
63 |
+
flickr_id2text[l[0]['photo_flickr_id']] = l[0]['original_text']
|
64 |
+
|
65 |
+
for p in prefix:
|
66 |
+
deletables = []
|
67 |
+
for story_id, story in whole_album[p].items():
|
68 |
+
story['dii'] = []
|
69 |
+
for i, flickr_id in enumerate(story['flickr_id']):
|
70 |
+
if flickr_id not in flickr_id2text:
|
71 |
+
print("{} not found in story {}".format(flickr_id, story_id))
|
72 |
+
deletables.append(story_id)
|
73 |
+
break
|
74 |
+
story['dii'].append(flickr_id2text[flickr_id])
|
75 |
+
for i in deletables:
|
76 |
+
del whole_album[p][i]
|
77 |
+
|
78 |
+
f = h5py.File(args.save_path, "w")
|
79 |
+
for p in prefix:
|
80 |
+
group = f.create_group(p)
|
81 |
+
story_dict = whole_album[p]
|
82 |
+
length = len(story_dict)
|
83 |
+
images = list()
|
84 |
+
for i in range(5):
|
85 |
+
images.append(
|
86 |
+
group.create_dataset('image{}'.format(i), (length,), dtype=h5py.vlen_dtype(np.dtype('uint8'))))
|
87 |
+
sis = group.create_dataset('sis', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
|
88 |
+
dii = group.create_dataset('dii', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
|
89 |
+
for i, (story_id, story) in enumerate(tqdm(story_dict.items(), leave=True, desc="saveh5")):
|
90 |
+
imgs = [Image.open('{}/{}.jpg'.format(args.img_dir, flickr_id)).convert('RGB') for flickr_id in
|
91 |
+
story['flickr_id']]
|
92 |
+
for j, img in enumerate(imgs):
|
93 |
+
img = np.array(img).astype(np.uint8)
|
94 |
+
img = cv2.imencode('.png', img)[1].tobytes()
|
95 |
+
img = np.frombuffer(img, np.uint8)
|
96 |
+
images[j][i] = img
|
97 |
+
sis[i] = '|'.join([t.replace('\n', '').replace('\t', '').strip() for t in story['sis']])
|
98 |
+
txt_dii = [t.replace('\n', '').replace('\t', '').strip() for t in story['dii']]
|
99 |
+
txt_dii = sorted(set(txt_dii), key=txt_dii.index)
|
100 |
+
dii[i] = '|'.join(txt_dii)
|
101 |
+
f.close()
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == '__main__':
|
105 |
+
parser = argparse.ArgumentParser(description='arguments for vist hdf5 file saving')
|
106 |
+
parser.add_argument('--sis_json_dir', type=str, required=True, help='sis json file directory')
|
107 |
+
parser.add_argument('--dii_json_dir', type=str, required=True, help='dii json file directory')
|
108 |
+
parser.add_argument('--img_dir', type=str, required=True, help='json file directory')
|
109 |
+
parser.add_argument('--save_path', type=str, required=True, help='path to save hdf5')
|
110 |
+
args = parser.parse_args()
|
111 |
+
main(args)
|
data_script/vist_img_download.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import requests
|
3 |
+
from io import BytesIO
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
from multiprocessing import Process
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
|
11 |
+
def download_subprocess(dii, save_dir):
|
12 |
+
for image in tqdm(dii):
|
13 |
+
key, value = image.popitem()
|
14 |
+
try:
|
15 |
+
img_data = requests.get(value).content
|
16 |
+
img = Image.open(BytesIO(img_data)).convert('RGB')
|
17 |
+
h = img.size[0]
|
18 |
+
w = img.size[1]
|
19 |
+
if min(h, w) > 512:
|
20 |
+
img = img.resize((int(h / (w / 512)), 512) if h > w else (512, int(w / (h / 512))))
|
21 |
+
img.save('{}/{}.jpg'.format(save_dir, key))
|
22 |
+
except:
|
23 |
+
print(key, value)
|
24 |
+
|
25 |
+
|
26 |
+
def main(args):
|
27 |
+
train_data = json.load(open(os.path.join(args.json_dir, 'train.description-in-isolation.json')))
|
28 |
+
val_data = json.load(open(os.path.join(args.json_dir, 'val.description-in-isolation.json')))
|
29 |
+
test_data = json.load(open(os.path.join(args.json_dir, 'test.description-in-isolation.json')))
|
30 |
+
dii = []
|
31 |
+
for subset in [train_data, val_data, test_data]:
|
32 |
+
for image in subset["images"]:
|
33 |
+
try:
|
34 |
+
dii.append({image['id']: image['url_o']})
|
35 |
+
except:
|
36 |
+
dii.append({image['id']: image['url_m']})
|
37 |
+
|
38 |
+
dii = [image for image in dii if not os.path.exists('{}/{}.jpg'.format(args.save_dir, list(image)[0]))]
|
39 |
+
print('total images: {}'.format(len(dii)))
|
40 |
+
|
41 |
+
def splitlist(inlist, chunksize):
|
42 |
+
return [inlist[x:x + chunksize] for x in range(0, len(inlist), chunksize)]
|
43 |
+
|
44 |
+
dii_splitted = splitlist(dii, int((len(dii) / args.num_process)))
|
45 |
+
process_list = []
|
46 |
+
for dii_sub_list in dii_splitted:
|
47 |
+
p = Process(target=download_subprocess, args=(dii_sub_list,))
|
48 |
+
process_list.append(p)
|
49 |
+
p.Daemon = True
|
50 |
+
p.start()
|
51 |
+
for p in process_list:
|
52 |
+
p.join()
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
parser = argparse.ArgumentParser(description='arguments for vist images downloading')
|
57 |
+
parser.add_argument('--json_dir', type=str, required=True, help='dii json file directory')
|
58 |
+
parser.add_argument('--img_dir', type=str, required=True, help='images saving directory')
|
59 |
+
parser.add_argument('--num_process', type=int, default=32)
|
60 |
+
args = parser.parse_args()
|
61 |
+
main(args)
|
datasets/flintstones.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import h5py
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from torchvision import transforms
|
9 |
+
from transformers import CLIPTokenizer
|
10 |
+
|
11 |
+
from models.blip_override.blip import init_tokenizer
|
12 |
+
|
13 |
+
|
14 |
+
class StoryDataset(Dataset):
|
15 |
+
"""
|
16 |
+
A custom subset class for the LRW (includes train, val, test) subset
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, subset, args):
|
20 |
+
super(StoryDataset, self).__init__()
|
21 |
+
self.args = args
|
22 |
+
|
23 |
+
self.h5_file = args.get(args.dataset).hdf5_file
|
24 |
+
self.subset = subset
|
25 |
+
|
26 |
+
self.augment = transforms.Compose([
|
27 |
+
transforms.ToPILImage(),
|
28 |
+
transforms.Resize([512, 512]),
|
29 |
+
transforms.ToTensor(),
|
30 |
+
transforms.Normalize([0.5], [0.5])
|
31 |
+
])
|
32 |
+
self.dataset = args.dataset
|
33 |
+
self.max_length = args.get(args.dataset).max_length
|
34 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
|
35 |
+
self.blip_tokenizer = init_tokenizer()
|
36 |
+
msg = self.clip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
|
37 |
+
print("clip {} new tokens added".format(msg))
|
38 |
+
msg = self.blip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
|
39 |
+
print("blip {} new tokens added".format(msg))
|
40 |
+
|
41 |
+
self.blip_image_processor = transforms.Compose([
|
42 |
+
transforms.ToPILImage(),
|
43 |
+
transforms.Resize([224, 224]),
|
44 |
+
transforms.ToTensor(),
|
45 |
+
transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
|
46 |
+
])
|
47 |
+
|
48 |
+
def open_h5(self):
|
49 |
+
h5 = h5py.File(self.h5_file, "r")
|
50 |
+
self.h5 = h5[self.subset]
|
51 |
+
|
52 |
+
def __getitem__(self, index):
|
53 |
+
if not hasattr(self, 'h5'):
|
54 |
+
self.open_h5()
|
55 |
+
|
56 |
+
images = list()
|
57 |
+
for i in range(5):
|
58 |
+
im = self.h5['image{}'.format(i)][index]
|
59 |
+
im = cv2.imdecode(im, cv2.IMREAD_COLOR)
|
60 |
+
idx = random.randint(0, 4)
|
61 |
+
images.append(im[idx * 128: (idx + 1) * 128])
|
62 |
+
|
63 |
+
source_images = torch.stack([self.blip_image_processor(im) for im in images])
|
64 |
+
images = images[1:] if self.args.task == 'continuation' else images
|
65 |
+
images = torch.stack([self.augment(im) for im in images]) \
|
66 |
+
if self.subset in ['train', 'val'] else torch.from_numpy(np.array(images)).permute(0, 3, 1, 2)
|
67 |
+
|
68 |
+
texts = self.h5['text'][index].decode('utf-8').split('|')
|
69 |
+
|
70 |
+
# tokenize caption using default tokenizer
|
71 |
+
tokenized = self.clip_tokenizer(
|
72 |
+
texts[1:] if self.args.task == 'continuation' else texts,
|
73 |
+
padding="max_length",
|
74 |
+
max_length=self.max_length,
|
75 |
+
truncation=False,
|
76 |
+
return_tensors="pt",
|
77 |
+
)
|
78 |
+
captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
79 |
+
|
80 |
+
tokenized = self.blip_tokenizer(
|
81 |
+
texts,
|
82 |
+
padding="max_length",
|
83 |
+
max_length=self.max_length,
|
84 |
+
truncation=False,
|
85 |
+
return_tensors="pt",
|
86 |
+
)
|
87 |
+
source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
88 |
+
return images, captions, attention_mask, source_images, source_caption, source_attention_mask
|
89 |
+
|
90 |
+
def __len__(self):
|
91 |
+
if not hasattr(self, 'h5'):
|
92 |
+
self.open_h5()
|
93 |
+
return len(self.h5['text'])
|
datasets/pororo.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from PIL import Image
|
5 |
+
import cv2
|
6 |
+
import h5py
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
from torchvision import transforms
|
11 |
+
from transformers import CLIPTokenizer
|
12 |
+
|
13 |
+
from models.blip_override.blip import init_tokenizer
|
14 |
+
|
15 |
+
|
16 |
+
class StoryDataset(Dataset):
|
17 |
+
"""
|
18 |
+
A custom subset class for the LRW (includes train, val, test) subset
|
19 |
+
"""
|
20 |
+
# StoryDataset 类的构造函数
|
21 |
+
def __init__(self, subset, args):
|
22 |
+
# 用来调用父类 Dataset 的初始化函数,确保该类能够继承 Dataset 类的所有方法和属性。
|
23 |
+
super(StoryDataset, self).__init__()
|
24 |
+
# args 则是该类的其他参数,是一个命名空间(namespace)对象
|
25 |
+
self.args = args
|
26 |
+
# 一个 HDF5 文件的路径,存储了训练、验证和测试集的图像和文本数据。
|
27 |
+
# ——args.get(args.dataset)表示从命名空间对象args中获取指定数据集(训练集、验证集或测试集)的参数。
|
28 |
+
self.h5_file = args.get(args.dataset).hdf5_file
|
29 |
+
# 初始化函数中 subset 表示要读取的子集的类型(如训练集、验证集、测试集)
|
30 |
+
self.subset = subset
|
31 |
+
|
32 |
+
# 一个图像变换函数序列(transform),用来对图像进行预处理,包括将图像转化为 PIL 格式,调整图像大小,将图像转换为 Tensor,并进行归一化。
|
33 |
+
self.augment = transforms.Compose([
|
34 |
+
transforms.ToPILImage(),
|
35 |
+
# transforms.Resize([256, 256]),
|
36 |
+
transforms.Resize([512, 512]),
|
37 |
+
transforms.ToTensor(),
|
38 |
+
transforms.Normalize([0.5], [0.5])
|
39 |
+
])
|
40 |
+
# 表示当前数据集的类型(训练集、验证集或测试集)
|
41 |
+
self.dataset = args.dataset
|
42 |
+
# 最大的 caption 长度,在进行tokenize操作时,caption中的单词数量将被填充到该长度。
|
43 |
+
self.max_length = args.get(args.dataset).max_length
|
44 |
+
# 一个使用CLIP模型进行tokenize的tokenizer
|
45 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
|
46 |
+
# 一个自定义的tokenizer,用于处理文本输入
|
47 |
+
self.blip_tokenizer = init_tokenizer()
|
48 |
+
msg = self.clip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
|
49 |
+
print("clip {} new tokens added".format(msg))
|
50 |
+
msg = self.blip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
|
51 |
+
print("blip {} new tokens added".format(msg))
|
52 |
+
|
53 |
+
# 一个用于对输入的图像进行处理的函数序列,包括转换为PIL图像、重置图像大小、转换为tensor、归一化等。
|
54 |
+
self.blip_image_processor = transforms.Compose([
|
55 |
+
transforms.ToPILImage(),
|
56 |
+
transforms.Resize([224, 224]),
|
57 |
+
transforms.ToTensor(),
|
58 |
+
transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
|
59 |
+
])
|
60 |
+
|
61 |
+
# 打开与数据集对应的h5文件
|
62 |
+
def open_h5(self):
|
63 |
+
h5 = h5py.File(self.h5_file, "r")
|
64 |
+
self.h5 = h5[self.subset]
|
65 |
+
|
66 |
+
# 用于按索引获取数据。
|
67 |
+
|
68 |
+
# 对于每个图像,都进行数据增强操作,以进行数据增强。
|
69 |
+
# 然后,将文本输入的caption进行tokenize操作,
|
70 |
+
# 使用CLIP tokenizer和自定义tokenizer分别进行tokenize。
|
71 |
+
# 最后,将处理好的图像、caption和attention mask返回
|
72 |
+
def __getitem__(self, index):
|
73 |
+
# 首先调用open_h5()打开数据集的h5文件
|
74 |
+
if not hasattr(self, 'h5'):
|
75 |
+
self.open_h5()
|
76 |
+
#index = 1
|
77 |
+
images = list()
|
78 |
+
for i in range(5):
|
79 |
+
# 从h5文件中读取一组图像和对应的文本。
|
80 |
+
im = self.h5['image{}'.format(i)][index]
|
81 |
+
# print(im)
|
82 |
+
# pil_img = Image.fromarray(im)
|
83 |
+
# # 保存图像
|
84 |
+
# pil_img.save(os.path.join('/root/lihui/StoryVisualization/ori_test_images', '{:04d}.png'.format(i)))
|
85 |
+
# 对每个图像解码
|
86 |
+
im = cv2.imdecode(im, cv2.IMREAD_COLOR)
|
87 |
+
# 随机选择一个128像素的图像切片
|
88 |
+
idx = random.randint(0, im.shape[0] / 128 - 1)
|
89 |
+
# 将切片后的图像加到images列表中
|
90 |
+
images.append(im[idx * 128: (idx + 1) * 128])
|
91 |
+
# 深拷贝,后续不随images变化
|
92 |
+
ori_images = copy.deepcopy(images)
|
93 |
+
# 保存test原始图像
|
94 |
+
|
95 |
+
# for i, im in enumerate(images):
|
96 |
+
# file_path = '/root/lihui/StoryVisualization/ori_test_images/group{:02d}_image{:02d}.png'.format(index + 1,
|
97 |
+
# i + 1)
|
98 |
+
# cv2.imwrite(file_path, im)
|
99 |
+
# 将图像转换为张量
|
100 |
+
source_images = torch.stack([self.blip_image_processor(im) for im in images])
|
101 |
+
# 如果为continuation任务,将列表中的第一个图像从images中移除
|
102 |
+
images = images[1:] if self.args.task == 'continuation' else images
|
103 |
+
# 如果subset的值为train/val,则使用augment方法对images列表中的所有图像进行数据增强,并将其转换为张量
|
104 |
+
# 否则使用numpy.array方法将images列表转换为张量,并进行转置操作
|
105 |
+
images = torch.stack([self.augment(im) for im in images]) \
|
106 |
+
if self.subset in ['train', 'val'] else torch.from_numpy(np.array(images)).permute(0, 3, 1, 2)
|
107 |
+
######################
|
108 |
+
# 读取当前索引处的文本,并使用decode方法将其解码为UTF-8
|
109 |
+
texts = self.h5['text'][index].decode('utf-8').split('|')
|
110 |
+
# print(f"index: {index}")
|
111 |
+
# for text in texts:
|
112 |
+
# print(f"texts: {text}")
|
113 |
+
|
114 |
+
# tokenize caption using default tokenizer
|
115 |
+
tokenized = self.clip_tokenizer(
|
116 |
+
texts[1:] if self.args.task == 'continuation' else texts,
|
117 |
+
padding="max_length",
|
118 |
+
max_length=self.max_length,
|
119 |
+
truncation=False,
|
120 |
+
return_tensors="pt",
|
121 |
+
)
|
122 |
+
captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
123 |
+
|
124 |
+
tokenized = self.blip_tokenizer(
|
125 |
+
texts,
|
126 |
+
padding="max_length",
|
127 |
+
max_length=self.max_length,
|
128 |
+
truncation=False,
|
129 |
+
return_tensors="pt",
|
130 |
+
)
|
131 |
+
source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
132 |
+
return images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_images
|
133 |
+
|
134 |
+
# 返回数据集中样本的数量
|
135 |
+
# 如果是测试集,则返回100,否则返回对应的数据集中的样本数量
|
136 |
+
def __len__(self):
|
137 |
+
if not hasattr(self, 'h5'):
|
138 |
+
self.open_h5()
|
139 |
+
if self.subset == 'test':
|
140 |
+
#print('')
|
141 |
+
return 1
|
142 |
+
# if self.subset == 'test':
|
143 |
+
# return 100
|
144 |
+
return len(self.h5['text'])
|
datasets/vistdii.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import h5py
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision import transforms
|
7 |
+
from transformers import CLIPTokenizer
|
8 |
+
|
9 |
+
from models.blip_override.blip import init_tokenizer
|
10 |
+
|
11 |
+
|
12 |
+
class StoryDataset(Dataset):
|
13 |
+
"""
|
14 |
+
A custom subset class for the LRW (includes train, val, test) subset
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, subset, args):
|
18 |
+
super(StoryDataset, self).__init__()
|
19 |
+
self.args = args
|
20 |
+
|
21 |
+
self.h5_file = args.get(args.dataset).hdf5_file
|
22 |
+
self.subset = subset
|
23 |
+
|
24 |
+
self.augment = transforms.Compose([
|
25 |
+
transforms.ToPILImage(),
|
26 |
+
transforms.Resize(512),
|
27 |
+
transforms.RandomCrop(512) if self.subset == 'train' else transforms.CenterCrop(512),
|
28 |
+
transforms.ToTensor(),
|
29 |
+
transforms.Normalize([0.5], [0.5])
|
30 |
+
]) if self.subset in ['train', 'val'] else transforms.Compose([
|
31 |
+
transforms.ToPILImage(),
|
32 |
+
transforms.Resize(64),
|
33 |
+
transforms.CenterCrop(64)
|
34 |
+
])
|
35 |
+
|
36 |
+
self.dataset = args.dataset
|
37 |
+
self.max_length = args.get(args.dataset).max_length
|
38 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
|
39 |
+
self.blip_tokenizer = init_tokenizer()
|
40 |
+
|
41 |
+
self.blip_image_processor = transforms.Compose([
|
42 |
+
transforms.ToPILImage(),
|
43 |
+
transforms.Resize(224),
|
44 |
+
transforms.RandomCrop(224) if self.subset == 'train' else transforms.CenterCrop(224),
|
45 |
+
transforms.ToTensor(),
|
46 |
+
transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
|
47 |
+
])
|
48 |
+
|
49 |
+
def open_h5(self):
|
50 |
+
h5 = h5py.File(self.h5_file, "r")
|
51 |
+
self.h5 = h5[self.subset]
|
52 |
+
|
53 |
+
def __getitem__(self, index):
|
54 |
+
if not hasattr(self, 'h5'):
|
55 |
+
self.open_h5()
|
56 |
+
|
57 |
+
images = list()
|
58 |
+
for i in range(5):
|
59 |
+
im = self.h5['image{}'.format(i)][index]
|
60 |
+
im = cv2.imdecode(im, cv2.IMREAD_COLOR)
|
61 |
+
images.append(im)
|
62 |
+
|
63 |
+
source_images = torch.stack([self.blip_image_processor(im) for im in images])
|
64 |
+
images = images[1:] if self.args.task == 'continuation' else images
|
65 |
+
images = [self.augment(im) for im in images]
|
66 |
+
images = torch.stack(images) if self.subset in ['train', 'val'] \
|
67 |
+
else torch.from_numpy(np.array([np.array(im) for im in images])).permute(0, 3, 1, 2)
|
68 |
+
|
69 |
+
texts = self.h5['dii'][index].decode('utf-8').split('|')
|
70 |
+
|
71 |
+
# tokenize caption using default tokenizer
|
72 |
+
tokenized = self.clip_tokenizer(
|
73 |
+
texts[1:] if self.args.task == 'continuation' else texts,
|
74 |
+
padding="max_length",
|
75 |
+
max_length=self.max_length,
|
76 |
+
truncation=False,
|
77 |
+
return_tensors="pt",
|
78 |
+
)
|
79 |
+
captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
80 |
+
|
81 |
+
tokenized = self.blip_tokenizer(
|
82 |
+
texts,
|
83 |
+
padding="max_length",
|
84 |
+
max_length=self.max_length,
|
85 |
+
truncation=False,
|
86 |
+
return_tensors="pt",
|
87 |
+
)
|
88 |
+
source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
89 |
+
return images, captions, attention_mask, source_images, source_caption, source_attention_mask
|
90 |
+
|
91 |
+
def __len__(self):
|
92 |
+
if not hasattr(self, 'h5'):
|
93 |
+
self.open_h5()
|
94 |
+
return len(self.h5['dii'])
|
datasets/vistsis.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import h5py
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision import transforms
|
7 |
+
from transformers import CLIPTokenizer
|
8 |
+
|
9 |
+
from models.blip_override.blip import init_tokenizer
|
10 |
+
|
11 |
+
|
12 |
+
class StoryDataset(Dataset):
|
13 |
+
"""
|
14 |
+
A custom subset class for the LRW (includes train, val, test) subset
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, subset, args):
|
18 |
+
super(StoryDataset, self).__init__()
|
19 |
+
self.args = args
|
20 |
+
|
21 |
+
self.h5_file = args.get(args.dataset).hdf5_file
|
22 |
+
self.subset = subset
|
23 |
+
|
24 |
+
self.augment = transforms.Compose([
|
25 |
+
transforms.ToPILImage(),
|
26 |
+
transforms.Resize(512),
|
27 |
+
transforms.RandomCrop(512) if self.subset == 'train' else transforms.CenterCrop(512),
|
28 |
+
transforms.ToTensor(),
|
29 |
+
transforms.Normalize([0.5], [0.5])
|
30 |
+
]) if self.subset in ['train', 'val'] else transforms.Compose([
|
31 |
+
transforms.ToPILImage(),
|
32 |
+
transforms.Resize(64),
|
33 |
+
transforms.CenterCrop(64)
|
34 |
+
])
|
35 |
+
|
36 |
+
self.dataset = args.dataset
|
37 |
+
self.max_length = args.get(args.dataset).max_length
|
38 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
|
39 |
+
self.blip_tokenizer = init_tokenizer()
|
40 |
+
|
41 |
+
self.blip_image_processor = transforms.Compose([
|
42 |
+
transforms.ToPILImage(),
|
43 |
+
transforms.Resize(224),
|
44 |
+
transforms.RandomCrop(224) if self.subset == 'train' else transforms.CenterCrop(224),
|
45 |
+
transforms.ToTensor(),
|
46 |
+
transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
|
47 |
+
])
|
48 |
+
|
49 |
+
def open_h5(self):
|
50 |
+
h5 = h5py.File(self.h5_file, "r")
|
51 |
+
self.h5 = h5[self.subset]
|
52 |
+
|
53 |
+
def __getitem__(self, index):
|
54 |
+
if not hasattr(self, 'h5'):
|
55 |
+
self.open_h5()
|
56 |
+
|
57 |
+
images = list()
|
58 |
+
for i in range(5):
|
59 |
+
im = self.h5['image{}'.format(i)][index]
|
60 |
+
im = cv2.imdecode(im, cv2.IMREAD_COLOR)
|
61 |
+
images.append(im)
|
62 |
+
|
63 |
+
source_images = torch.stack([self.blip_image_processor(im) for im in images])
|
64 |
+
images = images[1:] if self.args.task == 'continuation' else images
|
65 |
+
images = [self.augment(im) for im in images]
|
66 |
+
images = torch.stack(images) if self.subset in ['train', 'val'] \
|
67 |
+
else torch.from_numpy(np.array([np.array(im) for im in images])).permute(0, 3, 1, 2)
|
68 |
+
|
69 |
+
texts = self.h5['sis'][index].decode('utf-8').split('|')
|
70 |
+
|
71 |
+
# tokenize caption using default tokenizer
|
72 |
+
tokenized = self.clip_tokenizer(
|
73 |
+
texts[1:] if self.args.task == 'continuation' else texts,
|
74 |
+
padding="max_length",
|
75 |
+
max_length=self.max_length,
|
76 |
+
truncation=False,
|
77 |
+
return_tensors="pt",
|
78 |
+
)
|
79 |
+
captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
80 |
+
|
81 |
+
tokenized = self.blip_tokenizer(
|
82 |
+
texts,
|
83 |
+
padding="max_length",
|
84 |
+
max_length=self.max_length,
|
85 |
+
truncation=False,
|
86 |
+
return_tensors="pt",
|
87 |
+
)
|
88 |
+
source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
89 |
+
return images, captions, attention_mask, source_images, source_caption, source_attention_mask
|
90 |
+
|
91 |
+
def __len__(self):
|
92 |
+
if not hasattr(self, 'h5'):
|
93 |
+
self.open_h5()
|
94 |
+
return len(self.h5['sis'])
|
environment.yml
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: story
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- _libgcc_mutex=0.1=main
|
8 |
+
- _openmp_mutex=5.1=1_gnu
|
9 |
+
- blas=1.0=mkl
|
10 |
+
- brotlipy=0.7.0=py38h27cfd23_1003
|
11 |
+
- bzip2=1.0.8=h7b6447c_0
|
12 |
+
- ca-certificates=2023.01.10=h06a4308_0
|
13 |
+
- certifi=2022.12.7=py38h06a4308_0
|
14 |
+
- cffi=1.15.1=py38h5eee18b_3
|
15 |
+
- cryptography=39.0.1=py38h9ce1e76_0
|
16 |
+
- cuda-cudart=11.7.99=0
|
17 |
+
- cuda-cupti=11.7.101=0
|
18 |
+
- cuda-libraries=11.7.1=0
|
19 |
+
- cuda-nvrtc=11.7.99=0
|
20 |
+
- cuda-nvtx=11.7.91=0
|
21 |
+
- cuda-runtime=11.7.1=0
|
22 |
+
- ffmpeg=4.3=hf484d3e_0
|
23 |
+
- flit-core=3.8.0=py38h06a4308_0
|
24 |
+
- freetype=2.12.1=h4a9f257_0
|
25 |
+
- giflib=5.2.1=h5eee18b_3
|
26 |
+
- gmp=6.2.1=h295c915_3
|
27 |
+
- gnutls=3.6.15=he1e5248_0
|
28 |
+
- idna=3.4=py38h06a4308_0
|
29 |
+
- intel-openmp=2021.4.0=h06a4308_3561
|
30 |
+
- jpeg=9e=h5eee18b_1
|
31 |
+
- lame=3.100=h7b6447c_0
|
32 |
+
- lcms2=2.12=h3be6417_0
|
33 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
34 |
+
- lerc=3.0=h295c915_0
|
35 |
+
- libcublas=11.10.3.66=0
|
36 |
+
- libcufft=10.7.2.124=h4fbf590_0
|
37 |
+
- libcufile=1.6.0.25=0
|
38 |
+
- libcurand=10.3.2.56=0
|
39 |
+
- libcusolver=11.4.0.1=0
|
40 |
+
- libcusparse=11.7.4.91=0
|
41 |
+
- libdeflate=1.17=h5eee18b_0
|
42 |
+
- libffi=3.4.2=h6a678d5_6
|
43 |
+
- libgcc-ng=11.2.0=h1234567_1
|
44 |
+
- libgomp=11.2.0=h1234567_1
|
45 |
+
- libiconv=1.16=h7f8727e_2
|
46 |
+
- libidn2=2.3.2=h7f8727e_0
|
47 |
+
- libnpp=11.7.4.75=0
|
48 |
+
- libnvjpeg=11.8.0.2=0
|
49 |
+
- libpng=1.6.39=h5eee18b_0
|
50 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
51 |
+
- libtasn1=4.19.0=h5eee18b_0
|
52 |
+
- libtiff=4.5.0=h6a678d5_2
|
53 |
+
- libunistring=0.9.10=h27cfd23_0
|
54 |
+
- libwebp=1.2.4=h11a3e52_1
|
55 |
+
- libwebp-base=1.2.4=h5eee18b_1
|
56 |
+
- lz4-c=1.9.4=h6a678d5_0
|
57 |
+
- mkl=2021.4.0=h06a4308_640
|
58 |
+
- mkl-service=2.4.0=py38h7f8727e_0
|
59 |
+
- mkl_fft=1.3.1=py38hd3c417c_0
|
60 |
+
- mkl_random=1.2.2=py38h51133e4_0
|
61 |
+
- ncurses=6.4=h6a678d5_0
|
62 |
+
- nettle=3.7.3=hbbd107a_1
|
63 |
+
- numpy-base=1.23.5=py38h31eccc5_0
|
64 |
+
- openh264=2.1.1=h4ff587b_0
|
65 |
+
- openssl=1.1.1t=h7f8727e_0
|
66 |
+
- pip=23.0.1=py38h06a4308_0
|
67 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
68 |
+
- pyopenssl=23.0.0=py38h06a4308_0
|
69 |
+
- pysocks=1.7.1=py38h06a4308_0
|
70 |
+
- python=3.8.16=h7a1cb2a_3
|
71 |
+
- pytorch=1.13.1=py3.8_cuda11.7_cudnn8.5.0_0
|
72 |
+
- pytorch-cuda=11.7=h778d358_3
|
73 |
+
- pytorch-mutex=1.0=cuda
|
74 |
+
- readline=8.2=h5eee18b_0
|
75 |
+
- six=1.16.0=pyhd3eb1b0_1
|
76 |
+
- sqlite=3.41.1=h5eee18b_0
|
77 |
+
- tk=8.6.12=h1ccaba5_0
|
78 |
+
- typing_extensions=4.4.0=py38h06a4308_0
|
79 |
+
- urllib3=1.26.15=py38h06a4308_0
|
80 |
+
- wheel=0.38.4=py38h06a4308_0
|
81 |
+
- xz=5.2.10=h5eee18b_1
|
82 |
+
- zlib=1.2.13=h5eee18b_0
|
83 |
+
- zstd=1.5.4=hc292b87_0
|
84 |
+
- pip:
|
85 |
+
- absl-py==1.4.0
|
86 |
+
- accelerate==0.17.1
|
87 |
+
- aiofiles==23.1.0
|
88 |
+
- aiohttp==3.8.4
|
89 |
+
- aiosignal==1.3.1
|
90 |
+
- altair==4.2.2
|
91 |
+
- antlr4-python3-runtime==4.9.3
|
92 |
+
- anyio==3.6.2
|
93 |
+
- appdirs==1.4.4
|
94 |
+
- argon2-cffi==21.3.0
|
95 |
+
- argon2-cffi-bindings==21.2.0
|
96 |
+
- arrow==1.2.3
|
97 |
+
- asttokens==2.2.1
|
98 |
+
- async-timeout==4.0.2
|
99 |
+
- attrs==22.2.0
|
100 |
+
- backcall==0.2.0
|
101 |
+
- beautifulsoup4==4.11.2
|
102 |
+
- bleach==6.0.0
|
103 |
+
- cachetools==5.3.0
|
104 |
+
- chardet==5.1.0
|
105 |
+
- charset-normalizer==3.1.0
|
106 |
+
- click==8.1.3
|
107 |
+
- comm==0.1.2
|
108 |
+
- contourpy==1.0.7
|
109 |
+
- cycler==0.11.0
|
110 |
+
- debugpy==1.6.6
|
111 |
+
- decorator==5.1.1
|
112 |
+
- defusedxml==0.7.1
|
113 |
+
- diffusers==0.9.0
|
114 |
+
- docker-pycreds==0.4.0
|
115 |
+
- entrypoints==0.4
|
116 |
+
- executing==1.2.0
|
117 |
+
- fastapi==0.95.0
|
118 |
+
- fastjsonschema==2.16.3
|
119 |
+
- ffmpy==0.3.0
|
120 |
+
- filelock==3.10.0
|
121 |
+
- fire==0.5.0
|
122 |
+
- flatbuffers==23.3.3
|
123 |
+
- fonttools==4.39.3
|
124 |
+
- fqdn==1.5.1
|
125 |
+
- frozenlist==1.3.3
|
126 |
+
- fsspec==2023.3.0
|
127 |
+
- ftfy==6.1.1
|
128 |
+
- gitdb==4.0.10
|
129 |
+
- gitpython==3.1.31
|
130 |
+
- google-auth==2.16.2
|
131 |
+
- google-auth-oauthlib==0.4.6
|
132 |
+
- gradio==3.24.1
|
133 |
+
- gradio-client==0.0.5
|
134 |
+
- grpcio==1.51.3
|
135 |
+
- h11==0.14.0
|
136 |
+
- h5py==3.8.0
|
137 |
+
- httpcore==0.16.3
|
138 |
+
- httpx==0.23.3
|
139 |
+
- huggingface-hub==0.13.2
|
140 |
+
- hydra-core==1.3.2
|
141 |
+
- importlib-metadata==6.1.0
|
142 |
+
- importlib-resources==5.12.0
|
143 |
+
- ipykernel==6.21.3
|
144 |
+
- ipython==8.11.0
|
145 |
+
- ipython-genutils==0.2.0
|
146 |
+
- ipywidgets==8.0.4
|
147 |
+
- isoduration==20.11.0
|
148 |
+
- jedi==0.18.2
|
149 |
+
- jinja2==3.1.2
|
150 |
+
- jsonpointer==2.3
|
151 |
+
- jsonschema==4.17.3
|
152 |
+
- jupyter==1.0.0
|
153 |
+
- jupyter-client==8.0.3
|
154 |
+
- jupyter-console==6.6.3
|
155 |
+
- jupyter-core==5.3.0
|
156 |
+
- jupyter-events==0.6.3
|
157 |
+
- jupyter-server==2.5.0
|
158 |
+
- jupyter-server-terminals==0.4.4
|
159 |
+
- jupyterlab-pygments==0.2.2
|
160 |
+
- jupyterlab-widgets==3.0.5
|
161 |
+
- kiwisolver==1.4.4
|
162 |
+
- lightning-bolts==0.5.0
|
163 |
+
- linkify-it-py==2.0.0
|
164 |
+
- lora-diffusion==0.1.7
|
165 |
+
- markdown==3.4.1
|
166 |
+
- markdown-it-py==2.2.0
|
167 |
+
- markupsafe==2.1.2
|
168 |
+
- matplotlib==3.7.1
|
169 |
+
- matplotlib-inline==0.1.6
|
170 |
+
- mdit-py-plugins==0.3.3
|
171 |
+
- mdurl==0.1.2
|
172 |
+
- mediapipe==0.9.1.0
|
173 |
+
- mistune==2.0.5
|
174 |
+
- multidict==6.0.4
|
175 |
+
- nbclassic==0.5.3
|
176 |
+
- nbclient==0.7.2
|
177 |
+
- nbconvert==7.2.10
|
178 |
+
- nbformat==5.7.3
|
179 |
+
- nest-asyncio==1.5.6
|
180 |
+
- notebook==6.5.3
|
181 |
+
- notebook-shim==0.2.2
|
182 |
+
- numpy==1.24.2
|
183 |
+
- oauthlib==3.2.2
|
184 |
+
- omegaconf==2.3.0
|
185 |
+
- opencv-contrib-python==4.7.0.72
|
186 |
+
- opencv-python==4.7.0.72
|
187 |
+
- orjson==3.8.9
|
188 |
+
- packaging==23.0
|
189 |
+
- pandas==1.5.3
|
190 |
+
- pandocfilters==1.5.0
|
191 |
+
- parso==0.8.3
|
192 |
+
- pathtools==0.1.2
|
193 |
+
- pexpect==4.8.0
|
194 |
+
- pickleshare==0.7.5
|
195 |
+
- pillow==9.4.0
|
196 |
+
- pkgutil-resolve-name==1.3.10
|
197 |
+
- platformdirs==3.1.1
|
198 |
+
- prometheus-client==0.16.0
|
199 |
+
- prompt-toolkit==3.0.38
|
200 |
+
- protobuf==3.20.1
|
201 |
+
- psutil==5.9.4
|
202 |
+
- ptyprocess==0.7.0
|
203 |
+
- pure-eval==0.2.2
|
204 |
+
- pyasn1==0.4.8
|
205 |
+
- pyasn1-modules==0.2.8
|
206 |
+
- pydantic==1.10.7
|
207 |
+
- pydeprecate==0.3.2
|
208 |
+
- pydub==0.25.1
|
209 |
+
- pygments==2.14.0
|
210 |
+
- pyparsing==3.0.9
|
211 |
+
- pyrsistent==0.19.3
|
212 |
+
- python-dateutil==2.8.2
|
213 |
+
- python-json-logger==2.0.7
|
214 |
+
- python-multipart==0.0.6
|
215 |
+
- pytorch-lightning==1.6.5
|
216 |
+
- pytz==2023.3
|
217 |
+
- pyyaml==6.0
|
218 |
+
- pyzmq==25.0.1
|
219 |
+
- qtconsole==5.4.1
|
220 |
+
- qtpy==2.3.0
|
221 |
+
- regex==2022.10.31
|
222 |
+
- requests==2.28.2
|
223 |
+
- requests-oauthlib==1.3.1
|
224 |
+
- rfc3339-validator==0.1.4
|
225 |
+
- rfc3986==1.5.0
|
226 |
+
- rfc3986-validator==0.1.1
|
227 |
+
- rsa==4.9
|
228 |
+
- safetensors==0.3.0
|
229 |
+
- scipy==1.10.1
|
230 |
+
- semantic-version==2.10.0
|
231 |
+
- send2trash==1.8.0
|
232 |
+
- sentry-sdk==1.17.0
|
233 |
+
- setproctitle==1.3.2
|
234 |
+
- setuptools==59.5.0
|
235 |
+
- smmap==5.0.0
|
236 |
+
- sniffio==1.3.0
|
237 |
+
- soupsieve==2.4
|
238 |
+
- stack-data==0.6.2
|
239 |
+
- starlette==0.26.1
|
240 |
+
- tensorboard==2.12.0
|
241 |
+
- tensorboard-data-server==0.7.0
|
242 |
+
- tensorboard-plugin-wit==1.8.1
|
243 |
+
- termcolor==2.2.0
|
244 |
+
- terminado==0.17.1
|
245 |
+
- timm==0.6.12
|
246 |
+
- tinycss2==1.2.1
|
247 |
+
- tokenizers==0.13.2
|
248 |
+
- toolz==0.12.0
|
249 |
+
- torch==1.9.0
|
250 |
+
- torchaudio==0.9.0
|
251 |
+
- torchmetrics==0.11.4
|
252 |
+
- torchvision==0.10.0+cu111
|
253 |
+
- tornado==6.2
|
254 |
+
- tqdm==4.65.0
|
255 |
+
- traitlets==5.9.0
|
256 |
+
- transformers==4.28.1
|
257 |
+
- typing-extensions==4.5.0
|
258 |
+
- uc-micro-py==1.0.1
|
259 |
+
- uri-template==1.2.0
|
260 |
+
- uvicorn==0.21.1
|
261 |
+
- wandb==0.14.0
|
262 |
+
- wcwidth==0.2.6
|
263 |
+
- webcolors==1.12
|
264 |
+
- webencodings==0.5.1
|
265 |
+
- websocket-client==1.5.1
|
266 |
+
- websockets==11.0
|
267 |
+
- werkzeug==2.2.3
|
268 |
+
- widgetsnbextension==4.0.5
|
269 |
+
- yarl==1.8.2
|
270 |
+
- zipp==3.15.0
|
271 |
+
prefix: /root/anaconda3/envs/story
|
fid_utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy import linalg
|
3 |
+
|
4 |
+
|
5 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
6 |
+
mu1 = np.atleast_1d(mu1)
|
7 |
+
mu2 = np.atleast_1d(mu2)
|
8 |
+
|
9 |
+
sigma1 = np.atleast_2d(sigma1)
|
10 |
+
sigma2 = np.atleast_2d(sigma2)
|
11 |
+
|
12 |
+
assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
|
13 |
+
assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'
|
14 |
+
|
15 |
+
diff = mu1 - mu2
|
16 |
+
|
17 |
+
# Product might be almost singular
|
18 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
19 |
+
if not np.isfinite(covmean).all():
|
20 |
+
print('fid calculation produces singular product; adding %s to diagonal of cov estimates' % eps)
|
21 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
22 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
23 |
+
|
24 |
+
# Numerical error might give slight imaginary component
|
25 |
+
if np.iscomplexobj(covmean):
|
26 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
27 |
+
m = np.max(np.abs(covmean.imag))
|
28 |
+
raise ValueError('Imaginary component {}'.format(m))
|
29 |
+
covmean = covmean.real
|
30 |
+
|
31 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
|
32 |
+
|
33 |
+
|
34 |
+
def calculate_fid_given_features(feature1, feature2):
|
35 |
+
mu1 = np.mean(feature1, axis=0)
|
36 |
+
sigma1 = np.cov(feature1, rowvar=False)
|
37 |
+
mu2 = np.mean(feature2, axis=0)
|
38 |
+
sigma2 = np.cov(feature2, rowvar=False)
|
39 |
+
fid_value = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
|
40 |
+
|
41 |
+
return fid_value
|
main.py
ADDED
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import os
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import hydra
|
6 |
+
import numpy as np
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.checkpoint
|
11 |
+
from PIL import Image
|
12 |
+
from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler
|
13 |
+
from omegaconf import DictConfig
|
14 |
+
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
|
15 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
16 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
17 |
+
from pytorch_lightning.strategies import DDPStrategy
|
18 |
+
from torch import nn
|
19 |
+
from torch.utils.data import DataLoader
|
20 |
+
from torchvision import transforms
|
21 |
+
from transformers import CLIPTokenizer, CLIPTextModel
|
22 |
+
|
23 |
+
from fid_utils import calculate_fid_given_features
|
24 |
+
from lora_diffusion import monkeypatch_or_replace_lora, tune_lora_scale
|
25 |
+
|
26 |
+
from models.blip_override.blip import blip_feature_extractor, init_tokenizer
|
27 |
+
from models.diffusers_override.unet_2d_condition import UNet2DConditionModel
|
28 |
+
from models.inception import InceptionV3
|
29 |
+
unet_target_replace_module = {"CrossAttention", "Attention", "GEGLU"}
|
30 |
+
#!/usr/bin/env python3
|
31 |
+
from transformers import CLIPProcessor
|
32 |
+
import transformers
|
33 |
+
from PIL import Image
|
34 |
+
import PIL.Image
|
35 |
+
import numpy as np
|
36 |
+
import torchvision.transforms as tvtrans
|
37 |
+
import requests
|
38 |
+
from io import BytesIO
|
39 |
+
|
40 |
+
class LightningDataset(pl.LightningDataModule):
|
41 |
+
def __init__(self, args: DictConfig):
|
42 |
+
super(LightningDataset, self).__init__()
|
43 |
+
self.kwargs = {"num_workers": args.num_workers, "persistent_workers": True if args.num_workers > 0 else False,
|
44 |
+
"pin_memory": True}
|
45 |
+
self.args = args
|
46 |
+
|
47 |
+
def setup(self, stage="fit"):
|
48 |
+
if self.args.dataset == "pororo":
|
49 |
+
import datasets.pororo as data
|
50 |
+
elif self.args.dataset == 'flintstones':
|
51 |
+
import datasets.flintstones as data
|
52 |
+
elif self.args.dataset == 'vistsis':
|
53 |
+
import datasets.vistsis as data
|
54 |
+
elif self.args.dataset == 'vistdii':
|
55 |
+
import datasets.vistdii as data
|
56 |
+
else:
|
57 |
+
raise ValueError("Unknown dataset: {}".format(self.args.dataset))
|
58 |
+
if stage == "fit":
|
59 |
+
self.train_data = data.StoryDataset("train", self.args)
|
60 |
+
self.val_data = data.StoryDataset("val", self.args)
|
61 |
+
if stage == "test":
|
62 |
+
self.test_data = data.StoryDataset("test", self.args)
|
63 |
+
|
64 |
+
def train_dataloader(self):
|
65 |
+
if not hasattr(self, 'trainloader'):
|
66 |
+
self.trainloader = DataLoader(self.train_data, batch_size=self.args.batch_size, shuffle=True, **self.kwargs)
|
67 |
+
return self.trainloader
|
68 |
+
|
69 |
+
def val_dataloader(self):
|
70 |
+
return DataLoader(self.val_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)
|
71 |
+
|
72 |
+
def test_dataloader(self):
|
73 |
+
return DataLoader(self.test_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)
|
74 |
+
|
75 |
+
def predict_dataloader(self):
|
76 |
+
return DataLoader(self.test_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)
|
77 |
+
|
78 |
+
def get_length_of_train_dataloader(self):
|
79 |
+
if not hasattr(self, 'trainloader'):
|
80 |
+
self.trainloader = DataLoader(self.train_data, batch_size=self.args.batch_size, shuffle=True, **self.kwargs)
|
81 |
+
return len(self.trainloader)
|
82 |
+
|
83 |
+
|
84 |
+
class ARLDM(pl.LightningModule):
|
85 |
+
def __init__(self, args: DictConfig, steps_per_epoch=1):
|
86 |
+
super(ARLDM, self).__init__()
|
87 |
+
self.args = args
|
88 |
+
self.steps_per_epoch = steps_per_epoch
|
89 |
+
"""
|
90 |
+
Configurations
|
91 |
+
"""
|
92 |
+
self.task = args.task
|
93 |
+
|
94 |
+
if args.mode == 'sample':
|
95 |
+
if args.scheduler == "pndm":
|
96 |
+
self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
97 |
+
skip_prk_steps=True)
|
98 |
+
elif args.scheduler == "ddim":
|
99 |
+
self.scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
100 |
+
clip_sample=False, set_alpha_to_one=True)
|
101 |
+
else:
|
102 |
+
raise ValueError("Scheduler not supported")
|
103 |
+
self.fid_augment = transforms.Compose([
|
104 |
+
transforms.Resize([64, 64]),
|
105 |
+
transforms.ToTensor(),
|
106 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
107 |
+
])
|
108 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
|
109 |
+
self.inception = InceptionV3([block_idx])
|
110 |
+
|
111 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
|
112 |
+
##############################
|
113 |
+
#self.clip_tokenizer.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/tokenizer')
|
114 |
+
self.blip_tokenizer = init_tokenizer()
|
115 |
+
self.blip_image_processor = transforms.Compose([
|
116 |
+
transforms.Resize([224, 224]),
|
117 |
+
transforms.ToTensor(),
|
118 |
+
transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
|
119 |
+
])
|
120 |
+
self.max_length = args.get(args.dataset).max_length
|
121 |
+
|
122 |
+
blip_image_null_token = self.blip_image_processor(
|
123 |
+
Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))).unsqueeze(0).float()
|
124 |
+
clip_text_null_token = self.clip_tokenizer([""], padding="max_length", max_length=self.max_length,
|
125 |
+
return_tensors="pt").input_ids
|
126 |
+
blip_text_null_token = self.blip_tokenizer([""], padding="max_length", max_length=self.max_length,
|
127 |
+
return_tensors="pt").input_ids
|
128 |
+
|
129 |
+
self.register_buffer('clip_text_null_token', clip_text_null_token)
|
130 |
+
self.register_buffer('blip_text_null_token', blip_text_null_token)
|
131 |
+
self.register_buffer('blip_image_null_token', blip_image_null_token)
|
132 |
+
|
133 |
+
self.text_encoder = CLIPTextModel.from_pretrained('runwayml/stable-diffusion-v1-5',
|
134 |
+
subfolder="text_encoder")
|
135 |
+
############################################
|
136 |
+
#self.text_encoder.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/text_encoder')
|
137 |
+
self.text_encoder.resize_token_embeddings(args.get(args.dataset).clip_embedding_tokens)
|
138 |
+
# resize_position_embeddings
|
139 |
+
old_embeddings = self.text_encoder.text_model.embeddings.position_embedding
|
140 |
+
new_embeddings = self.text_encoder._get_resized_embeddings(old_embeddings, self.max_length)
|
141 |
+
self.text_encoder.text_model.embeddings.position_embedding = new_embeddings
|
142 |
+
self.text_encoder.config.max_position_embeddings = self.max_length
|
143 |
+
self.text_encoder.max_position_embeddings = self.max_length
|
144 |
+
self.text_encoder.text_model.embeddings.position_ids = torch.arange(self.max_length).expand((1, -1))
|
145 |
+
|
146 |
+
self.modal_type_embeddings = nn.Embedding(2, 768)
|
147 |
+
self.time_embeddings = nn.Embedding(5, 768)
|
148 |
+
self.mm_encoder = blip_feature_extractor(
|
149 |
+
# pretrained='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth',
|
150 |
+
pretrained='/root/lihui/StoryVisualization/save_pretrained/model_large.pth',
|
151 |
+
image_size=224, vit='large')#, local_files_only=True)
|
152 |
+
self.mm_encoder.text_encoder.resize_token_embeddings(args.get(args.dataset).blip_embedding_tokens)
|
153 |
+
|
154 |
+
self.vae = AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="vae")
|
155 |
+
self.unet = UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="unet")
|
156 |
+
|
157 |
+
self.noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
158 |
+
num_train_timesteps=1000)
|
159 |
+
# monkeypatch_or_replace_lora(
|
160 |
+
# self.unet,
|
161 |
+
# torch.load("lora/example_loras/analog_svd_rank4.safetensors"),
|
162 |
+
# r=4,
|
163 |
+
# target_replace_module=unet_target_replace_module,
|
164 |
+
# )
|
165 |
+
#
|
166 |
+
# tune_lora_scale(self.unet, 1.00)
|
167 |
+
#tune_lora_scale(self.text_encoder, 1.00)
|
168 |
+
|
169 |
+
# torch.manual_seed(0)
|
170 |
+
###################################
|
171 |
+
#self.vae.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/vae')
|
172 |
+
#self.unet.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/unet')
|
173 |
+
|
174 |
+
# Freeze vae and unet
|
175 |
+
self.freeze_params(self.vae.parameters())
|
176 |
+
if args.freeze_resnet:
|
177 |
+
self.freeze_params([p for n, p in self.unet.named_parameters() if "attentions" not in n])
|
178 |
+
|
179 |
+
if args.freeze_blip and hasattr(self, "mm_encoder"):
|
180 |
+
self.freeze_params(self.mm_encoder.parameters())
|
181 |
+
self.unfreeze_params(self.mm_encoder.text_encoder.embeddings.word_embeddings.parameters())
|
182 |
+
|
183 |
+
if args.freeze_clip and hasattr(self, "text_encoder"):
|
184 |
+
self.freeze_params(self.text_encoder.parameters())
|
185 |
+
self.unfreeze_params(self.text_encoder.text_model.embeddings.token_embedding.parameters())
|
186 |
+
|
187 |
+
@staticmethod
|
188 |
+
def freeze_params(params):
|
189 |
+
for param in params:
|
190 |
+
param.requires_grad = False
|
191 |
+
|
192 |
+
@staticmethod
|
193 |
+
def unfreeze_params(params):
|
194 |
+
for param in params:
|
195 |
+
param.requires_grad = True
|
196 |
+
|
197 |
+
def configure_optimizers(self):
|
198 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.init_lr, weight_decay=1e-4) # optim_bits=8
|
199 |
+
scheduler = LinearWarmupCosineAnnealingLR(optimizer,
|
200 |
+
warmup_epochs=self.args.warmup_epochs * self.steps_per_epoch,
|
201 |
+
max_epochs=self.args.max_epochs * self.steps_per_epoch)
|
202 |
+
optim_dict = {
|
203 |
+
'optimizer': optimizer,
|
204 |
+
'lr_scheduler': {
|
205 |
+
'scheduler': scheduler, # The LR scheduler instance (required)
|
206 |
+
'interval': 'step', # The unit of the scheduler's step size
|
207 |
+
}
|
208 |
+
}
|
209 |
+
return optim_dict
|
210 |
+
|
211 |
+
def forward(self, batch):
|
212 |
+
if self.args.freeze_clip and hasattr(self, "text_encoder"):
|
213 |
+
self.text_encoder.eval()
|
214 |
+
if self.args.freeze_blip and hasattr(self, "mm_encoder"):
|
215 |
+
self.mm_encoder.eval()
|
216 |
+
images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_images = batch
|
217 |
+
B, V, S = captions.shape
|
218 |
+
src_V = V + 1 if self.task == 'continuation' else V
|
219 |
+
images = torch.flatten(images, 0, 1)
|
220 |
+
captions = torch.flatten(captions, 0, 1)
|
221 |
+
attention_mask = torch.flatten(attention_mask, 0, 1)
|
222 |
+
source_images = torch.flatten(source_images, 0, 1)
|
223 |
+
source_caption = torch.flatten(source_caption, 0, 1)
|
224 |
+
source_attention_mask = torch.flatten(source_attention_mask, 0, 1)
|
225 |
+
# 1 is not masked, 0 is maske
|
226 |
+
|
227 |
+
classifier_free_idx = np.random.rand(B * V) < 0.1
|
228 |
+
|
229 |
+
caption_embeddings = self.text_encoder(captions, attention_mask).last_hidden_state # B * V, S, D
|
230 |
+
source_embeddings = self.mm_encoder(source_images, source_caption, source_attention_mask,
|
231 |
+
mode='multimodal').reshape(B, src_V * S, -1)
|
232 |
+
source_embeddings = source_embeddings.repeat_interleave(V, dim=0)
|
233 |
+
caption_embeddings[classifier_free_idx] = \
|
234 |
+
self.text_encoder(self.clip_text_null_token).last_hidden_state[0]
|
235 |
+
source_embeddings[classifier_free_idx] = \
|
236 |
+
self.mm_encoder(self.blip_image_null_token, self.blip_text_null_token, attention_mask=None,
|
237 |
+
mode='multimodal')[0].repeat(src_V, 1)
|
238 |
+
caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device))
|
239 |
+
source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device))
|
240 |
+
source_embeddings += self.time_embeddings(
|
241 |
+
torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0))
|
242 |
+
encoder_hidden_states = torch.cat([caption_embeddings, source_embeddings], dim=1)
|
243 |
+
|
244 |
+
attention_mask = torch.cat(
|
245 |
+
[attention_mask, source_attention_mask.reshape(B, src_V * S).repeat_interleave(V, dim=0)], dim=1)
|
246 |
+
attention_mask = ~(attention_mask.bool()) # B * V, (src_V + 1) * S
|
247 |
+
attention_mask[classifier_free_idx] = False
|
248 |
+
|
249 |
+
# B, V, V, S
|
250 |
+
square_mask = torch.triu(torch.ones((V, V), device=self.device)).bool()
|
251 |
+
square_mask = square_mask.unsqueeze(0).unsqueeze(-1).expand(B, V, V, S)
|
252 |
+
square_mask = square_mask.reshape(B * V, V * S)
|
253 |
+
attention_mask[:, -V * S:] = torch.logical_or(square_mask, attention_mask[:, -V * S:])
|
254 |
+
|
255 |
+
latents = self.vae.encode(images).latent_dist.sample()
|
256 |
+
latents = latents * 0.18215
|
257 |
+
|
258 |
+
noise = torch.randn(latents.shape, device=self.device)
|
259 |
+
bsz = latents.shape[0]
|
260 |
+
timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bsz,), device=self.device).long()
|
261 |
+
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
|
262 |
+
|
263 |
+
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, attention_mask).sample
|
264 |
+
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
265 |
+
return loss
|
266 |
+
|
267 |
+
def sample(self, batch):
|
268 |
+
original_images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_test_images = batch
|
269 |
+
B, V, S = captions.shape
|
270 |
+
src_V = V + 1 if self.task == 'continuation' else V
|
271 |
+
original_images = torch.flatten(original_images, 0, 1)
|
272 |
+
captions = torch.flatten(captions, 0, 1)
|
273 |
+
attention_mask = torch.flatten(attention_mask, 0, 1)
|
274 |
+
source_images = torch.flatten(source_images, 0, 1)
|
275 |
+
source_caption = torch.flatten(source_caption, 0, 1)
|
276 |
+
source_attention_mask = torch.flatten(source_attention_mask, 0, 1)
|
277 |
+
|
278 |
+
caption_embeddings = self.text_encoder(captions, attention_mask).last_hidden_state # B * V, S, D
|
279 |
+
source_embeddings = self.mm_encoder(source_images, source_caption, source_attention_mask,
|
280 |
+
mode='multimodal').reshape(B, src_V * S, -1)
|
281 |
+
caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device))
|
282 |
+
source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device))
|
283 |
+
source_embeddings += self.time_embeddings(
|
284 |
+
torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0))
|
285 |
+
source_embeddings = source_embeddings.repeat_interleave(V, dim=0)
|
286 |
+
encoder_hidden_states = torch.cat([caption_embeddings, source_embeddings], dim=1)
|
287 |
+
|
288 |
+
attention_mask = torch.cat(
|
289 |
+
[attention_mask, source_attention_mask.reshape(B, src_V * S).repeat_interleave(V, dim=0)], dim=1)
|
290 |
+
attention_mask = ~(attention_mask.bool()) # B * V, (src_V + 1) * S
|
291 |
+
# B, V, V, S
|
292 |
+
square_mask = torch.triu(torch.ones((V, V), device=self.device)).bool()
|
293 |
+
square_mask = square_mask.unsqueeze(0).unsqueeze(-1).expand(B, V, V, S)
|
294 |
+
square_mask = square_mask.reshape(B * V, V * S)
|
295 |
+
attention_mask[:, -V * S:] = torch.logical_or(square_mask, attention_mask[:, -V * S:])
|
296 |
+
|
297 |
+
uncond_caption_embeddings = self.text_encoder(self.clip_text_null_token).last_hidden_state
|
298 |
+
uncond_source_embeddings = self.mm_encoder(self.blip_image_null_token, self.blip_text_null_token,
|
299 |
+
attention_mask=None, mode='multimodal').repeat(1, src_V, 1)
|
300 |
+
uncond_caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device))
|
301 |
+
uncond_source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device))
|
302 |
+
uncond_source_embeddings += self.time_embeddings(
|
303 |
+
torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0))
|
304 |
+
uncond_embeddings = torch.cat([uncond_caption_embeddings, uncond_source_embeddings], dim=1)
|
305 |
+
uncond_embeddings = uncond_embeddings.expand(B * V, -1, -1)
|
306 |
+
|
307 |
+
encoder_hidden_states = torch.cat([uncond_embeddings, encoder_hidden_states])
|
308 |
+
uncond_attention_mask = torch.zeros((B * V, (src_V + 1) * S), device=self.device).bool()
|
309 |
+
uncond_attention_mask[:, -V * S:] = square_mask
|
310 |
+
attention_mask = torch.cat([uncond_attention_mask, attention_mask], dim=0)
|
311 |
+
|
312 |
+
attention_mask = attention_mask.reshape(2, B, V, (src_V + 1) * S)
|
313 |
+
images = list()
|
314 |
+
for i in range(V):
|
315 |
+
encoder_hidden_states = encoder_hidden_states.reshape(2, B, V, (src_V + 1) * S, -1)
|
316 |
+
new_image = self.diffusion(encoder_hidden_states[:, :, i].reshape(2 * B, (src_V + 1) * S, -1),
|
317 |
+
attention_mask[:, :, i].reshape(2 * B, (src_V + 1) * S),
|
318 |
+
512, 512, self.args.num_inference_steps, self.args.guidance_scale, 0.0)
|
319 |
+
images += new_image
|
320 |
+
|
321 |
+
new_image = torch.stack([self.blip_image_processor(im) for im in new_image]).to(self.device)
|
322 |
+
|
323 |
+
new_embedding = self.mm_encoder(new_image, # B,C,H,W
|
324 |
+
source_caption.reshape(B, src_V, S)[:, i + src_V - V],
|
325 |
+
source_attention_mask.reshape(B, src_V, S)[:, i + src_V - V],
|
326 |
+
mode='multimodal') # B, S, D
|
327 |
+
new_embedding = new_embedding.repeat_interleave(V, dim=0)
|
328 |
+
new_embedding += self.modal_type_embeddings(torch.tensor(1, device=self.device))
|
329 |
+
new_embedding += self.time_embeddings(torch.tensor(i + src_V - V, device=self.device))
|
330 |
+
|
331 |
+
encoder_hidden_states = encoder_hidden_states[1].reshape(B * V, (src_V + 1) * S, -1)
|
332 |
+
encoder_hidden_states[:, (i + 1 + src_V - V) * S:(i + 2 + src_V - V) * S] = new_embedding
|
333 |
+
encoder_hidden_states = torch.cat([uncond_embeddings, encoder_hidden_states])
|
334 |
+
|
335 |
+
return original_images, images, texts, ori_test_images
|
336 |
+
|
337 |
+
|
338 |
+
def training_step(self, batch, batch_idx):
|
339 |
+
loss = self(batch)
|
340 |
+
self.log('loss/train_loss', loss, on_step=True, on_epoch=False, sync_dist=True, prog_bar=True)
|
341 |
+
return loss
|
342 |
+
|
343 |
+
def validation_step(self, batch, batch_idx):
|
344 |
+
loss = self(batch)
|
345 |
+
self.log('loss/val_loss', loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=True)
|
346 |
+
|
347 |
+
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
348 |
+
original_images, images, texts, ori_test_images = self.sample(batch)
|
349 |
+
if self.args.calculate_fid:
|
350 |
+
original_images = original_images.cpu().numpy().astype('uint8')
|
351 |
+
original_images = [Image.fromarray(im, 'RGB') for im in original_images]
|
352 |
+
|
353 |
+
# ori_test_images = torch.stack(ori_test_images).cpu().numpy().astype('uint8')
|
354 |
+
# ori_test_images = [Image.fromarray(im, 'RGB') for im in ori_test_images]
|
355 |
+
ori = self.inception_feature(original_images).cpu().numpy()
|
356 |
+
gen = self.inception_feature(images).cpu().numpy()
|
357 |
+
else:
|
358 |
+
ori = None
|
359 |
+
gen = None
|
360 |
+
|
361 |
+
return images, ori, gen, ori_test_images, texts
|
362 |
+
|
363 |
+
def diffusion(self, encoder_hidden_states, attention_mask, height, width, num_inference_steps, guidance_scale, eta):
|
364 |
+
latents = torch.randn((encoder_hidden_states.shape[0] // 2, self.unet.in_channels, height // 8, width // 8),
|
365 |
+
device=self.device)
|
366 |
+
|
367 |
+
# set timesteps
|
368 |
+
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
369 |
+
extra_set_kwargs = {}
|
370 |
+
if accepts_offset:
|
371 |
+
extra_set_kwargs["offset"] = 1
|
372 |
+
|
373 |
+
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
374 |
+
|
375 |
+
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
|
376 |
+
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
377 |
+
latents = latents * self.scheduler.sigmas[0]
|
378 |
+
|
379 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
380 |
+
extra_step_kwargs = {}
|
381 |
+
if accepts_eta:
|
382 |
+
extra_step_kwargs["eta"] = eta
|
383 |
+
|
384 |
+
for i, t in enumerate(self.scheduler.timesteps):
|
385 |
+
# expand the latents if we are doing classifier free guidance
|
386 |
+
latent_model_input = torch.cat([latents] * 2)
|
387 |
+
|
388 |
+
# noise_pred = self.unet(latent_model_input, t, encoder_hidden_states).sample
|
389 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states, attention_mask).sample
|
390 |
+
|
391 |
+
# perform guidance
|
392 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
393 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
394 |
+
|
395 |
+
# compute the previous noisy sample x_t -> x_t-1
|
396 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
397 |
+
|
398 |
+
# scale and decode the image latents with vae
|
399 |
+
latents = 1 / 0.18215 * latents
|
400 |
+
image = self.vae.decode(latents).sample
|
401 |
+
|
402 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
403 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
404 |
+
|
405 |
+
return self.numpy_to_pil(image)
|
406 |
+
|
407 |
+
@staticmethod
|
408 |
+
def numpy_to_pil(images):
|
409 |
+
"""
|
410 |
+
Convert a numpy image or a batch of images to a PIL image.
|
411 |
+
"""
|
412 |
+
if images.ndim == 3:
|
413 |
+
images = images[None, ...]
|
414 |
+
images = (images * 255).round().astype("uint8")
|
415 |
+
pil_images = [Image.fromarray(image, 'RGB') for image in images]
|
416 |
+
|
417 |
+
return pil_images
|
418 |
+
|
419 |
+
def inception_feature(self, images):
|
420 |
+
images = torch.stack([self.fid_augment(image) for image in images])
|
421 |
+
images = images.type(torch.FloatTensor).to(self.device)
|
422 |
+
images = (images + 1) / 2
|
423 |
+
images = F.interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)
|
424 |
+
pred = self.inception(images)[0]
|
425 |
+
|
426 |
+
if pred.shape[2] != 1 or pred.shape[3] != 1:
|
427 |
+
pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))
|
428 |
+
return pred.reshape(-1, 2048)
|
429 |
+
|
430 |
+
|
431 |
+
def train(args: DictConfig) -> None:
|
432 |
+
dataloader = LightningDataset(args)
|
433 |
+
dataloader.setup('fit')
|
434 |
+
# dataloader.
|
435 |
+
model = ARLDM(args, steps_per_epoch=dataloader.get_length_of_train_dataloader())
|
436 |
+
|
437 |
+
logger = TensorBoardLogger(save_dir=os.path.join(args.ckpt_dir, args.run_name), name='log', default_hp_metric=False)
|
438 |
+
|
439 |
+
checkpoint_callback = ModelCheckpoint(
|
440 |
+
dirpath=os.path.join(args.ckpt_dir, args.run_name),
|
441 |
+
save_top_k=0,
|
442 |
+
save_last=True
|
443 |
+
)
|
444 |
+
|
445 |
+
lr_monitor = LearningRateMonitor(logging_interval='step')
|
446 |
+
|
447 |
+
callback_list = [lr_monitor, checkpoint_callback]
|
448 |
+
|
449 |
+
trainer = pl.Trainer(
|
450 |
+
accelerator='gpu',
|
451 |
+
devices=args.gpu_ids,
|
452 |
+
max_epochs=args.max_epochs,
|
453 |
+
benchmark=True,
|
454 |
+
logger=logger,
|
455 |
+
log_every_n_steps=1,
|
456 |
+
callbacks=callback_list,
|
457 |
+
strategy=DDPStrategy(find_unused_parameters=False)
|
458 |
+
)
|
459 |
+
trainer.fit(model, dataloader, ckpt_path=args.train_model_file)
|
460 |
+
|
461 |
+
|
462 |
+
def sample(args: DictConfig) -> None:
|
463 |
+
|
464 |
+
assert args.test_model_file is not None, "test_model_file cannot be None"
|
465 |
+
assert args.gpu_ids == 1 or len(args.gpu_ids) == 1, "Only one GPU is supported in test mode"
|
466 |
+
dataloader = LightningDataset(args)
|
467 |
+
dataloader.setup('test')
|
468 |
+
model = ARLDM.load_from_checkpoint(args.test_model_file, args=args, strict=False)
|
469 |
+
|
470 |
+
predictor = pl.Trainer(
|
471 |
+
accelerator='gpu',
|
472 |
+
devices=args.gpu_ids,
|
473 |
+
max_epochs=-1,
|
474 |
+
benchmark=True
|
475 |
+
)
|
476 |
+
predictions = predictor.predict(model, dataloader)
|
477 |
+
images = [elem for sublist in predictions for elem in sublist[0]]
|
478 |
+
ori_images = [elem for sublist in predictions for elem in sublist[3]]
|
479 |
+
ori_test_images = list()
|
480 |
+
if not os.path.exists(args.sample_output_dir):
|
481 |
+
try:
|
482 |
+
os.mkdir(args.sample_output_dir)
|
483 |
+
except:
|
484 |
+
pass
|
485 |
+
|
486 |
+
text_list = [elem for sublist in predictions for elem in sublist[4]]
|
487 |
+
################################
|
488 |
+
# print(f"index: {index}")
|
489 |
+
num_images = len(images)
|
490 |
+
num_groups = (num_images + 4) // 5 # 计算总共需要的组数
|
491 |
+
|
492 |
+
for g in range(num_groups):
|
493 |
+
print('Story {}:'.format(g + 1)) # 打印组号
|
494 |
+
start_index = g * 5 # 当前组的起始索引
|
495 |
+
end_index = min(start_index + 5, num_images) # 当前组的结束索引
|
496 |
+
for i in range(start_index, end_index):
|
497 |
+
print(text_list[i]) # 打印对应的文本
|
498 |
+
images[i].save(
|
499 |
+
os.path.join(args.sample_output_dir, 'group{:02d}_image{:02d}.png'.format(g + 1, i - start_index + 1)))
|
500 |
+
# ori_images[i] = ori_images[i]
|
501 |
+
ori_images_pil = Image.fromarray(np.uint8(ori_images[i].detach().cpu().squeeze().float().numpy())).convert("RGB")
|
502 |
+
ori_test_images.append(ori_images_pil)
|
503 |
+
ori_images_pil.save(
|
504 |
+
os.path.join('/root/lihui/StoryVisualization/ori_test_images_epoch10', 'group{:02d}_image{:02d}.png'.format(g + 1, i - start_index + 1)))
|
505 |
+
# for i, im in enumerate(ori_images):
|
506 |
+
# file_path = '/root/lihui/StoryVisualization/ori_test_images/image{}.png'.format(i)
|
507 |
+
# cv2.imwrite(file_path, im)
|
508 |
+
|
509 |
+
|
510 |
+
if args.calculate_fid:
|
511 |
+
ori = np.array([elem for sublist in predictions for elem in sublist[1]])
|
512 |
+
gen = np.array([elem for sublist in predictions for elem in sublist[2]])
|
513 |
+
fid = calculate_fid_given_features(ori, gen)
|
514 |
+
print('FID: {}'.format(fid))
|
515 |
+
|
516 |
+
|
517 |
+
|
518 |
+
|
519 |
+
|
520 |
+
@hydra.main(config_path=".", config_name="config")
|
521 |
+
def main(args: DictConfig) -> None:
|
522 |
+
pl.seed_everything(args.seed)
|
523 |
+
if args.num_cpu_cores > 0:
|
524 |
+
torch.set_num_threads(args.num_cpu_cores)
|
525 |
+
|
526 |
+
if args.mode == 'train':
|
527 |
+
############################
|
528 |
+
train(args)
|
529 |
+
elif args.mode == 'sample':
|
530 |
+
# dataloader = LightningDataset(args)
|
531 |
+
# dataloader.setup('test')
|
532 |
+
sample(args)
|
533 |
+
|
534 |
+
|
535 |
+
|
536 |
+
if __name__ == '__main__':
|
537 |
+
main()
|
models/blip_override/blip.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
warnings.filterwarnings("ignore")
|
11 |
+
|
12 |
+
from .vit import VisionTransformer, interpolate_pos_embed
|
13 |
+
from .med import BertModel, BertLMHeadModel
|
14 |
+
from transformers import BertTokenizer, BertConfig
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
import os
|
20 |
+
from urllib.parse import urlparse
|
21 |
+
from timm.models.hub import download_cached_file
|
22 |
+
|
23 |
+
|
24 |
+
class BLIP_Base(nn.Module):
|
25 |
+
def __init__(self,
|
26 |
+
med_config='models/blip_override/med_config.json',
|
27 |
+
image_size=224,
|
28 |
+
vit='base',
|
29 |
+
vit_grad_ckpt=False,
|
30 |
+
vit_ckpt_layer=0,
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Args:
|
34 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
35 |
+
image_size (int): input image size
|
36 |
+
vit (str): model size of vision transformer
|
37 |
+
"""
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
|
41 |
+
self.tokenizer = init_tokenizer()
|
42 |
+
med_config = BertConfig.from_json_file(med_config)
|
43 |
+
med_config.encoder_width = vision_width
|
44 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
45 |
+
|
46 |
+
def forward(self, image, text, attention_mask, mode):
|
47 |
+
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
|
48 |
+
if mode == 'image':
|
49 |
+
# return image features
|
50 |
+
image_embeds = self.visual_encoder(image)
|
51 |
+
return image_embeds
|
52 |
+
|
53 |
+
elif mode == 'text':
|
54 |
+
# return text features
|
55 |
+
text_output = self.text_encoder(text, attention_mask=attention_mask, return_dict=True, mode='text')
|
56 |
+
return text_output.last_hidden_state
|
57 |
+
|
58 |
+
elif mode == 'multimodal':
|
59 |
+
# return multimodel features
|
60 |
+
image_embeds = self.visual_encoder(image)
|
61 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
62 |
+
|
63 |
+
text[:, 0] = self.tokenizer.enc_token_id
|
64 |
+
output = self.text_encoder(text,
|
65 |
+
attention_mask=attention_mask,
|
66 |
+
encoder_hidden_states=image_embeds,
|
67 |
+
encoder_attention_mask=image_atts,
|
68 |
+
return_dict=True,
|
69 |
+
)
|
70 |
+
return output.last_hidden_state
|
71 |
+
|
72 |
+
|
73 |
+
class BLIP_Decoder(nn.Module):
|
74 |
+
def __init__(self,
|
75 |
+
med_config='models/blip_override/med_config.json',
|
76 |
+
image_size=384,
|
77 |
+
vit='base',
|
78 |
+
vit_grad_ckpt=False,
|
79 |
+
vit_ckpt_layer=0,
|
80 |
+
prompt='a picture of ',
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
Args:
|
84 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
85 |
+
image_size (int): input image size
|
86 |
+
vit (str): model size of vision transformer
|
87 |
+
"""
|
88 |
+
super().__init__()
|
89 |
+
|
90 |
+
self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
|
91 |
+
self.tokenizer = init_tokenizer()
|
92 |
+
med_config = BertConfig.from_json_file(med_config)
|
93 |
+
med_config.encoder_width = vision_width
|
94 |
+
self.text_decoder = BertLMHeadModel(config=med_config)
|
95 |
+
|
96 |
+
self.prompt = prompt
|
97 |
+
self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
|
98 |
+
|
99 |
+
def forward(self, image, caption):
|
100 |
+
|
101 |
+
image_embeds = self.visual_encoder(image)
|
102 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
103 |
+
|
104 |
+
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(
|
105 |
+
image.device)
|
106 |
+
|
107 |
+
text.input_ids[:, 0] = self.tokenizer.bos_token_id
|
108 |
+
|
109 |
+
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
|
110 |
+
decoder_targets[:, :self.prompt_length] = -100
|
111 |
+
|
112 |
+
decoder_output = self.text_decoder(text.input_ids,
|
113 |
+
attention_mask=text.attention_mask,
|
114 |
+
encoder_hidden_states=image_embeds,
|
115 |
+
encoder_attention_mask=image_atts,
|
116 |
+
labels=decoder_targets,
|
117 |
+
return_dict=True,
|
118 |
+
)
|
119 |
+
loss_lm = decoder_output.loss
|
120 |
+
|
121 |
+
return loss_lm
|
122 |
+
|
123 |
+
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9,
|
124 |
+
repetition_penalty=1.0):
|
125 |
+
image_embeds = self.visual_encoder(image)
|
126 |
+
|
127 |
+
if not sample:
|
128 |
+
image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
|
129 |
+
|
130 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
131 |
+
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask": image_atts}
|
132 |
+
|
133 |
+
prompt = [self.prompt] * image.size(0)
|
134 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
|
135 |
+
input_ids[:, 0] = self.tokenizer.bos_token_id
|
136 |
+
input_ids = input_ids[:, :-1]
|
137 |
+
|
138 |
+
if sample:
|
139 |
+
# nucleus sampling
|
140 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
141 |
+
max_length=max_length,
|
142 |
+
min_length=min_length,
|
143 |
+
do_sample=True,
|
144 |
+
top_p=top_p,
|
145 |
+
num_return_sequences=1,
|
146 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
147 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
148 |
+
repetition_penalty=1.1,
|
149 |
+
**model_kwargs)
|
150 |
+
else:
|
151 |
+
# beam search
|
152 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
153 |
+
max_length=max_length,
|
154 |
+
min_length=min_length,
|
155 |
+
num_beams=num_beams,
|
156 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
157 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
158 |
+
repetition_penalty=repetition_penalty,
|
159 |
+
**model_kwargs)
|
160 |
+
|
161 |
+
captions = []
|
162 |
+
for output in outputs:
|
163 |
+
caption = self.tokenizer.decode(output, skip_special_tokens=True)
|
164 |
+
captions.append(caption[len(self.prompt):])
|
165 |
+
return captions
|
166 |
+
|
167 |
+
|
168 |
+
def blip_decoder(pretrained='', **kwargs):
|
169 |
+
model = BLIP_Decoder(**kwargs)
|
170 |
+
if pretrained:
|
171 |
+
model, msg = load_checkpoint(model, pretrained)
|
172 |
+
assert (len(msg.missing_keys) == 0)
|
173 |
+
return model
|
174 |
+
|
175 |
+
|
176 |
+
def blip_feature_extractor(pretrained='', **kwargs):
|
177 |
+
model = BLIP_Base(**kwargs)
|
178 |
+
if pretrained:
|
179 |
+
model, msg = load_checkpoint(model, pretrained)
|
180 |
+
assert (len(msg.missing_keys) == 0)
|
181 |
+
return model
|
182 |
+
|
183 |
+
|
184 |
+
def init_tokenizer():
|
185 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
186 |
+
tokenizer.add_special_tokens({'bos_token': '[DEC]'})
|
187 |
+
tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']})
|
188 |
+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
189 |
+
return tokenizer
|
190 |
+
|
191 |
+
|
192 |
+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
193 |
+
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
194 |
+
assert use_grad_checkpointing is False, 'grad checkpointing is not supported yet'
|
195 |
+
if vit == 'base':
|
196 |
+
vision_width = 768
|
197 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
198 |
+
num_heads=12, use_grad_checkpointing=use_grad_checkpointing,
|
199 |
+
ckpt_layer=ckpt_layer,
|
200 |
+
drop_path_rate=0 or drop_path_rate
|
201 |
+
)
|
202 |
+
elif vit == 'large':
|
203 |
+
vision_width = 1024
|
204 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
205 |
+
num_heads=16, use_grad_checkpointing=use_grad_checkpointing,
|
206 |
+
ckpt_layer=ckpt_layer,
|
207 |
+
drop_path_rate=0.1 or drop_path_rate
|
208 |
+
)
|
209 |
+
return visual_encoder, vision_width
|
210 |
+
|
211 |
+
|
212 |
+
def is_url(url_or_filename):
|
213 |
+
parsed = urlparse(url_or_filename)
|
214 |
+
return parsed.scheme in ("http", "https")
|
215 |
+
|
216 |
+
|
217 |
+
def load_checkpoint(model, url_or_filename):
|
218 |
+
if is_url(url_or_filename):
|
219 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
220 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
221 |
+
elif os.path.isfile(url_or_filename):
|
222 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
223 |
+
else:
|
224 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
225 |
+
|
226 |
+
state_dict = checkpoint['model']
|
227 |
+
|
228 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],
|
229 |
+
model.visual_encoder)
|
230 |
+
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
231 |
+
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
232 |
+
model.visual_encoder_m)
|
233 |
+
for key in model.state_dict().keys():
|
234 |
+
if key in state_dict.keys():
|
235 |
+
if state_dict[key].shape != model.state_dict()[key].shape:
|
236 |
+
del state_dict[key]
|
237 |
+
|
238 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
239 |
+
print('load checkpoint from %s' % url_or_filename)
|
240 |
+
return model, msg
|
models/blip_override/med.py
ADDED
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
* Based on huggingface code base
|
8 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
9 |
+
'''
|
10 |
+
|
11 |
+
import math
|
12 |
+
import os
|
13 |
+
import warnings
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional, Tuple
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import Tensor, device, dtype, nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
from torch import nn
|
21 |
+
from torch.nn import CrossEntropyLoss
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from transformers.activations import ACT2FN
|
25 |
+
from transformers.file_utils import (
|
26 |
+
ModelOutput,
|
27 |
+
)
|
28 |
+
from transformers.modeling_outputs import (
|
29 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
30 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
31 |
+
CausalLMOutputWithCrossAttentions,
|
32 |
+
MaskedLMOutput,
|
33 |
+
MultipleChoiceModelOutput,
|
34 |
+
NextSentencePredictorOutput,
|
35 |
+
QuestionAnsweringModelOutput,
|
36 |
+
SequenceClassifierOutput,
|
37 |
+
TokenClassifierOutput,
|
38 |
+
)
|
39 |
+
from transformers.modeling_utils import (
|
40 |
+
PreTrainedModel,
|
41 |
+
apply_chunking_to_forward,
|
42 |
+
find_pruneable_heads_and_indices,
|
43 |
+
prune_linear_layer,
|
44 |
+
)
|
45 |
+
from transformers.utils import logging
|
46 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
47 |
+
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__)
|
50 |
+
|
51 |
+
|
52 |
+
class BertEmbeddings(nn.Module):
|
53 |
+
"""Construct the embeddings from word and position embeddings."""
|
54 |
+
|
55 |
+
def __init__(self, config):
|
56 |
+
super().__init__()
|
57 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
58 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
59 |
+
|
60 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
61 |
+
# any TensorFlow checkpoint file
|
62 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
63 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
64 |
+
|
65 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
66 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
67 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
68 |
+
|
69 |
+
self.config = config
|
70 |
+
|
71 |
+
def forward(
|
72 |
+
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
73 |
+
):
|
74 |
+
if input_ids is not None:
|
75 |
+
input_shape = input_ids.size()
|
76 |
+
else:
|
77 |
+
input_shape = inputs_embeds.size()[:-1]
|
78 |
+
|
79 |
+
seq_length = input_shape[1]
|
80 |
+
|
81 |
+
if position_ids is None:
|
82 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
83 |
+
|
84 |
+
if inputs_embeds is None:
|
85 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
86 |
+
|
87 |
+
embeddings = inputs_embeds
|
88 |
+
|
89 |
+
if self.position_embedding_type == "absolute":
|
90 |
+
position_embeddings = self.position_embeddings(position_ids)
|
91 |
+
embeddings += position_embeddings
|
92 |
+
embeddings = self.LayerNorm(embeddings)
|
93 |
+
embeddings = self.dropout(embeddings)
|
94 |
+
return embeddings
|
95 |
+
|
96 |
+
|
97 |
+
class BertSelfAttention(nn.Module):
|
98 |
+
def __init__(self, config, is_cross_attention):
|
99 |
+
super().__init__()
|
100 |
+
self.config = config
|
101 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
102 |
+
raise ValueError(
|
103 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
104 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
105 |
+
)
|
106 |
+
|
107 |
+
self.num_attention_heads = config.num_attention_heads
|
108 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
109 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
110 |
+
|
111 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
112 |
+
if is_cross_attention:
|
113 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
114 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
115 |
+
else:
|
116 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
117 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
118 |
+
|
119 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
120 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
121 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
122 |
+
self.max_position_embeddings = config.max_position_embeddings
|
123 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
124 |
+
self.save_attention = False
|
125 |
+
|
126 |
+
def save_attn_gradients(self, attn_gradients):
|
127 |
+
self.attn_gradients = attn_gradients
|
128 |
+
|
129 |
+
def get_attn_gradients(self):
|
130 |
+
return self.attn_gradients
|
131 |
+
|
132 |
+
def save_attention_map(self, attention_map):
|
133 |
+
self.attention_map = attention_map
|
134 |
+
|
135 |
+
def get_attention_map(self):
|
136 |
+
return self.attention_map
|
137 |
+
|
138 |
+
def transpose_for_scores(self, x):
|
139 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
140 |
+
x = x.view(*new_x_shape)
|
141 |
+
return x.permute(0, 2, 1, 3)
|
142 |
+
|
143 |
+
def forward(
|
144 |
+
self,
|
145 |
+
hidden_states,
|
146 |
+
attention_mask=None,
|
147 |
+
head_mask=None,
|
148 |
+
encoder_hidden_states=None,
|
149 |
+
encoder_attention_mask=None,
|
150 |
+
past_key_value=None,
|
151 |
+
output_attentions=False,
|
152 |
+
):
|
153 |
+
mixed_query_layer = self.query(hidden_states)
|
154 |
+
|
155 |
+
# If this is instantiated as a cross-attention module, the keys
|
156 |
+
# and values come from an encoder; the attention mask needs to be
|
157 |
+
# such that the encoder's padding tokens are not attended to.
|
158 |
+
is_cross_attention = encoder_hidden_states is not None
|
159 |
+
|
160 |
+
if is_cross_attention:
|
161 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
162 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
163 |
+
attention_mask = encoder_attention_mask
|
164 |
+
elif past_key_value is not None:
|
165 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
166 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
167 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
168 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
169 |
+
else:
|
170 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
171 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
172 |
+
|
173 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
174 |
+
|
175 |
+
past_key_value = (key_layer, value_layer)
|
176 |
+
|
177 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
178 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
179 |
+
|
180 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
181 |
+
seq_length = hidden_states.size()[1]
|
182 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
183 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
184 |
+
distance = position_ids_l - position_ids_r
|
185 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
186 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
187 |
+
|
188 |
+
if self.position_embedding_type == "relative_key":
|
189 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
190 |
+
attention_scores = attention_scores + relative_position_scores
|
191 |
+
elif self.position_embedding_type == "relative_key_query":
|
192 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
193 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
194 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
195 |
+
|
196 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
197 |
+
if attention_mask is not None:
|
198 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
199 |
+
attention_scores = attention_scores + attention_mask
|
200 |
+
|
201 |
+
# Normalize the attention scores to probabilities.
|
202 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
203 |
+
|
204 |
+
if is_cross_attention and self.save_attention:
|
205 |
+
self.save_attention_map(attention_probs)
|
206 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
207 |
+
|
208 |
+
# This is actually dropping out entire tokens to attend to, which might
|
209 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
210 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
211 |
+
|
212 |
+
# Mask heads if we want to
|
213 |
+
if head_mask is not None:
|
214 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
215 |
+
|
216 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
217 |
+
|
218 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
219 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
220 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
221 |
+
|
222 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
223 |
+
|
224 |
+
outputs = outputs + (past_key_value,)
|
225 |
+
return outputs
|
226 |
+
|
227 |
+
|
228 |
+
class BertSelfOutput(nn.Module):
|
229 |
+
def __init__(self, config):
|
230 |
+
super().__init__()
|
231 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
232 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
233 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
234 |
+
|
235 |
+
def forward(self, hidden_states, input_tensor):
|
236 |
+
hidden_states = self.dense(hidden_states)
|
237 |
+
hidden_states = self.dropout(hidden_states)
|
238 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
239 |
+
return hidden_states
|
240 |
+
|
241 |
+
|
242 |
+
class BertAttention(nn.Module):
|
243 |
+
def __init__(self, config, is_cross_attention=False):
|
244 |
+
super().__init__()
|
245 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
246 |
+
self.output = BertSelfOutput(config)
|
247 |
+
self.pruned_heads = set()
|
248 |
+
|
249 |
+
def prune_heads(self, heads):
|
250 |
+
if len(heads) == 0:
|
251 |
+
return
|
252 |
+
heads, index = find_pruneable_heads_and_indices(
|
253 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
254 |
+
)
|
255 |
+
|
256 |
+
# Prune linear layers
|
257 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
258 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
259 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
260 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
261 |
+
|
262 |
+
# Update hyper params and store pruned heads
|
263 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
264 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
265 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
266 |
+
|
267 |
+
def forward(
|
268 |
+
self,
|
269 |
+
hidden_states,
|
270 |
+
attention_mask=None,
|
271 |
+
head_mask=None,
|
272 |
+
encoder_hidden_states=None,
|
273 |
+
encoder_attention_mask=None,
|
274 |
+
past_key_value=None,
|
275 |
+
output_attentions=False,
|
276 |
+
):
|
277 |
+
self_outputs = self.self(
|
278 |
+
hidden_states,
|
279 |
+
attention_mask,
|
280 |
+
head_mask,
|
281 |
+
encoder_hidden_states,
|
282 |
+
encoder_attention_mask,
|
283 |
+
past_key_value,
|
284 |
+
output_attentions,
|
285 |
+
)
|
286 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
287 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
288 |
+
return outputs
|
289 |
+
|
290 |
+
|
291 |
+
class BertIntermediate(nn.Module):
|
292 |
+
def __init__(self, config):
|
293 |
+
super().__init__()
|
294 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
295 |
+
if isinstance(config.hidden_act, str):
|
296 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
297 |
+
else:
|
298 |
+
self.intermediate_act_fn = config.hidden_act
|
299 |
+
|
300 |
+
def forward(self, hidden_states):
|
301 |
+
hidden_states = self.dense(hidden_states)
|
302 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
303 |
+
return hidden_states
|
304 |
+
|
305 |
+
|
306 |
+
class BertOutput(nn.Module):
|
307 |
+
def __init__(self, config):
|
308 |
+
super().__init__()
|
309 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
310 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
311 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
312 |
+
|
313 |
+
def forward(self, hidden_states, input_tensor):
|
314 |
+
hidden_states = self.dense(hidden_states)
|
315 |
+
hidden_states = self.dropout(hidden_states)
|
316 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
317 |
+
return hidden_states
|
318 |
+
|
319 |
+
|
320 |
+
class BertLayer(nn.Module):
|
321 |
+
def __init__(self, config, layer_num):
|
322 |
+
super().__init__()
|
323 |
+
self.config = config
|
324 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
325 |
+
self.seq_len_dim = 1
|
326 |
+
self.attention = BertAttention(config)
|
327 |
+
self.layer_num = layer_num
|
328 |
+
if self.config.add_cross_attention:
|
329 |
+
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
330 |
+
self.intermediate = BertIntermediate(config)
|
331 |
+
self.output = BertOutput(config)
|
332 |
+
|
333 |
+
def forward(
|
334 |
+
self,
|
335 |
+
hidden_states,
|
336 |
+
attention_mask=None,
|
337 |
+
head_mask=None,
|
338 |
+
encoder_hidden_states=None,
|
339 |
+
encoder_attention_mask=None,
|
340 |
+
past_key_value=None,
|
341 |
+
output_attentions=False,
|
342 |
+
mode=None,
|
343 |
+
):
|
344 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
345 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
346 |
+
self_attention_outputs = self.attention(
|
347 |
+
hidden_states,
|
348 |
+
attention_mask,
|
349 |
+
head_mask,
|
350 |
+
output_attentions=output_attentions,
|
351 |
+
past_key_value=self_attn_past_key_value,
|
352 |
+
)
|
353 |
+
attention_output = self_attention_outputs[0]
|
354 |
+
|
355 |
+
outputs = self_attention_outputs[1:-1]
|
356 |
+
present_key_value = self_attention_outputs[-1]
|
357 |
+
|
358 |
+
if mode=='multimodal':
|
359 |
+
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
360 |
+
|
361 |
+
cross_attention_outputs = self.crossattention(
|
362 |
+
attention_output,
|
363 |
+
attention_mask,
|
364 |
+
head_mask,
|
365 |
+
encoder_hidden_states,
|
366 |
+
encoder_attention_mask,
|
367 |
+
output_attentions=output_attentions,
|
368 |
+
)
|
369 |
+
attention_output = cross_attention_outputs[0]
|
370 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
371 |
+
layer_output = apply_chunking_to_forward(
|
372 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
373 |
+
)
|
374 |
+
outputs = (layer_output,) + outputs
|
375 |
+
|
376 |
+
outputs = outputs + (present_key_value,)
|
377 |
+
|
378 |
+
return outputs
|
379 |
+
|
380 |
+
def feed_forward_chunk(self, attention_output):
|
381 |
+
intermediate_output = self.intermediate(attention_output)
|
382 |
+
layer_output = self.output(intermediate_output, attention_output)
|
383 |
+
return layer_output
|
384 |
+
|
385 |
+
|
386 |
+
class BertEncoder(nn.Module):
|
387 |
+
def __init__(self, config):
|
388 |
+
super().__init__()
|
389 |
+
self.config = config
|
390 |
+
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
391 |
+
self.gradient_checkpointing = False
|
392 |
+
|
393 |
+
def forward(
|
394 |
+
self,
|
395 |
+
hidden_states,
|
396 |
+
attention_mask=None,
|
397 |
+
head_mask=None,
|
398 |
+
encoder_hidden_states=None,
|
399 |
+
encoder_attention_mask=None,
|
400 |
+
past_key_values=None,
|
401 |
+
use_cache=None,
|
402 |
+
output_attentions=False,
|
403 |
+
output_hidden_states=False,
|
404 |
+
return_dict=True,
|
405 |
+
mode='multimodal',
|
406 |
+
):
|
407 |
+
all_hidden_states = () if output_hidden_states else None
|
408 |
+
all_self_attentions = () if output_attentions else None
|
409 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
410 |
+
|
411 |
+
next_decoder_cache = () if use_cache else None
|
412 |
+
|
413 |
+
for i in range(self.config.num_hidden_layers):
|
414 |
+
layer_module = self.layer[i]
|
415 |
+
if output_hidden_states:
|
416 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
417 |
+
|
418 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
419 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
420 |
+
|
421 |
+
if self.gradient_checkpointing and self.training:
|
422 |
+
|
423 |
+
if use_cache:
|
424 |
+
logger.warn(
|
425 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
426 |
+
)
|
427 |
+
use_cache = False
|
428 |
+
|
429 |
+
def create_custom_forward(module):
|
430 |
+
def custom_forward(*inputs):
|
431 |
+
return module(*inputs, past_key_value, output_attentions)
|
432 |
+
|
433 |
+
return custom_forward
|
434 |
+
|
435 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
436 |
+
create_custom_forward(layer_module),
|
437 |
+
hidden_states,
|
438 |
+
attention_mask,
|
439 |
+
layer_head_mask,
|
440 |
+
encoder_hidden_states,
|
441 |
+
encoder_attention_mask,
|
442 |
+
mode=mode,
|
443 |
+
)
|
444 |
+
else:
|
445 |
+
layer_outputs = layer_module(
|
446 |
+
hidden_states,
|
447 |
+
attention_mask,
|
448 |
+
layer_head_mask,
|
449 |
+
encoder_hidden_states,
|
450 |
+
encoder_attention_mask,
|
451 |
+
past_key_value,
|
452 |
+
output_attentions,
|
453 |
+
mode=mode,
|
454 |
+
)
|
455 |
+
|
456 |
+
hidden_states = layer_outputs[0]
|
457 |
+
if use_cache:
|
458 |
+
next_decoder_cache += (layer_outputs[-1],)
|
459 |
+
if output_attentions:
|
460 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
461 |
+
|
462 |
+
if output_hidden_states:
|
463 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
464 |
+
|
465 |
+
if not return_dict:
|
466 |
+
return tuple(
|
467 |
+
v
|
468 |
+
for v in [
|
469 |
+
hidden_states,
|
470 |
+
next_decoder_cache,
|
471 |
+
all_hidden_states,
|
472 |
+
all_self_attentions,
|
473 |
+
all_cross_attentions,
|
474 |
+
]
|
475 |
+
if v is not None
|
476 |
+
)
|
477 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
478 |
+
last_hidden_state=hidden_states,
|
479 |
+
past_key_values=next_decoder_cache,
|
480 |
+
hidden_states=all_hidden_states,
|
481 |
+
attentions=all_self_attentions,
|
482 |
+
cross_attentions=all_cross_attentions,
|
483 |
+
)
|
484 |
+
|
485 |
+
|
486 |
+
class BertPooler(nn.Module):
|
487 |
+
def __init__(self, config):
|
488 |
+
super().__init__()
|
489 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
490 |
+
self.activation = nn.Tanh()
|
491 |
+
|
492 |
+
def forward(self, hidden_states):
|
493 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
494 |
+
# to the first token.
|
495 |
+
first_token_tensor = hidden_states[:, 0]
|
496 |
+
pooled_output = self.dense(first_token_tensor)
|
497 |
+
pooled_output = self.activation(pooled_output)
|
498 |
+
return pooled_output
|
499 |
+
|
500 |
+
|
501 |
+
class BertPredictionHeadTransform(nn.Module):
|
502 |
+
def __init__(self, config):
|
503 |
+
super().__init__()
|
504 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
505 |
+
if isinstance(config.hidden_act, str):
|
506 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
507 |
+
else:
|
508 |
+
self.transform_act_fn = config.hidden_act
|
509 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
510 |
+
|
511 |
+
def forward(self, hidden_states):
|
512 |
+
hidden_states = self.dense(hidden_states)
|
513 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
514 |
+
hidden_states = self.LayerNorm(hidden_states)
|
515 |
+
return hidden_states
|
516 |
+
|
517 |
+
|
518 |
+
class BertLMPredictionHead(nn.Module):
|
519 |
+
def __init__(self, config):
|
520 |
+
super().__init__()
|
521 |
+
self.transform = BertPredictionHeadTransform(config)
|
522 |
+
|
523 |
+
# The output weights are the same as the input embeddings, but there is
|
524 |
+
# an output-only bias for each token.
|
525 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
526 |
+
|
527 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
528 |
+
|
529 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
530 |
+
self.decoder.bias = self.bias
|
531 |
+
|
532 |
+
def forward(self, hidden_states):
|
533 |
+
hidden_states = self.transform(hidden_states)
|
534 |
+
hidden_states = self.decoder(hidden_states)
|
535 |
+
return hidden_states
|
536 |
+
|
537 |
+
|
538 |
+
class BertOnlyMLMHead(nn.Module):
|
539 |
+
def __init__(self, config):
|
540 |
+
super().__init__()
|
541 |
+
self.predictions = BertLMPredictionHead(config)
|
542 |
+
|
543 |
+
def forward(self, sequence_output):
|
544 |
+
prediction_scores = self.predictions(sequence_output)
|
545 |
+
return prediction_scores
|
546 |
+
|
547 |
+
|
548 |
+
class BertPreTrainedModel(PreTrainedModel):
|
549 |
+
"""
|
550 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
551 |
+
models.
|
552 |
+
"""
|
553 |
+
|
554 |
+
config_class = BertConfig
|
555 |
+
base_model_prefix = "bert"
|
556 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
557 |
+
|
558 |
+
def _init_weights(self, module):
|
559 |
+
""" Initialize the weights """
|
560 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
561 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
562 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
563 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
564 |
+
elif isinstance(module, nn.LayerNorm):
|
565 |
+
module.bias.data.zero_()
|
566 |
+
module.weight.data.fill_(1.0)
|
567 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
568 |
+
module.bias.data.zero_()
|
569 |
+
|
570 |
+
|
571 |
+
class BertModel(BertPreTrainedModel):
|
572 |
+
"""
|
573 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
574 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
575 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
576 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
577 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
578 |
+
input to the forward pass.
|
579 |
+
"""
|
580 |
+
|
581 |
+
def __init__(self, config, add_pooling_layer=True):
|
582 |
+
super().__init__(config)
|
583 |
+
self.config = config
|
584 |
+
|
585 |
+
self.embeddings = BertEmbeddings(config)
|
586 |
+
|
587 |
+
self.encoder = BertEncoder(config)
|
588 |
+
|
589 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
590 |
+
|
591 |
+
self.init_weights()
|
592 |
+
|
593 |
+
|
594 |
+
def get_input_embeddings(self):
|
595 |
+
return self.embeddings.word_embeddings
|
596 |
+
|
597 |
+
def set_input_embeddings(self, value):
|
598 |
+
self.embeddings.word_embeddings = value
|
599 |
+
|
600 |
+
def _prune_heads(self, heads_to_prune):
|
601 |
+
"""
|
602 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
603 |
+
class PreTrainedModel
|
604 |
+
"""
|
605 |
+
for layer, heads in heads_to_prune.items():
|
606 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
607 |
+
|
608 |
+
|
609 |
+
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
610 |
+
"""
|
611 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
612 |
+
|
613 |
+
Arguments:
|
614 |
+
attention_mask (:obj:`torch.Tensor`):
|
615 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
616 |
+
input_shape (:obj:`Tuple[int]`):
|
617 |
+
The shape of the input to the model.
|
618 |
+
device: (:obj:`torch.device`):
|
619 |
+
The device of the input to the model.
|
620 |
+
|
621 |
+
Returns:
|
622 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
623 |
+
"""
|
624 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
625 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
626 |
+
if attention_mask.dim() == 3:
|
627 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
628 |
+
elif attention_mask.dim() == 2:
|
629 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
630 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
631 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
632 |
+
if is_decoder:
|
633 |
+
batch_size, seq_length = input_shape
|
634 |
+
|
635 |
+
seq_ids = torch.arange(seq_length, device=device)
|
636 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
637 |
+
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
638 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
639 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
640 |
+
|
641 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
642 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
643 |
+
causal_mask = torch.cat(
|
644 |
+
[
|
645 |
+
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
646 |
+
causal_mask,
|
647 |
+
],
|
648 |
+
axis=-1,
|
649 |
+
)
|
650 |
+
|
651 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
652 |
+
else:
|
653 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
654 |
+
else:
|
655 |
+
raise ValueError(
|
656 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
657 |
+
input_shape, attention_mask.shape
|
658 |
+
)
|
659 |
+
)
|
660 |
+
|
661 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
662 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
663 |
+
# positions we want to attend and -10000.0 for masked positions.
|
664 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
665 |
+
# effectively the same as removing these entirely.
|
666 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
667 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
668 |
+
return extended_attention_mask
|
669 |
+
|
670 |
+
def forward(
|
671 |
+
self,
|
672 |
+
input_ids=None,
|
673 |
+
attention_mask=None,
|
674 |
+
position_ids=None,
|
675 |
+
head_mask=None,
|
676 |
+
inputs_embeds=None,
|
677 |
+
encoder_embeds=None,
|
678 |
+
encoder_hidden_states=None,
|
679 |
+
encoder_attention_mask=None,
|
680 |
+
past_key_values=None,
|
681 |
+
use_cache=None,
|
682 |
+
output_attentions=None,
|
683 |
+
output_hidden_states=None,
|
684 |
+
return_dict=None,
|
685 |
+
is_decoder=False,
|
686 |
+
mode='multimodal',
|
687 |
+
):
|
688 |
+
r"""
|
689 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
690 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
691 |
+
the model is configured as a decoder.
|
692 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
693 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
694 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
695 |
+
- 1 for tokens that are **not masked**,
|
696 |
+
- 0 for tokens that are **masked**.
|
697 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
698 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
699 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
700 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
701 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
702 |
+
use_cache (:obj:`bool`, `optional`):
|
703 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
704 |
+
decoding (see :obj:`past_key_values`).
|
705 |
+
"""
|
706 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
707 |
+
output_hidden_states = (
|
708 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
709 |
+
)
|
710 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
711 |
+
|
712 |
+
if is_decoder:
|
713 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
714 |
+
else:
|
715 |
+
use_cache = False
|
716 |
+
|
717 |
+
if input_ids is not None and inputs_embeds is not None:
|
718 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
719 |
+
elif input_ids is not None:
|
720 |
+
input_shape = input_ids.size()
|
721 |
+
batch_size, seq_length = input_shape
|
722 |
+
device = input_ids.device
|
723 |
+
elif inputs_embeds is not None:
|
724 |
+
input_shape = inputs_embeds.size()[:-1]
|
725 |
+
batch_size, seq_length = input_shape
|
726 |
+
device = inputs_embeds.device
|
727 |
+
elif encoder_embeds is not None:
|
728 |
+
input_shape = encoder_embeds.size()[:-1]
|
729 |
+
batch_size, seq_length = input_shape
|
730 |
+
device = encoder_embeds.device
|
731 |
+
else:
|
732 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
733 |
+
|
734 |
+
# past_key_values_length
|
735 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
736 |
+
|
737 |
+
if attention_mask is None:
|
738 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
739 |
+
|
740 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
741 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
742 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
743 |
+
device, is_decoder)
|
744 |
+
|
745 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
746 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
747 |
+
if encoder_hidden_states is not None:
|
748 |
+
if type(encoder_hidden_states) == list:
|
749 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
750 |
+
else:
|
751 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
752 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
753 |
+
|
754 |
+
if type(encoder_attention_mask) == list:
|
755 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
756 |
+
elif encoder_attention_mask is None:
|
757 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
758 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
759 |
+
else:
|
760 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
761 |
+
else:
|
762 |
+
encoder_extended_attention_mask = None
|
763 |
+
|
764 |
+
# Prepare head mask if needed
|
765 |
+
# 1.0 in head_mask indicate we keep the head
|
766 |
+
# attention_probs has shape bsz x n_heads x N x N
|
767 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
768 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
769 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
770 |
+
|
771 |
+
if encoder_embeds is None:
|
772 |
+
embedding_output = self.embeddings(
|
773 |
+
input_ids=input_ids,
|
774 |
+
position_ids=position_ids,
|
775 |
+
inputs_embeds=inputs_embeds,
|
776 |
+
past_key_values_length=past_key_values_length,
|
777 |
+
)
|
778 |
+
else:
|
779 |
+
embedding_output = encoder_embeds
|
780 |
+
|
781 |
+
encoder_outputs = self.encoder(
|
782 |
+
embedding_output,
|
783 |
+
attention_mask=extended_attention_mask,
|
784 |
+
head_mask=head_mask,
|
785 |
+
encoder_hidden_states=encoder_hidden_states,
|
786 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
787 |
+
past_key_values=past_key_values,
|
788 |
+
use_cache=use_cache,
|
789 |
+
output_attentions=output_attentions,
|
790 |
+
output_hidden_states=output_hidden_states,
|
791 |
+
return_dict=return_dict,
|
792 |
+
mode=mode,
|
793 |
+
)
|
794 |
+
sequence_output = encoder_outputs[0]
|
795 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
796 |
+
|
797 |
+
if not return_dict:
|
798 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
799 |
+
|
800 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
801 |
+
last_hidden_state=sequence_output,
|
802 |
+
pooler_output=pooled_output,
|
803 |
+
past_key_values=encoder_outputs.past_key_values,
|
804 |
+
hidden_states=encoder_outputs.hidden_states,
|
805 |
+
attentions=encoder_outputs.attentions,
|
806 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
807 |
+
)
|
808 |
+
|
809 |
+
|
810 |
+
|
811 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
812 |
+
|
813 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
814 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
815 |
+
|
816 |
+
def __init__(self, config):
|
817 |
+
super().__init__(config)
|
818 |
+
|
819 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
820 |
+
self.cls = BertOnlyMLMHead(config)
|
821 |
+
|
822 |
+
self.init_weights()
|
823 |
+
|
824 |
+
def get_output_embeddings(self):
|
825 |
+
return self.cls.predictions.decoder
|
826 |
+
|
827 |
+
def set_output_embeddings(self, new_embeddings):
|
828 |
+
self.cls.predictions.decoder = new_embeddings
|
829 |
+
|
830 |
+
def forward(
|
831 |
+
self,
|
832 |
+
input_ids=None,
|
833 |
+
attention_mask=None,
|
834 |
+
position_ids=None,
|
835 |
+
head_mask=None,
|
836 |
+
inputs_embeds=None,
|
837 |
+
encoder_hidden_states=None,
|
838 |
+
encoder_attention_mask=None,
|
839 |
+
labels=None,
|
840 |
+
past_key_values=None,
|
841 |
+
use_cache=None,
|
842 |
+
output_attentions=None,
|
843 |
+
output_hidden_states=None,
|
844 |
+
return_dict=None,
|
845 |
+
return_logits=False,
|
846 |
+
is_decoder=True,
|
847 |
+
reduction='mean',
|
848 |
+
mode='multimodal',
|
849 |
+
):
|
850 |
+
r"""
|
851 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
852 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
853 |
+
the model is configured as a decoder.
|
854 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
855 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
856 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
857 |
+
- 1 for tokens that are **not masked**,
|
858 |
+
- 0 for tokens that are **masked**.
|
859 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
860 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
861 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
862 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
863 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
864 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
865 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
866 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
867 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
868 |
+
use_cache (:obj:`bool`, `optional`):
|
869 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
870 |
+
decoding (see :obj:`past_key_values`).
|
871 |
+
Returns:
|
872 |
+
Example::
|
873 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
874 |
+
>>> import torch
|
875 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
876 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
877 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
878 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
879 |
+
>>> outputs = model(**inputs)
|
880 |
+
>>> prediction_logits = outputs.logits
|
881 |
+
"""
|
882 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
883 |
+
if labels is not None:
|
884 |
+
use_cache = False
|
885 |
+
|
886 |
+
outputs = self.bert(
|
887 |
+
input_ids,
|
888 |
+
attention_mask=attention_mask,
|
889 |
+
position_ids=position_ids,
|
890 |
+
head_mask=head_mask,
|
891 |
+
inputs_embeds=inputs_embeds,
|
892 |
+
encoder_hidden_states=encoder_hidden_states,
|
893 |
+
encoder_attention_mask=encoder_attention_mask,
|
894 |
+
past_key_values=past_key_values,
|
895 |
+
use_cache=use_cache,
|
896 |
+
output_attentions=output_attentions,
|
897 |
+
output_hidden_states=output_hidden_states,
|
898 |
+
return_dict=return_dict,
|
899 |
+
is_decoder=is_decoder,
|
900 |
+
mode=mode,
|
901 |
+
)
|
902 |
+
|
903 |
+
sequence_output = outputs[0]
|
904 |
+
prediction_scores = self.cls(sequence_output)
|
905 |
+
|
906 |
+
if return_logits:
|
907 |
+
return prediction_scores[:, :-1, :].contiguous()
|
908 |
+
|
909 |
+
lm_loss = None
|
910 |
+
if labels is not None:
|
911 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
912 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
913 |
+
labels = labels[:, 1:].contiguous()
|
914 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
915 |
+
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
916 |
+
if reduction=='none':
|
917 |
+
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
|
918 |
+
|
919 |
+
if not return_dict:
|
920 |
+
output = (prediction_scores,) + outputs[2:]
|
921 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
922 |
+
|
923 |
+
return CausalLMOutputWithCrossAttentions(
|
924 |
+
loss=lm_loss,
|
925 |
+
logits=prediction_scores,
|
926 |
+
past_key_values=outputs.past_key_values,
|
927 |
+
hidden_states=outputs.hidden_states,
|
928 |
+
attentions=outputs.attentions,
|
929 |
+
cross_attentions=outputs.cross_attentions,
|
930 |
+
)
|
931 |
+
|
932 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
933 |
+
input_shape = input_ids.shape
|
934 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
935 |
+
if attention_mask is None:
|
936 |
+
attention_mask = input_ids.new_ones(input_shape)
|
937 |
+
|
938 |
+
# cut decoder_input_ids if past is used
|
939 |
+
if past is not None:
|
940 |
+
input_ids = input_ids[:, -1:]
|
941 |
+
|
942 |
+
return {
|
943 |
+
"input_ids": input_ids,
|
944 |
+
"attention_mask": attention_mask,
|
945 |
+
"past_key_values": past,
|
946 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
947 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
948 |
+
"is_decoder": True,
|
949 |
+
}
|
950 |
+
|
951 |
+
def _reorder_cache(self, past, beam_idx):
|
952 |
+
reordered_past = ()
|
953 |
+
for layer_past in past:
|
954 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
955 |
+
return reordered_past
|
models/blip_override/med_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModel"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 12,
|
15 |
+
"num_hidden_layers": 12,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"type_vocab_size": 2,
|
18 |
+
"vocab_size": 30524,
|
19 |
+
"encoder_width": 768,
|
20 |
+
"add_cross_attention": true
|
21 |
+
}
|
models/blip_override/vit.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
* Based on timm code base
|
8 |
+
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
9 |
+
'''
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from functools import partial
|
15 |
+
|
16 |
+
from timm.models.vision_transformer import _cfg, PatchEmbed
|
17 |
+
from timm.models.registry import register_model
|
18 |
+
from timm.models.layers import trunc_normal_, DropPath
|
19 |
+
from timm.models.helpers import named_apply, adapt_input_conv
|
20 |
+
|
21 |
+
|
22 |
+
class Mlp(nn.Module):
|
23 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
27 |
+
super().__init__()
|
28 |
+
out_features = out_features or in_features
|
29 |
+
hidden_features = hidden_features or in_features
|
30 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
31 |
+
self.act = act_layer()
|
32 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
33 |
+
self.drop = nn.Dropout(drop)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
x = self.fc1(x)
|
37 |
+
x = self.act(x)
|
38 |
+
x = self.drop(x)
|
39 |
+
x = self.fc2(x)
|
40 |
+
x = self.drop(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class Attention(nn.Module):
|
45 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
46 |
+
super().__init__()
|
47 |
+
self.num_heads = num_heads
|
48 |
+
head_dim = dim // num_heads
|
49 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
50 |
+
self.scale = qk_scale or head_dim ** -0.5
|
51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
53 |
+
self.proj = nn.Linear(dim, dim)
|
54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
55 |
+
self.attn_gradients = None
|
56 |
+
self.attention_map = None
|
57 |
+
|
58 |
+
def save_attn_gradients(self, attn_gradients):
|
59 |
+
self.attn_gradients = attn_gradients
|
60 |
+
|
61 |
+
def get_attn_gradients(self):
|
62 |
+
return self.attn_gradients
|
63 |
+
|
64 |
+
def save_attention_map(self, attention_map):
|
65 |
+
self.attention_map = attention_map
|
66 |
+
|
67 |
+
def get_attention_map(self):
|
68 |
+
return self.attention_map
|
69 |
+
|
70 |
+
def forward(self, x, register_hook=False):
|
71 |
+
B, N, C = x.shape
|
72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
73 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
74 |
+
|
75 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
76 |
+
attn = attn.softmax(dim=-1)
|
77 |
+
attn = self.attn_drop(attn)
|
78 |
+
|
79 |
+
if register_hook:
|
80 |
+
self.save_attention_map(attn)
|
81 |
+
attn.register_hook(self.save_attn_gradients)
|
82 |
+
|
83 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
84 |
+
x = self.proj(x)
|
85 |
+
x = self.proj_drop(x)
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class Block(nn.Module):
|
90 |
+
|
91 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
92 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
|
93 |
+
super().__init__()
|
94 |
+
self.norm1 = norm_layer(dim)
|
95 |
+
self.attn = Attention(
|
96 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
97 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
98 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
99 |
+
self.norm2 = norm_layer(dim)
|
100 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
101 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
102 |
+
|
103 |
+
def forward(self, x, register_hook=False):
|
104 |
+
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
105 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
class VisionTransformer(nn.Module):
|
110 |
+
""" Vision Transformer
|
111 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
112 |
+
https://arxiv.org/abs/2010.11929
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
116 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
117 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
|
118 |
+
use_grad_checkpointing=False, ckpt_layer=0):
|
119 |
+
"""
|
120 |
+
Args:
|
121 |
+
img_size (int, tuple): input image size
|
122 |
+
patch_size (int, tuple): patch size
|
123 |
+
in_chans (int): number of input channels
|
124 |
+
num_classes (int): number of classes for classification head
|
125 |
+
embed_dim (int): embedding dimension
|
126 |
+
depth (int): depth of transformer
|
127 |
+
num_heads (int): number of attention heads
|
128 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
129 |
+
qkv_bias (bool): enable bias for qkv if True
|
130 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
131 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
132 |
+
drop_rate (float): dropout rate
|
133 |
+
attn_drop_rate (float): attention dropout rate
|
134 |
+
drop_path_rate (float): stochastic depth rate
|
135 |
+
norm_layer: (nn.Module): normalization layer
|
136 |
+
"""
|
137 |
+
super().__init__()
|
138 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
139 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
140 |
+
|
141 |
+
self.patch_embed = PatchEmbed(
|
142 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
143 |
+
|
144 |
+
num_patches = self.patch_embed.num_patches
|
145 |
+
|
146 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
147 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
148 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
149 |
+
|
150 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
151 |
+
self.blocks = nn.ModuleList([
|
152 |
+
Block(
|
153 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
154 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
155 |
+
use_grad_checkpointing=(use_grad_checkpointing and i >= depth - ckpt_layer)
|
156 |
+
)
|
157 |
+
for i in range(depth)])
|
158 |
+
self.norm = norm_layer(embed_dim)
|
159 |
+
|
160 |
+
trunc_normal_(self.pos_embed, std=.02)
|
161 |
+
trunc_normal_(self.cls_token, std=.02)
|
162 |
+
self.apply(self._init_weights)
|
163 |
+
|
164 |
+
def _init_weights(self, m):
|
165 |
+
if isinstance(m, nn.Linear):
|
166 |
+
trunc_normal_(m.weight, std=.02)
|
167 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
168 |
+
nn.init.constant_(m.bias, 0)
|
169 |
+
elif isinstance(m, nn.LayerNorm):
|
170 |
+
nn.init.constant_(m.bias, 0)
|
171 |
+
nn.init.constant_(m.weight, 1.0)
|
172 |
+
|
173 |
+
@torch.jit.ignore
|
174 |
+
def no_weight_decay(self):
|
175 |
+
return {'pos_embed', 'cls_token'}
|
176 |
+
|
177 |
+
def forward(self, x, register_blk=-1):
|
178 |
+
B = x.shape[0]
|
179 |
+
x = self.patch_embed(x)
|
180 |
+
|
181 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
182 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
183 |
+
|
184 |
+
x = x + self.pos_embed[:, :x.size(1), :]
|
185 |
+
x = self.pos_drop(x)
|
186 |
+
|
187 |
+
for i, blk in enumerate(self.blocks):
|
188 |
+
x = blk(x, register_blk == i)
|
189 |
+
x = self.norm(x)
|
190 |
+
|
191 |
+
return x
|
192 |
+
|
193 |
+
@torch.jit.ignore()
|
194 |
+
def load_pretrained(self, checkpoint_path, prefix=''):
|
195 |
+
_load_weights(self, checkpoint_path, prefix)
|
196 |
+
|
197 |
+
|
198 |
+
@torch.no_grad()
|
199 |
+
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
200 |
+
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
201 |
+
"""
|
202 |
+
import numpy as np
|
203 |
+
|
204 |
+
def _n2p(w, t=True):
|
205 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
206 |
+
w = w.flatten()
|
207 |
+
if t:
|
208 |
+
if w.ndim == 4:
|
209 |
+
w = w.transpose([3, 2, 0, 1])
|
210 |
+
elif w.ndim == 3:
|
211 |
+
w = w.transpose([2, 0, 1])
|
212 |
+
elif w.ndim == 2:
|
213 |
+
w = w.transpose([1, 0])
|
214 |
+
return torch.from_numpy(w)
|
215 |
+
|
216 |
+
w = np.load(checkpoint_path)
|
217 |
+
if not prefix and 'opt/target/embedding/kernel' in w:
|
218 |
+
prefix = 'opt/target/'
|
219 |
+
|
220 |
+
if hasattr(model.patch_embed, 'backbone'):
|
221 |
+
# hybrid
|
222 |
+
backbone = model.patch_embed.backbone
|
223 |
+
stem_only = not hasattr(backbone, 'stem')
|
224 |
+
stem = backbone if stem_only else backbone.stem
|
225 |
+
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
226 |
+
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
227 |
+
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
228 |
+
if not stem_only:
|
229 |
+
for i, stage in enumerate(backbone.stages):
|
230 |
+
for j, block in enumerate(stage.blocks):
|
231 |
+
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
232 |
+
for r in range(3):
|
233 |
+
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
234 |
+
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
235 |
+
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
236 |
+
if block.downsample is not None:
|
237 |
+
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
238 |
+
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
239 |
+
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
240 |
+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
241 |
+
else:
|
242 |
+
embed_conv_w = adapt_input_conv(
|
243 |
+
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
244 |
+
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
245 |
+
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
246 |
+
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
247 |
+
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
248 |
+
if pos_embed_w.shape != model.pos_embed.shape:
|
249 |
+
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
250 |
+
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
251 |
+
model.pos_embed.copy_(pos_embed_w)
|
252 |
+
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
253 |
+
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
254 |
+
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
255 |
+
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
256 |
+
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
257 |
+
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
258 |
+
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
259 |
+
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
260 |
+
for i, block in enumerate(model.blocks.children()):
|
261 |
+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
262 |
+
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
263 |
+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
264 |
+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
265 |
+
block.attn.qkv.weight.copy_(torch.cat([
|
266 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
267 |
+
block.attn.qkv.bias.copy_(torch.cat([
|
268 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
269 |
+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
270 |
+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
271 |
+
for r in range(2):
|
272 |
+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
273 |
+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
274 |
+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
275 |
+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
276 |
+
|
277 |
+
|
278 |
+
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
279 |
+
# interpolate position embedding
|
280 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
281 |
+
num_patches = visual_encoder.patch_embed.num_patches
|
282 |
+
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
283 |
+
# height (== width) for the checkpoint position embedding
|
284 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
285 |
+
# height (== width) for the new position embedding
|
286 |
+
new_size = int(num_patches ** 0.5)
|
287 |
+
|
288 |
+
if orig_size != new_size:
|
289 |
+
# class_token and dist_token are kept unchanged
|
290 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
291 |
+
# only the position tokens are interpolated
|
292 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
293 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
294 |
+
pos_tokens = torch.nn.functional.interpolate(
|
295 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
296 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
297 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
298 |
+
print('reshape position embedding from %d to %d' % (orig_size ** 2, new_size ** 2))
|
299 |
+
|
300 |
+
return new_pos_embed
|
301 |
+
else:
|
302 |
+
return pos_embed_checkpoint
|
models/diffusers_override/attention.py
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Optional
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from diffusers.modeling_utils import ModelMixin
|
24 |
+
from diffusers.models.embeddings import ImagePositionalEmbeddings
|
25 |
+
from diffusers.utils import BaseOutput
|
26 |
+
from diffusers.utils.import_utils import is_xformers_available
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class Transformer2DModelOutput(BaseOutput):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
34 |
+
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
|
35 |
+
for the unnoised latent pixels.
|
36 |
+
"""
|
37 |
+
|
38 |
+
sample: torch.FloatTensor
|
39 |
+
|
40 |
+
|
41 |
+
if is_xformers_available():
|
42 |
+
import xformers
|
43 |
+
import xformers.ops
|
44 |
+
else:
|
45 |
+
xformers = None
|
46 |
+
|
47 |
+
|
48 |
+
class Transformer2DModel(ModelMixin, ConfigMixin):
|
49 |
+
"""
|
50 |
+
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
|
51 |
+
embeddings) inputs.
|
52 |
+
|
53 |
+
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
|
54 |
+
transformer action. Finally, reshape to image.
|
55 |
+
|
56 |
+
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
|
57 |
+
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
|
58 |
+
classes of unnoised image.
|
59 |
+
|
60 |
+
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
|
61 |
+
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
|
62 |
+
|
63 |
+
Parameters:
|
64 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
65 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
66 |
+
in_channels (`int`, *optional*):
|
67 |
+
Pass if the input is continuous. The number of channels in the input and output.
|
68 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
69 |
+
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
70 |
+
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
|
71 |
+
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
72 |
+
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
73 |
+
`ImagePositionalEmbeddings`.
|
74 |
+
num_vector_embeds (`int`, *optional*):
|
75 |
+
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
76 |
+
Includes the class for the masked latent pixel.
|
77 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
78 |
+
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
79 |
+
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
80 |
+
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
81 |
+
up to but not more than steps than `num_embeds_ada_norm`.
|
82 |
+
attention_bias (`bool`, *optional*):
|
83 |
+
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
84 |
+
"""
|
85 |
+
|
86 |
+
@register_to_config
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
num_attention_heads: int = 16,
|
90 |
+
attention_head_dim: int = 88,
|
91 |
+
in_channels: Optional[int] = None,
|
92 |
+
num_layers: int = 1,
|
93 |
+
dropout: float = 0.0,
|
94 |
+
norm_num_groups: int = 32,
|
95 |
+
cross_attention_dim: Optional[int] = None,
|
96 |
+
attention_bias: bool = False,
|
97 |
+
sample_size: Optional[int] = None,
|
98 |
+
num_vector_embeds: Optional[int] = None,
|
99 |
+
activation_fn: str = "geglu",
|
100 |
+
num_embeds_ada_norm: Optional[int] = None,
|
101 |
+
):
|
102 |
+
super().__init__()
|
103 |
+
self.num_attention_heads = num_attention_heads
|
104 |
+
self.attention_head_dim = attention_head_dim
|
105 |
+
inner_dim = num_attention_heads * attention_head_dim
|
106 |
+
|
107 |
+
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
108 |
+
# Define whether input is continuous or discrete depending on configuration
|
109 |
+
self.is_input_continuous = in_channels is not None
|
110 |
+
self.is_input_vectorized = num_vector_embeds is not None
|
111 |
+
|
112 |
+
if self.is_input_continuous and self.is_input_vectorized:
|
113 |
+
raise ValueError(
|
114 |
+
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
115 |
+
" sure that either `in_channels` or `num_vector_embeds` is None."
|
116 |
+
)
|
117 |
+
elif not self.is_input_continuous and not self.is_input_vectorized:
|
118 |
+
raise ValueError(
|
119 |
+
f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
|
120 |
+
" sure that either `in_channels` or `num_vector_embeds` is not None."
|
121 |
+
)
|
122 |
+
|
123 |
+
# 2. Define input layers
|
124 |
+
if self.is_input_continuous:
|
125 |
+
self.in_channels = in_channels
|
126 |
+
|
127 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
128 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
129 |
+
elif self.is_input_vectorized:
|
130 |
+
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
131 |
+
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
132 |
+
|
133 |
+
self.height = sample_size
|
134 |
+
self.width = sample_size
|
135 |
+
self.num_vector_embeds = num_vector_embeds
|
136 |
+
self.num_latent_pixels = self.height * self.width
|
137 |
+
|
138 |
+
self.latent_image_embedding = ImagePositionalEmbeddings(
|
139 |
+
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
140 |
+
)
|
141 |
+
|
142 |
+
# 3. Define transformers blocks
|
143 |
+
self.transformer_blocks = nn.ModuleList(
|
144 |
+
[
|
145 |
+
BasicTransformerBlock(
|
146 |
+
inner_dim,
|
147 |
+
num_attention_heads,
|
148 |
+
attention_head_dim,
|
149 |
+
dropout=dropout,
|
150 |
+
cross_attention_dim=cross_attention_dim,
|
151 |
+
activation_fn=activation_fn,
|
152 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
153 |
+
attention_bias=attention_bias,
|
154 |
+
)
|
155 |
+
for d in range(num_layers)
|
156 |
+
]
|
157 |
+
)
|
158 |
+
|
159 |
+
# 4. Define output layers
|
160 |
+
if self.is_input_continuous:
|
161 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
162 |
+
elif self.is_input_vectorized:
|
163 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
164 |
+
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
165 |
+
|
166 |
+
def _set_attention_slice(self, slice_size):
|
167 |
+
for block in self.transformer_blocks:
|
168 |
+
block._set_attention_slice(slice_size)
|
169 |
+
|
170 |
+
def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None, timestep=None,
|
171 |
+
return_dict: bool = True):
|
172 |
+
"""
|
173 |
+
Args:
|
174 |
+
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
175 |
+
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
176 |
+
hidden_states
|
177 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
|
178 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
179 |
+
self-attention.
|
180 |
+
encoder_attention_mask ( `torch.LongTensor` of shape `(batch size, context)`, *optional*):
|
181 |
+
Attention mask for cross attention layer.
|
182 |
+
timestep ( `torch.long`, *optional*):
|
183 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
184 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
185 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
|
189 |
+
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
|
190 |
+
tensor.
|
191 |
+
"""
|
192 |
+
# 1. Input
|
193 |
+
if self.is_input_continuous:
|
194 |
+
batch, channel, height, weight = hidden_states.shape
|
195 |
+
residual = hidden_states
|
196 |
+
hidden_states = self.norm(hidden_states)
|
197 |
+
hidden_states = self.proj_in(hidden_states)
|
198 |
+
inner_dim = hidden_states.shape[1]
|
199 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
200 |
+
elif self.is_input_vectorized:
|
201 |
+
hidden_states = self.latent_image_embedding(hidden_states)
|
202 |
+
|
203 |
+
# 2. Blocks
|
204 |
+
for block in self.transformer_blocks:
|
205 |
+
hidden_states = block(hidden_states, context=encoder_hidden_states, mask=encoder_attention_mask,
|
206 |
+
timestep=timestep)
|
207 |
+
|
208 |
+
# 3. Output
|
209 |
+
if self.is_input_continuous:
|
210 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
211 |
+
hidden_states = self.proj_out(hidden_states)
|
212 |
+
output = hidden_states + residual
|
213 |
+
elif self.is_input_vectorized:
|
214 |
+
hidden_states = self.norm_out(hidden_states)
|
215 |
+
logits = self.out(hidden_states)
|
216 |
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
217 |
+
logits = logits.permute(0, 2, 1)
|
218 |
+
|
219 |
+
# log(p(x_0))
|
220 |
+
output = F.log_softmax(logits.double(), dim=1).float()
|
221 |
+
|
222 |
+
if not return_dict:
|
223 |
+
return (output,)
|
224 |
+
|
225 |
+
return Transformer2DModelOutput(sample=output)
|
226 |
+
|
227 |
+
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
228 |
+
for block in self.transformer_blocks:
|
229 |
+
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
230 |
+
|
231 |
+
|
232 |
+
class AttentionBlock(nn.Module):
|
233 |
+
"""
|
234 |
+
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
|
235 |
+
to the N-d case.
|
236 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
237 |
+
Uses three q, k, v linear layers to compute attention.
|
238 |
+
|
239 |
+
Parameters:
|
240 |
+
channels (`int`): The number of channels in the input and output.
|
241 |
+
num_head_channels (`int`, *optional*):
|
242 |
+
The number of channels in each head. If None, then `num_heads` = 1.
|
243 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
|
244 |
+
rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
|
245 |
+
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
|
246 |
+
"""
|
247 |
+
|
248 |
+
def __init__(
|
249 |
+
self,
|
250 |
+
channels: int,
|
251 |
+
num_head_channels: Optional[int] = None,
|
252 |
+
norm_num_groups: int = 32,
|
253 |
+
rescale_output_factor: float = 1.0,
|
254 |
+
eps: float = 1e-5,
|
255 |
+
):
|
256 |
+
super().__init__()
|
257 |
+
self.channels = channels
|
258 |
+
|
259 |
+
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
|
260 |
+
self.num_head_size = num_head_channels
|
261 |
+
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
|
262 |
+
|
263 |
+
# define q,k,v as linear layers
|
264 |
+
self.query = nn.Linear(channels, channels)
|
265 |
+
self.key = nn.Linear(channels, channels)
|
266 |
+
self.value = nn.Linear(channels, channels)
|
267 |
+
|
268 |
+
self.rescale_output_factor = rescale_output_factor
|
269 |
+
self.proj_attn = nn.Linear(channels, channels, 1)
|
270 |
+
|
271 |
+
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
272 |
+
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
273 |
+
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
274 |
+
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
275 |
+
return new_projection
|
276 |
+
|
277 |
+
def forward(self, hidden_states):
|
278 |
+
residual = hidden_states
|
279 |
+
batch, channel, height, width = hidden_states.shape
|
280 |
+
|
281 |
+
# norm
|
282 |
+
hidden_states = self.group_norm(hidden_states)
|
283 |
+
|
284 |
+
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
285 |
+
|
286 |
+
# proj to q, k, v
|
287 |
+
query_proj = self.query(hidden_states)
|
288 |
+
key_proj = self.key(hidden_states)
|
289 |
+
value_proj = self.value(hidden_states)
|
290 |
+
|
291 |
+
# transpose
|
292 |
+
query_states = self.transpose_for_scores(query_proj)
|
293 |
+
key_states = self.transpose_for_scores(key_proj)
|
294 |
+
value_states = self.transpose_for_scores(value_proj)
|
295 |
+
|
296 |
+
# get scores
|
297 |
+
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
|
298 |
+
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
|
299 |
+
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
|
300 |
+
|
301 |
+
# compute attention output
|
302 |
+
hidden_states = torch.matmul(attention_probs, value_states)
|
303 |
+
|
304 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
305 |
+
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
306 |
+
hidden_states = hidden_states.view(new_hidden_states_shape)
|
307 |
+
|
308 |
+
# compute next hidden_states
|
309 |
+
hidden_states = self.proj_attn(hidden_states)
|
310 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
311 |
+
|
312 |
+
# res connect and rescale
|
313 |
+
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
314 |
+
return hidden_states
|
315 |
+
|
316 |
+
|
317 |
+
class BasicTransformerBlock(nn.Module):
|
318 |
+
r"""
|
319 |
+
A basic Transformer block.
|
320 |
+
|
321 |
+
Parameters:
|
322 |
+
dim (`int`): The number of channels in the input and output.
|
323 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
324 |
+
attention_head_dim (`int`): The number of channels in each head.
|
325 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
326 |
+
cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
|
327 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
328 |
+
num_embeds_ada_norm (:
|
329 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
330 |
+
attention_bias (:
|
331 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
332 |
+
"""
|
333 |
+
|
334 |
+
def __init__(
|
335 |
+
self,
|
336 |
+
dim: int,
|
337 |
+
num_attention_heads: int,
|
338 |
+
attention_head_dim: int,
|
339 |
+
dropout=0.0,
|
340 |
+
cross_attention_dim: Optional[int] = None,
|
341 |
+
activation_fn: str = "geglu",
|
342 |
+
num_embeds_ada_norm: Optional[int] = None,
|
343 |
+
attention_bias: bool = False,
|
344 |
+
):
|
345 |
+
super().__init__()
|
346 |
+
self.attn1 = CrossAttention(
|
347 |
+
query_dim=dim,
|
348 |
+
heads=num_attention_heads,
|
349 |
+
dim_head=attention_head_dim,
|
350 |
+
dropout=dropout,
|
351 |
+
bias=attention_bias,
|
352 |
+
) # is a self-attention
|
353 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
354 |
+
self.attn2 = CrossAttention(
|
355 |
+
query_dim=dim,
|
356 |
+
cross_attention_dim=cross_attention_dim,
|
357 |
+
heads=num_attention_heads,
|
358 |
+
dim_head=attention_head_dim,
|
359 |
+
dropout=dropout,
|
360 |
+
bias=attention_bias,
|
361 |
+
) # is self-attn if context is none
|
362 |
+
|
363 |
+
# layer norms
|
364 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
365 |
+
if self.use_ada_layer_norm:
|
366 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
367 |
+
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
368 |
+
else:
|
369 |
+
self.norm1 = nn.LayerNorm(dim)
|
370 |
+
self.norm2 = nn.LayerNorm(dim)
|
371 |
+
self.norm3 = nn.LayerNorm(dim)
|
372 |
+
|
373 |
+
def _set_attention_slice(self, slice_size):
|
374 |
+
self.attn1._slice_size = slice_size
|
375 |
+
self.attn2._slice_size = slice_size
|
376 |
+
|
377 |
+
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
378 |
+
if not is_xformers_available():
|
379 |
+
print("Here is how to install it")
|
380 |
+
raise ModuleNotFoundError(
|
381 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
382 |
+
" xformers",
|
383 |
+
name="xformers",
|
384 |
+
)
|
385 |
+
elif not torch.cuda.is_available():
|
386 |
+
raise ValueError(
|
387 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
388 |
+
" available for GPU "
|
389 |
+
)
|
390 |
+
else:
|
391 |
+
try:
|
392 |
+
# Make sure we can run the memory efficient attention
|
393 |
+
_ = xformers.ops.memory_efficient_attention(
|
394 |
+
torch.randn((1, 2, 40), device="cuda"),
|
395 |
+
torch.randn((1, 2, 40), device="cuda"),
|
396 |
+
torch.randn((1, 2, 40), device="cuda"),
|
397 |
+
)
|
398 |
+
except Exception as e:
|
399 |
+
raise e
|
400 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
401 |
+
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
402 |
+
|
403 |
+
def forward(self, hidden_states, context=None, mask=None, timestep=None):
|
404 |
+
# 1. Self-Attention
|
405 |
+
norm_hidden_states = (
|
406 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
407 |
+
)
|
408 |
+
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
409 |
+
|
410 |
+
# 2. Cross-Attention
|
411 |
+
norm_hidden_states = (
|
412 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
413 |
+
)
|
414 |
+
hidden_states = self.attn2(norm_hidden_states, context=context, mask=mask) + hidden_states
|
415 |
+
|
416 |
+
# 3. Feed-forward
|
417 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
418 |
+
|
419 |
+
return hidden_states
|
420 |
+
|
421 |
+
|
422 |
+
class CrossAttention(nn.Module):
|
423 |
+
r"""
|
424 |
+
A cross attention layer.
|
425 |
+
|
426 |
+
Parameters:
|
427 |
+
query_dim (`int`): The number of channels in the query.
|
428 |
+
cross_attention_dim (`int`, *optional*):
|
429 |
+
The number of channels in the context. If not given, defaults to `query_dim`.
|
430 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
431 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
432 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
433 |
+
bias (`bool`, *optional*, defaults to False):
|
434 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
435 |
+
"""
|
436 |
+
|
437 |
+
def __init__(
|
438 |
+
self,
|
439 |
+
query_dim: int,
|
440 |
+
cross_attention_dim: Optional[int] = None,
|
441 |
+
heads: int = 8,
|
442 |
+
dim_head: int = 64,
|
443 |
+
dropout: float = 0.0,
|
444 |
+
bias=False,
|
445 |
+
):
|
446 |
+
super().__init__()
|
447 |
+
inner_dim = dim_head * heads
|
448 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
449 |
+
|
450 |
+
self.scale = dim_head ** -0.5
|
451 |
+
self.heads = heads
|
452 |
+
# for slice_size > 0 the attention score computation
|
453 |
+
# is split across the batch axis to save memory
|
454 |
+
# You can set slice_size with `set_attention_slice`
|
455 |
+
self._slice_size = None
|
456 |
+
self._use_memory_efficient_attention_xformers = False
|
457 |
+
|
458 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
459 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
460 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
461 |
+
|
462 |
+
self.to_out = nn.ModuleList([])
|
463 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
464 |
+
self.to_out.append(nn.Dropout(dropout))
|
465 |
+
|
466 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
467 |
+
batch_size, seq_len, dim = tensor.shape
|
468 |
+
head_size = self.heads
|
469 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
470 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
471 |
+
return tensor
|
472 |
+
|
473 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
474 |
+
batch_size, seq_len, dim = tensor.shape
|
475 |
+
head_size = self.heads
|
476 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
477 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
478 |
+
return tensor
|
479 |
+
|
480 |
+
def forward(self, hidden_states, context=None, mask=None):
|
481 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
482 |
+
|
483 |
+
query = self.to_q(hidden_states)
|
484 |
+
context = context if context is not None else hidden_states
|
485 |
+
key = self.to_k(context)
|
486 |
+
value = self.to_v(context)
|
487 |
+
|
488 |
+
dim = query.shape[-1]
|
489 |
+
|
490 |
+
query = self.reshape_heads_to_batch_dim(query)
|
491 |
+
key = self.reshape_heads_to_batch_dim(key)
|
492 |
+
value = self.reshape_heads_to_batch_dim(value)
|
493 |
+
mask = mask.repeat_interleave(self.heads, dim=0).unsqueeze(1) if mask is not None else None
|
494 |
+
|
495 |
+
# attention, what we cannot get enough of
|
496 |
+
if self._use_memory_efficient_attention_xformers:
|
497 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
|
498 |
+
else:
|
499 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
500 |
+
hidden_states = self._attention(query, key, value, mask)
|
501 |
+
else:
|
502 |
+
assert mask is None, "masking is not supported for sliced attention"
|
503 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
|
504 |
+
|
505 |
+
# linear proj
|
506 |
+
hidden_states = self.to_out[0](hidden_states)
|
507 |
+
# dropout
|
508 |
+
hidden_states = self.to_out[1](hidden_states)
|
509 |
+
return hidden_states
|
510 |
+
|
511 |
+
def _attention(self, query, key, value, mask):
|
512 |
+
# TODO: use baddbmm for better performance
|
513 |
+
if query.device.type == "mps":
|
514 |
+
# Better performance on mps (~20-25%)
|
515 |
+
attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
|
516 |
+
else:
|
517 |
+
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
|
518 |
+
attention_scores = attention_scores.masked_fill(mask.expand(attention_scores.shape), value=float("-inf")) \
|
519 |
+
if mask is not None else attention_scores
|
520 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
521 |
+
# compute attention output
|
522 |
+
|
523 |
+
if query.device.type == "mps":
|
524 |
+
hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
|
525 |
+
else:
|
526 |
+
hidden_states = torch.matmul(attention_probs, value)
|
527 |
+
|
528 |
+
# reshape hidden_states
|
529 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
530 |
+
return hidden_states
|
531 |
+
|
532 |
+
def _sliced_attention(self, query, key, value, sequence_length, dim):
|
533 |
+
batch_size_attention = query.shape[0]
|
534 |
+
hidden_states = torch.zeros(
|
535 |
+
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
536 |
+
)
|
537 |
+
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
538 |
+
for i in range(hidden_states.shape[0] // slice_size):
|
539 |
+
start_idx = i * slice_size
|
540 |
+
end_idx = (i + 1) * slice_size
|
541 |
+
if query.device.type == "mps":
|
542 |
+
# Better performance on mps (~20-25%)
|
543 |
+
attn_slice = (
|
544 |
+
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
|
545 |
+
* self.scale
|
546 |
+
)
|
547 |
+
else:
|
548 |
+
attn_slice = (
|
549 |
+
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
|
550 |
+
) # TODO: use baddbmm for better performance
|
551 |
+
attn_slice = attn_slice.softmax(dim=-1)
|
552 |
+
if query.device.type == "mps":
|
553 |
+
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
|
554 |
+
else:
|
555 |
+
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
|
556 |
+
|
557 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
558 |
+
|
559 |
+
# reshape hidden_states
|
560 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
561 |
+
return hidden_states
|
562 |
+
|
563 |
+
def _memory_efficient_attention_xformers(self, query, key, value):
|
564 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
|
565 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
566 |
+
return hidden_states
|
567 |
+
|
568 |
+
|
569 |
+
class FeedForward(nn.Module):
|
570 |
+
r"""
|
571 |
+
A feed-forward layer.
|
572 |
+
|
573 |
+
Parameters:
|
574 |
+
dim (`int`): The number of channels in the input.
|
575 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
576 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
577 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
578 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
579 |
+
"""
|
580 |
+
|
581 |
+
def __init__(
|
582 |
+
self,
|
583 |
+
dim: int,
|
584 |
+
dim_out: Optional[int] = None,
|
585 |
+
mult: int = 4,
|
586 |
+
dropout: float = 0.0,
|
587 |
+
activation_fn: str = "geglu",
|
588 |
+
):
|
589 |
+
super().__init__()
|
590 |
+
inner_dim = int(dim * mult)
|
591 |
+
dim_out = dim_out if dim_out is not None else dim
|
592 |
+
|
593 |
+
if activation_fn == "geglu":
|
594 |
+
geglu = GEGLU(dim, inner_dim)
|
595 |
+
elif activation_fn == "geglu-approximate":
|
596 |
+
geglu = ApproximateGELU(dim, inner_dim)
|
597 |
+
|
598 |
+
self.net = nn.ModuleList([])
|
599 |
+
# project in
|
600 |
+
self.net.append(geglu)
|
601 |
+
# project dropout
|
602 |
+
self.net.append(nn.Dropout(dropout))
|
603 |
+
# project out
|
604 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
605 |
+
|
606 |
+
def forward(self, hidden_states):
|
607 |
+
for module in self.net:
|
608 |
+
hidden_states = module(hidden_states)
|
609 |
+
return hidden_states
|
610 |
+
|
611 |
+
|
612 |
+
# feedforward
|
613 |
+
class GEGLU(nn.Module):
|
614 |
+
r"""
|
615 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
616 |
+
|
617 |
+
Parameters:
|
618 |
+
dim_in (`int`): The number of channels in the input.
|
619 |
+
dim_out (`int`): The number of channels in the output.
|
620 |
+
"""
|
621 |
+
|
622 |
+
def __init__(self, dim_in: int, dim_out: int):
|
623 |
+
super().__init__()
|
624 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
625 |
+
|
626 |
+
def gelu(self, gate):
|
627 |
+
if gate.device.type != "mps":
|
628 |
+
return F.gelu(gate)
|
629 |
+
# mps: gelu is not implemented for float16
|
630 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
631 |
+
|
632 |
+
def forward(self, hidden_states):
|
633 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
634 |
+
return hidden_states * self.gelu(gate)
|
635 |
+
|
636 |
+
|
637 |
+
class ApproximateGELU(nn.Module):
|
638 |
+
"""
|
639 |
+
The approximate form of Gaussian Error Linear Unit (GELU)
|
640 |
+
|
641 |
+
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
642 |
+
"""
|
643 |
+
|
644 |
+
def __init__(self, dim_in: int, dim_out: int):
|
645 |
+
super().__init__()
|
646 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
647 |
+
|
648 |
+
def forward(self, x):
|
649 |
+
x = self.proj(x)
|
650 |
+
return x * torch.sigmoid(1.702 * x)
|
651 |
+
|
652 |
+
|
653 |
+
class AdaLayerNorm(nn.Module):
|
654 |
+
"""
|
655 |
+
Norm layer modified to incorporate timestep embeddings.
|
656 |
+
"""
|
657 |
+
|
658 |
+
def __init__(self, embedding_dim, num_embeddings):
|
659 |
+
super().__init__()
|
660 |
+
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
661 |
+
self.silu = nn.SiLU()
|
662 |
+
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
663 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
664 |
+
|
665 |
+
def forward(self, x, timestep):
|
666 |
+
emb = self.linear(self.silu(self.emb(timestep)))
|
667 |
+
scale, shift = torch.chunk(emb, 2)
|
668 |
+
x = self.norm(x) * (1 + scale) + shift
|
669 |
+
return x
|
models/diffusers_override/unet_2d_blocks.py
ADDED
@@ -0,0 +1,1602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
from .attention import AttentionBlock, Transformer2DModel
|
19 |
+
from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
20 |
+
|
21 |
+
|
22 |
+
def get_down_block(
|
23 |
+
down_block_type,
|
24 |
+
num_layers,
|
25 |
+
in_channels,
|
26 |
+
out_channels,
|
27 |
+
temb_channels,
|
28 |
+
add_downsample,
|
29 |
+
resnet_eps,
|
30 |
+
resnet_act_fn,
|
31 |
+
attn_num_head_channels,
|
32 |
+
resnet_groups=None,
|
33 |
+
cross_attention_dim=None,
|
34 |
+
downsample_padding=None,
|
35 |
+
):
|
36 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
37 |
+
if down_block_type == "DownBlock2D":
|
38 |
+
return DownBlock2D(
|
39 |
+
num_layers=num_layers,
|
40 |
+
in_channels=in_channels,
|
41 |
+
out_channels=out_channels,
|
42 |
+
temb_channels=temb_channels,
|
43 |
+
add_downsample=add_downsample,
|
44 |
+
resnet_eps=resnet_eps,
|
45 |
+
resnet_act_fn=resnet_act_fn,
|
46 |
+
resnet_groups=resnet_groups,
|
47 |
+
downsample_padding=downsample_padding,
|
48 |
+
)
|
49 |
+
elif down_block_type == "AttnDownBlock2D":
|
50 |
+
return AttnDownBlock2D(
|
51 |
+
num_layers=num_layers,
|
52 |
+
in_channels=in_channels,
|
53 |
+
out_channels=out_channels,
|
54 |
+
temb_channels=temb_channels,
|
55 |
+
add_downsample=add_downsample,
|
56 |
+
resnet_eps=resnet_eps,
|
57 |
+
resnet_act_fn=resnet_act_fn,
|
58 |
+
resnet_groups=resnet_groups,
|
59 |
+
downsample_padding=downsample_padding,
|
60 |
+
attn_num_head_channels=attn_num_head_channels,
|
61 |
+
)
|
62 |
+
elif down_block_type == "CrossAttnDownBlock2D":
|
63 |
+
if cross_attention_dim is None:
|
64 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
|
65 |
+
return CrossAttnDownBlock2D(
|
66 |
+
num_layers=num_layers,
|
67 |
+
in_channels=in_channels,
|
68 |
+
out_channels=out_channels,
|
69 |
+
temb_channels=temb_channels,
|
70 |
+
add_downsample=add_downsample,
|
71 |
+
resnet_eps=resnet_eps,
|
72 |
+
resnet_act_fn=resnet_act_fn,
|
73 |
+
resnet_groups=resnet_groups,
|
74 |
+
downsample_padding=downsample_padding,
|
75 |
+
cross_attention_dim=cross_attention_dim,
|
76 |
+
attn_num_head_channels=attn_num_head_channels,
|
77 |
+
)
|
78 |
+
elif down_block_type == "SkipDownBlock2D":
|
79 |
+
return SkipDownBlock2D(
|
80 |
+
num_layers=num_layers,
|
81 |
+
in_channels=in_channels,
|
82 |
+
out_channels=out_channels,
|
83 |
+
temb_channels=temb_channels,
|
84 |
+
add_downsample=add_downsample,
|
85 |
+
resnet_eps=resnet_eps,
|
86 |
+
resnet_act_fn=resnet_act_fn,
|
87 |
+
downsample_padding=downsample_padding,
|
88 |
+
)
|
89 |
+
elif down_block_type == "AttnSkipDownBlock2D":
|
90 |
+
return AttnSkipDownBlock2D(
|
91 |
+
num_layers=num_layers,
|
92 |
+
in_channels=in_channels,
|
93 |
+
out_channels=out_channels,
|
94 |
+
temb_channels=temb_channels,
|
95 |
+
add_downsample=add_downsample,
|
96 |
+
resnet_eps=resnet_eps,
|
97 |
+
resnet_act_fn=resnet_act_fn,
|
98 |
+
downsample_padding=downsample_padding,
|
99 |
+
attn_num_head_channels=attn_num_head_channels,
|
100 |
+
)
|
101 |
+
elif down_block_type == "DownEncoderBlock2D":
|
102 |
+
return DownEncoderBlock2D(
|
103 |
+
num_layers=num_layers,
|
104 |
+
in_channels=in_channels,
|
105 |
+
out_channels=out_channels,
|
106 |
+
add_downsample=add_downsample,
|
107 |
+
resnet_eps=resnet_eps,
|
108 |
+
resnet_act_fn=resnet_act_fn,
|
109 |
+
resnet_groups=resnet_groups,
|
110 |
+
downsample_padding=downsample_padding,
|
111 |
+
)
|
112 |
+
elif down_block_type == "AttnDownEncoderBlock2D":
|
113 |
+
return AttnDownEncoderBlock2D(
|
114 |
+
num_layers=num_layers,
|
115 |
+
in_channels=in_channels,
|
116 |
+
out_channels=out_channels,
|
117 |
+
add_downsample=add_downsample,
|
118 |
+
resnet_eps=resnet_eps,
|
119 |
+
resnet_act_fn=resnet_act_fn,
|
120 |
+
resnet_groups=resnet_groups,
|
121 |
+
downsample_padding=downsample_padding,
|
122 |
+
attn_num_head_channels=attn_num_head_channels,
|
123 |
+
)
|
124 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
125 |
+
|
126 |
+
|
127 |
+
def get_up_block(
|
128 |
+
up_block_type,
|
129 |
+
num_layers,
|
130 |
+
in_channels,
|
131 |
+
out_channels,
|
132 |
+
prev_output_channel,
|
133 |
+
temb_channels,
|
134 |
+
add_upsample,
|
135 |
+
resnet_eps,
|
136 |
+
resnet_act_fn,
|
137 |
+
attn_num_head_channels,
|
138 |
+
resnet_groups=None,
|
139 |
+
cross_attention_dim=None,
|
140 |
+
):
|
141 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
142 |
+
if up_block_type == "UpBlock2D":
|
143 |
+
return UpBlock2D(
|
144 |
+
num_layers=num_layers,
|
145 |
+
in_channels=in_channels,
|
146 |
+
out_channels=out_channels,
|
147 |
+
prev_output_channel=prev_output_channel,
|
148 |
+
temb_channels=temb_channels,
|
149 |
+
add_upsample=add_upsample,
|
150 |
+
resnet_eps=resnet_eps,
|
151 |
+
resnet_act_fn=resnet_act_fn,
|
152 |
+
resnet_groups=resnet_groups,
|
153 |
+
)
|
154 |
+
elif up_block_type == "CrossAttnUpBlock2D":
|
155 |
+
if cross_attention_dim is None:
|
156 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
|
157 |
+
return CrossAttnUpBlock2D(
|
158 |
+
num_layers=num_layers,
|
159 |
+
in_channels=in_channels,
|
160 |
+
out_channels=out_channels,
|
161 |
+
prev_output_channel=prev_output_channel,
|
162 |
+
temb_channels=temb_channels,
|
163 |
+
add_upsample=add_upsample,
|
164 |
+
resnet_eps=resnet_eps,
|
165 |
+
resnet_act_fn=resnet_act_fn,
|
166 |
+
resnet_groups=resnet_groups,
|
167 |
+
cross_attention_dim=cross_attention_dim,
|
168 |
+
attn_num_head_channels=attn_num_head_channels,
|
169 |
+
)
|
170 |
+
elif up_block_type == "AttnUpBlock2D":
|
171 |
+
return AttnUpBlock2D(
|
172 |
+
num_layers=num_layers,
|
173 |
+
in_channels=in_channels,
|
174 |
+
out_channels=out_channels,
|
175 |
+
prev_output_channel=prev_output_channel,
|
176 |
+
temb_channels=temb_channels,
|
177 |
+
add_upsample=add_upsample,
|
178 |
+
resnet_eps=resnet_eps,
|
179 |
+
resnet_act_fn=resnet_act_fn,
|
180 |
+
resnet_groups=resnet_groups,
|
181 |
+
attn_num_head_channels=attn_num_head_channels,
|
182 |
+
)
|
183 |
+
elif up_block_type == "SkipUpBlock2D":
|
184 |
+
return SkipUpBlock2D(
|
185 |
+
num_layers=num_layers,
|
186 |
+
in_channels=in_channels,
|
187 |
+
out_channels=out_channels,
|
188 |
+
prev_output_channel=prev_output_channel,
|
189 |
+
temb_channels=temb_channels,
|
190 |
+
add_upsample=add_upsample,
|
191 |
+
resnet_eps=resnet_eps,
|
192 |
+
resnet_act_fn=resnet_act_fn,
|
193 |
+
)
|
194 |
+
elif up_block_type == "AttnSkipUpBlock2D":
|
195 |
+
return AttnSkipUpBlock2D(
|
196 |
+
num_layers=num_layers,
|
197 |
+
in_channels=in_channels,
|
198 |
+
out_channels=out_channels,
|
199 |
+
prev_output_channel=prev_output_channel,
|
200 |
+
temb_channels=temb_channels,
|
201 |
+
add_upsample=add_upsample,
|
202 |
+
resnet_eps=resnet_eps,
|
203 |
+
resnet_act_fn=resnet_act_fn,
|
204 |
+
attn_num_head_channels=attn_num_head_channels,
|
205 |
+
)
|
206 |
+
elif up_block_type == "UpDecoderBlock2D":
|
207 |
+
return UpDecoderBlock2D(
|
208 |
+
num_layers=num_layers,
|
209 |
+
in_channels=in_channels,
|
210 |
+
out_channels=out_channels,
|
211 |
+
add_upsample=add_upsample,
|
212 |
+
resnet_eps=resnet_eps,
|
213 |
+
resnet_act_fn=resnet_act_fn,
|
214 |
+
resnet_groups=resnet_groups,
|
215 |
+
)
|
216 |
+
elif up_block_type == "AttnUpDecoderBlock2D":
|
217 |
+
return AttnUpDecoderBlock2D(
|
218 |
+
num_layers=num_layers,
|
219 |
+
in_channels=in_channels,
|
220 |
+
out_channels=out_channels,
|
221 |
+
add_upsample=add_upsample,
|
222 |
+
resnet_eps=resnet_eps,
|
223 |
+
resnet_act_fn=resnet_act_fn,
|
224 |
+
resnet_groups=resnet_groups,
|
225 |
+
attn_num_head_channels=attn_num_head_channels,
|
226 |
+
)
|
227 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
228 |
+
|
229 |
+
|
230 |
+
class UNetMidBlock2D(nn.Module):
|
231 |
+
def __init__(
|
232 |
+
self,
|
233 |
+
in_channels: int,
|
234 |
+
temb_channels: int,
|
235 |
+
dropout: float = 0.0,
|
236 |
+
num_layers: int = 1,
|
237 |
+
resnet_eps: float = 1e-6,
|
238 |
+
resnet_time_scale_shift: str = "default",
|
239 |
+
resnet_act_fn: str = "swish",
|
240 |
+
resnet_groups: int = 32,
|
241 |
+
resnet_pre_norm: bool = True,
|
242 |
+
attn_num_head_channels=1,
|
243 |
+
attention_type="default",
|
244 |
+
output_scale_factor=1.0,
|
245 |
+
**kwargs,
|
246 |
+
):
|
247 |
+
super().__init__()
|
248 |
+
|
249 |
+
self.attention_type = attention_type
|
250 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
251 |
+
|
252 |
+
# there is always at least one resnet
|
253 |
+
resnets = [
|
254 |
+
ResnetBlock2D(
|
255 |
+
in_channels=in_channels,
|
256 |
+
out_channels=in_channels,
|
257 |
+
temb_channels=temb_channels,
|
258 |
+
eps=resnet_eps,
|
259 |
+
groups=resnet_groups,
|
260 |
+
dropout=dropout,
|
261 |
+
time_embedding_norm=resnet_time_scale_shift,
|
262 |
+
non_linearity=resnet_act_fn,
|
263 |
+
output_scale_factor=output_scale_factor,
|
264 |
+
pre_norm=resnet_pre_norm,
|
265 |
+
)
|
266 |
+
]
|
267 |
+
attentions = []
|
268 |
+
|
269 |
+
for _ in range(num_layers):
|
270 |
+
attentions.append(
|
271 |
+
AttentionBlock(
|
272 |
+
in_channels,
|
273 |
+
num_head_channels=attn_num_head_channels,
|
274 |
+
rescale_output_factor=output_scale_factor,
|
275 |
+
eps=resnet_eps,
|
276 |
+
norm_num_groups=resnet_groups,
|
277 |
+
)
|
278 |
+
)
|
279 |
+
resnets.append(
|
280 |
+
ResnetBlock2D(
|
281 |
+
in_channels=in_channels,
|
282 |
+
out_channels=in_channels,
|
283 |
+
temb_channels=temb_channels,
|
284 |
+
eps=resnet_eps,
|
285 |
+
groups=resnet_groups,
|
286 |
+
dropout=dropout,
|
287 |
+
time_embedding_norm=resnet_time_scale_shift,
|
288 |
+
non_linearity=resnet_act_fn,
|
289 |
+
output_scale_factor=output_scale_factor,
|
290 |
+
pre_norm=resnet_pre_norm,
|
291 |
+
)
|
292 |
+
)
|
293 |
+
|
294 |
+
self.attentions = nn.ModuleList(attentions)
|
295 |
+
self.resnets = nn.ModuleList(resnets)
|
296 |
+
|
297 |
+
def forward(self, hidden_states, temb=None, encoder_states=None):
|
298 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
299 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
300 |
+
if self.attention_type == "default":
|
301 |
+
hidden_states = attn(hidden_states)
|
302 |
+
else:
|
303 |
+
hidden_states = attn(hidden_states, encoder_states)
|
304 |
+
hidden_states = resnet(hidden_states, temb)
|
305 |
+
|
306 |
+
return hidden_states
|
307 |
+
|
308 |
+
|
309 |
+
class UNetMidBlock2DCrossAttn(nn.Module):
|
310 |
+
def __init__(
|
311 |
+
self,
|
312 |
+
in_channels: int,
|
313 |
+
temb_channels: int,
|
314 |
+
dropout: float = 0.0,
|
315 |
+
num_layers: int = 1,
|
316 |
+
resnet_eps: float = 1e-6,
|
317 |
+
resnet_time_scale_shift: str = "default",
|
318 |
+
resnet_act_fn: str = "swish",
|
319 |
+
resnet_groups: int = 32,
|
320 |
+
resnet_pre_norm: bool = True,
|
321 |
+
attn_num_head_channels=1,
|
322 |
+
attention_type="default",
|
323 |
+
output_scale_factor=1.0,
|
324 |
+
cross_attention_dim=1280,
|
325 |
+
**kwargs,
|
326 |
+
):
|
327 |
+
super().__init__()
|
328 |
+
|
329 |
+
self.attention_type = attention_type
|
330 |
+
self.attn_num_head_channels = attn_num_head_channels
|
331 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
332 |
+
|
333 |
+
# there is always at least one resnet
|
334 |
+
resnets = [
|
335 |
+
ResnetBlock2D(
|
336 |
+
in_channels=in_channels,
|
337 |
+
out_channels=in_channels,
|
338 |
+
temb_channels=temb_channels,
|
339 |
+
eps=resnet_eps,
|
340 |
+
groups=resnet_groups,
|
341 |
+
dropout=dropout,
|
342 |
+
time_embedding_norm=resnet_time_scale_shift,
|
343 |
+
non_linearity=resnet_act_fn,
|
344 |
+
output_scale_factor=output_scale_factor,
|
345 |
+
pre_norm=resnet_pre_norm,
|
346 |
+
)
|
347 |
+
]
|
348 |
+
attentions = []
|
349 |
+
|
350 |
+
for _ in range(num_layers):
|
351 |
+
attentions.append(
|
352 |
+
Transformer2DModel(
|
353 |
+
attn_num_head_channels,
|
354 |
+
in_channels // attn_num_head_channels,
|
355 |
+
in_channels=in_channels,
|
356 |
+
num_layers=1,
|
357 |
+
cross_attention_dim=cross_attention_dim,
|
358 |
+
norm_num_groups=resnet_groups,
|
359 |
+
)
|
360 |
+
)
|
361 |
+
resnets.append(
|
362 |
+
ResnetBlock2D(
|
363 |
+
in_channels=in_channels,
|
364 |
+
out_channels=in_channels,
|
365 |
+
temb_channels=temb_channels,
|
366 |
+
eps=resnet_eps,
|
367 |
+
groups=resnet_groups,
|
368 |
+
dropout=dropout,
|
369 |
+
time_embedding_norm=resnet_time_scale_shift,
|
370 |
+
non_linearity=resnet_act_fn,
|
371 |
+
output_scale_factor=output_scale_factor,
|
372 |
+
pre_norm=resnet_pre_norm,
|
373 |
+
)
|
374 |
+
)
|
375 |
+
|
376 |
+
self.attentions = nn.ModuleList(attentions)
|
377 |
+
self.resnets = nn.ModuleList(resnets)
|
378 |
+
|
379 |
+
def set_attention_slice(self, slice_size):
|
380 |
+
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
381 |
+
raise ValueError(
|
382 |
+
f"Make sure slice_size {slice_size} is a divisor of "
|
383 |
+
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
384 |
+
)
|
385 |
+
if slice_size is not None and slice_size > self.attn_num_head_channels:
|
386 |
+
raise ValueError(
|
387 |
+
f"Chunk_size {slice_size} has to be smaller or equal to "
|
388 |
+
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
389 |
+
)
|
390 |
+
|
391 |
+
for attn in self.attentions:
|
392 |
+
attn._set_attention_slice(slice_size)
|
393 |
+
|
394 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
395 |
+
for attn in self.attentions:
|
396 |
+
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
397 |
+
|
398 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
399 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
400 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
401 |
+
hidden_states = attn(hidden_states, encoder_hidden_states, encoder_attention_mask).sample
|
402 |
+
hidden_states = resnet(hidden_states, temb)
|
403 |
+
|
404 |
+
return hidden_states
|
405 |
+
|
406 |
+
|
407 |
+
class AttnDownBlock2D(nn.Module):
|
408 |
+
def __init__(
|
409 |
+
self,
|
410 |
+
in_channels: int,
|
411 |
+
out_channels: int,
|
412 |
+
temb_channels: int,
|
413 |
+
dropout: float = 0.0,
|
414 |
+
num_layers: int = 1,
|
415 |
+
resnet_eps: float = 1e-6,
|
416 |
+
resnet_time_scale_shift: str = "default",
|
417 |
+
resnet_act_fn: str = "swish",
|
418 |
+
resnet_groups: int = 32,
|
419 |
+
resnet_pre_norm: bool = True,
|
420 |
+
attn_num_head_channels=1,
|
421 |
+
attention_type="default",
|
422 |
+
output_scale_factor=1.0,
|
423 |
+
downsample_padding=1,
|
424 |
+
add_downsample=True,
|
425 |
+
):
|
426 |
+
super().__init__()
|
427 |
+
resnets = []
|
428 |
+
attentions = []
|
429 |
+
|
430 |
+
self.attention_type = attention_type
|
431 |
+
|
432 |
+
for i in range(num_layers):
|
433 |
+
in_channels = in_channels if i == 0 else out_channels
|
434 |
+
resnets.append(
|
435 |
+
ResnetBlock2D(
|
436 |
+
in_channels=in_channels,
|
437 |
+
out_channels=out_channels,
|
438 |
+
temb_channels=temb_channels,
|
439 |
+
eps=resnet_eps,
|
440 |
+
groups=resnet_groups,
|
441 |
+
dropout=dropout,
|
442 |
+
time_embedding_norm=resnet_time_scale_shift,
|
443 |
+
non_linearity=resnet_act_fn,
|
444 |
+
output_scale_factor=output_scale_factor,
|
445 |
+
pre_norm=resnet_pre_norm,
|
446 |
+
)
|
447 |
+
)
|
448 |
+
attentions.append(
|
449 |
+
AttentionBlock(
|
450 |
+
out_channels,
|
451 |
+
num_head_channels=attn_num_head_channels,
|
452 |
+
rescale_output_factor=output_scale_factor,
|
453 |
+
eps=resnet_eps,
|
454 |
+
norm_num_groups=resnet_groups,
|
455 |
+
)
|
456 |
+
)
|
457 |
+
|
458 |
+
self.attentions = nn.ModuleList(attentions)
|
459 |
+
self.resnets = nn.ModuleList(resnets)
|
460 |
+
|
461 |
+
if add_downsample:
|
462 |
+
self.downsamplers = nn.ModuleList(
|
463 |
+
[
|
464 |
+
Downsample2D(
|
465 |
+
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
466 |
+
)
|
467 |
+
]
|
468 |
+
)
|
469 |
+
else:
|
470 |
+
self.downsamplers = None
|
471 |
+
|
472 |
+
def forward(self, hidden_states, temb=None):
|
473 |
+
output_states = ()
|
474 |
+
|
475 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
476 |
+
hidden_states = resnet(hidden_states, temb)
|
477 |
+
hidden_states = attn(hidden_states)
|
478 |
+
output_states += (hidden_states,)
|
479 |
+
|
480 |
+
if self.downsamplers is not None:
|
481 |
+
for downsampler in self.downsamplers:
|
482 |
+
hidden_states = downsampler(hidden_states)
|
483 |
+
|
484 |
+
output_states += (hidden_states,)
|
485 |
+
|
486 |
+
return hidden_states, output_states
|
487 |
+
|
488 |
+
|
489 |
+
class CrossAttnDownBlock2D(nn.Module):
|
490 |
+
def __init__(
|
491 |
+
self,
|
492 |
+
in_channels: int,
|
493 |
+
out_channels: int,
|
494 |
+
temb_channels: int,
|
495 |
+
dropout: float = 0.0,
|
496 |
+
num_layers: int = 1,
|
497 |
+
resnet_eps: float = 1e-6,
|
498 |
+
resnet_time_scale_shift: str = "default",
|
499 |
+
resnet_act_fn: str = "swish",
|
500 |
+
resnet_groups: int = 32,
|
501 |
+
resnet_pre_norm: bool = True,
|
502 |
+
attn_num_head_channels=1,
|
503 |
+
cross_attention_dim=1280,
|
504 |
+
attention_type="default",
|
505 |
+
output_scale_factor=1.0,
|
506 |
+
downsample_padding=1,
|
507 |
+
add_downsample=True,
|
508 |
+
):
|
509 |
+
super().__init__()
|
510 |
+
resnets = []
|
511 |
+
attentions = []
|
512 |
+
|
513 |
+
self.attention_type = attention_type
|
514 |
+
self.attn_num_head_channels = attn_num_head_channels
|
515 |
+
|
516 |
+
for i in range(num_layers):
|
517 |
+
in_channels = in_channels if i == 0 else out_channels
|
518 |
+
resnets.append(
|
519 |
+
ResnetBlock2D(
|
520 |
+
in_channels=in_channels,
|
521 |
+
out_channels=out_channels,
|
522 |
+
temb_channels=temb_channels,
|
523 |
+
eps=resnet_eps,
|
524 |
+
groups=resnet_groups,
|
525 |
+
dropout=dropout,
|
526 |
+
time_embedding_norm=resnet_time_scale_shift,
|
527 |
+
non_linearity=resnet_act_fn,
|
528 |
+
output_scale_factor=output_scale_factor,
|
529 |
+
pre_norm=resnet_pre_norm,
|
530 |
+
)
|
531 |
+
)
|
532 |
+
attentions.append(
|
533 |
+
Transformer2DModel(
|
534 |
+
attn_num_head_channels,
|
535 |
+
out_channels // attn_num_head_channels,
|
536 |
+
in_channels=out_channels,
|
537 |
+
num_layers=1,
|
538 |
+
cross_attention_dim=cross_attention_dim,
|
539 |
+
norm_num_groups=resnet_groups,
|
540 |
+
)
|
541 |
+
)
|
542 |
+
self.attentions = nn.ModuleList(attentions)
|
543 |
+
self.resnets = nn.ModuleList(resnets)
|
544 |
+
|
545 |
+
if add_downsample:
|
546 |
+
self.downsamplers = nn.ModuleList(
|
547 |
+
[
|
548 |
+
Downsample2D(
|
549 |
+
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
550 |
+
)
|
551 |
+
]
|
552 |
+
)
|
553 |
+
else:
|
554 |
+
self.downsamplers = None
|
555 |
+
|
556 |
+
self.gradient_checkpointing = False
|
557 |
+
|
558 |
+
def set_attention_slice(self, slice_size):
|
559 |
+
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
560 |
+
raise ValueError(
|
561 |
+
f"Make sure slice_size {slice_size} is a divisor of "
|
562 |
+
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
563 |
+
)
|
564 |
+
if slice_size is not None and slice_size > self.attn_num_head_channels:
|
565 |
+
raise ValueError(
|
566 |
+
f"Chunk_size {slice_size} has to be smaller or equal to "
|
567 |
+
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
568 |
+
)
|
569 |
+
|
570 |
+
for attn in self.attentions:
|
571 |
+
attn._set_attention_slice(slice_size)
|
572 |
+
|
573 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
574 |
+
for attn in self.attentions:
|
575 |
+
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
576 |
+
|
577 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
578 |
+
output_states = ()
|
579 |
+
|
580 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
581 |
+
if self.training and self.gradient_checkpointing:
|
582 |
+
|
583 |
+
def create_custom_forward(module, return_dict=None):
|
584 |
+
def custom_forward(*inputs):
|
585 |
+
if return_dict is not None:
|
586 |
+
return module(*inputs, return_dict=return_dict)
|
587 |
+
else:
|
588 |
+
return module(*inputs)
|
589 |
+
|
590 |
+
return custom_forward
|
591 |
+
|
592 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
593 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
594 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states,
|
595 |
+
encoder_attention_mask
|
596 |
+
)[0]
|
597 |
+
else:
|
598 |
+
hidden_states = resnet(hidden_states, temb)
|
599 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
|
600 |
+
encoder_attention_mask=encoder_attention_mask).sample
|
601 |
+
|
602 |
+
output_states += (hidden_states,)
|
603 |
+
|
604 |
+
if self.downsamplers is not None:
|
605 |
+
for downsampler in self.downsamplers:
|
606 |
+
hidden_states = downsampler(hidden_states)
|
607 |
+
|
608 |
+
output_states += (hidden_states,)
|
609 |
+
|
610 |
+
return hidden_states, output_states
|
611 |
+
|
612 |
+
|
613 |
+
class DownBlock2D(nn.Module):
|
614 |
+
def __init__(
|
615 |
+
self,
|
616 |
+
in_channels: int,
|
617 |
+
out_channels: int,
|
618 |
+
temb_channels: int,
|
619 |
+
dropout: float = 0.0,
|
620 |
+
num_layers: int = 1,
|
621 |
+
resnet_eps: float = 1e-6,
|
622 |
+
resnet_time_scale_shift: str = "default",
|
623 |
+
resnet_act_fn: str = "swish",
|
624 |
+
resnet_groups: int = 32,
|
625 |
+
resnet_pre_norm: bool = True,
|
626 |
+
output_scale_factor=1.0,
|
627 |
+
add_downsample=True,
|
628 |
+
downsample_padding=1,
|
629 |
+
):
|
630 |
+
super().__init__()
|
631 |
+
resnets = []
|
632 |
+
|
633 |
+
for i in range(num_layers):
|
634 |
+
in_channels = in_channels if i == 0 else out_channels
|
635 |
+
resnets.append(
|
636 |
+
ResnetBlock2D(
|
637 |
+
in_channels=in_channels,
|
638 |
+
out_channels=out_channels,
|
639 |
+
temb_channels=temb_channels,
|
640 |
+
eps=resnet_eps,
|
641 |
+
groups=resnet_groups,
|
642 |
+
dropout=dropout,
|
643 |
+
time_embedding_norm=resnet_time_scale_shift,
|
644 |
+
non_linearity=resnet_act_fn,
|
645 |
+
output_scale_factor=output_scale_factor,
|
646 |
+
pre_norm=resnet_pre_norm,
|
647 |
+
)
|
648 |
+
)
|
649 |
+
|
650 |
+
self.resnets = nn.ModuleList(resnets)
|
651 |
+
|
652 |
+
if add_downsample:
|
653 |
+
self.downsamplers = nn.ModuleList(
|
654 |
+
[
|
655 |
+
Downsample2D(
|
656 |
+
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
657 |
+
)
|
658 |
+
]
|
659 |
+
)
|
660 |
+
else:
|
661 |
+
self.downsamplers = None
|
662 |
+
|
663 |
+
self.gradient_checkpointing = False
|
664 |
+
|
665 |
+
def forward(self, hidden_states, temb=None):
|
666 |
+
output_states = ()
|
667 |
+
|
668 |
+
for resnet in self.resnets:
|
669 |
+
if self.training and self.gradient_checkpointing:
|
670 |
+
|
671 |
+
def create_custom_forward(module):
|
672 |
+
def custom_forward(*inputs):
|
673 |
+
return module(*inputs)
|
674 |
+
|
675 |
+
return custom_forward
|
676 |
+
|
677 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
678 |
+
else:
|
679 |
+
hidden_states = resnet(hidden_states, temb)
|
680 |
+
|
681 |
+
output_states += (hidden_states,)
|
682 |
+
|
683 |
+
if self.downsamplers is not None:
|
684 |
+
for downsampler in self.downsamplers:
|
685 |
+
hidden_states = downsampler(hidden_states)
|
686 |
+
|
687 |
+
output_states += (hidden_states,)
|
688 |
+
|
689 |
+
return hidden_states, output_states
|
690 |
+
|
691 |
+
|
692 |
+
class DownEncoderBlock2D(nn.Module):
|
693 |
+
def __init__(
|
694 |
+
self,
|
695 |
+
in_channels: int,
|
696 |
+
out_channels: int,
|
697 |
+
dropout: float = 0.0,
|
698 |
+
num_layers: int = 1,
|
699 |
+
resnet_eps: float = 1e-6,
|
700 |
+
resnet_time_scale_shift: str = "default",
|
701 |
+
resnet_act_fn: str = "swish",
|
702 |
+
resnet_groups: int = 32,
|
703 |
+
resnet_pre_norm: bool = True,
|
704 |
+
output_scale_factor=1.0,
|
705 |
+
add_downsample=True,
|
706 |
+
downsample_padding=1,
|
707 |
+
):
|
708 |
+
super().__init__()
|
709 |
+
resnets = []
|
710 |
+
|
711 |
+
for i in range(num_layers):
|
712 |
+
in_channels = in_channels if i == 0 else out_channels
|
713 |
+
resnets.append(
|
714 |
+
ResnetBlock2D(
|
715 |
+
in_channels=in_channels,
|
716 |
+
out_channels=out_channels,
|
717 |
+
temb_channels=None,
|
718 |
+
eps=resnet_eps,
|
719 |
+
groups=resnet_groups,
|
720 |
+
dropout=dropout,
|
721 |
+
time_embedding_norm=resnet_time_scale_shift,
|
722 |
+
non_linearity=resnet_act_fn,
|
723 |
+
output_scale_factor=output_scale_factor,
|
724 |
+
pre_norm=resnet_pre_norm,
|
725 |
+
)
|
726 |
+
)
|
727 |
+
|
728 |
+
self.resnets = nn.ModuleList(resnets)
|
729 |
+
|
730 |
+
if add_downsample:
|
731 |
+
self.downsamplers = nn.ModuleList(
|
732 |
+
[
|
733 |
+
Downsample2D(
|
734 |
+
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
735 |
+
)
|
736 |
+
]
|
737 |
+
)
|
738 |
+
else:
|
739 |
+
self.downsamplers = None
|
740 |
+
|
741 |
+
def forward(self, hidden_states):
|
742 |
+
for resnet in self.resnets:
|
743 |
+
hidden_states = resnet(hidden_states, temb=None)
|
744 |
+
|
745 |
+
if self.downsamplers is not None:
|
746 |
+
for downsampler in self.downsamplers:
|
747 |
+
hidden_states = downsampler(hidden_states)
|
748 |
+
|
749 |
+
return hidden_states
|
750 |
+
|
751 |
+
|
752 |
+
class AttnDownEncoderBlock2D(nn.Module):
|
753 |
+
def __init__(
|
754 |
+
self,
|
755 |
+
in_channels: int,
|
756 |
+
out_channels: int,
|
757 |
+
dropout: float = 0.0,
|
758 |
+
num_layers: int = 1,
|
759 |
+
resnet_eps: float = 1e-6,
|
760 |
+
resnet_time_scale_shift: str = "default",
|
761 |
+
resnet_act_fn: str = "swish",
|
762 |
+
resnet_groups: int = 32,
|
763 |
+
resnet_pre_norm: bool = True,
|
764 |
+
attn_num_head_channels=1,
|
765 |
+
output_scale_factor=1.0,
|
766 |
+
add_downsample=True,
|
767 |
+
downsample_padding=1,
|
768 |
+
):
|
769 |
+
super().__init__()
|
770 |
+
resnets = []
|
771 |
+
attentions = []
|
772 |
+
|
773 |
+
for i in range(num_layers):
|
774 |
+
in_channels = in_channels if i == 0 else out_channels
|
775 |
+
resnets.append(
|
776 |
+
ResnetBlock2D(
|
777 |
+
in_channels=in_channels,
|
778 |
+
out_channels=out_channels,
|
779 |
+
temb_channels=None,
|
780 |
+
eps=resnet_eps,
|
781 |
+
groups=resnet_groups,
|
782 |
+
dropout=dropout,
|
783 |
+
time_embedding_norm=resnet_time_scale_shift,
|
784 |
+
non_linearity=resnet_act_fn,
|
785 |
+
output_scale_factor=output_scale_factor,
|
786 |
+
pre_norm=resnet_pre_norm,
|
787 |
+
)
|
788 |
+
)
|
789 |
+
attentions.append(
|
790 |
+
AttentionBlock(
|
791 |
+
out_channels,
|
792 |
+
num_head_channels=attn_num_head_channels,
|
793 |
+
rescale_output_factor=output_scale_factor,
|
794 |
+
eps=resnet_eps,
|
795 |
+
norm_num_groups=resnet_groups,
|
796 |
+
)
|
797 |
+
)
|
798 |
+
|
799 |
+
self.attentions = nn.ModuleList(attentions)
|
800 |
+
self.resnets = nn.ModuleList(resnets)
|
801 |
+
|
802 |
+
if add_downsample:
|
803 |
+
self.downsamplers = nn.ModuleList(
|
804 |
+
[
|
805 |
+
Downsample2D(
|
806 |
+
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
807 |
+
)
|
808 |
+
]
|
809 |
+
)
|
810 |
+
else:
|
811 |
+
self.downsamplers = None
|
812 |
+
|
813 |
+
def forward(self, hidden_states):
|
814 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
815 |
+
hidden_states = resnet(hidden_states, temb=None)
|
816 |
+
hidden_states = attn(hidden_states)
|
817 |
+
|
818 |
+
if self.downsamplers is not None:
|
819 |
+
for downsampler in self.downsamplers:
|
820 |
+
hidden_states = downsampler(hidden_states)
|
821 |
+
|
822 |
+
return hidden_states
|
823 |
+
|
824 |
+
|
825 |
+
class AttnSkipDownBlock2D(nn.Module):
|
826 |
+
def __init__(
|
827 |
+
self,
|
828 |
+
in_channels: int,
|
829 |
+
out_channels: int,
|
830 |
+
temb_channels: int,
|
831 |
+
dropout: float = 0.0,
|
832 |
+
num_layers: int = 1,
|
833 |
+
resnet_eps: float = 1e-6,
|
834 |
+
resnet_time_scale_shift: str = "default",
|
835 |
+
resnet_act_fn: str = "swish",
|
836 |
+
resnet_pre_norm: bool = True,
|
837 |
+
attn_num_head_channels=1,
|
838 |
+
attention_type="default",
|
839 |
+
output_scale_factor=np.sqrt(2.0),
|
840 |
+
downsample_padding=1,
|
841 |
+
add_downsample=True,
|
842 |
+
):
|
843 |
+
super().__init__()
|
844 |
+
self.attentions = nn.ModuleList([])
|
845 |
+
self.resnets = nn.ModuleList([])
|
846 |
+
|
847 |
+
self.attention_type = attention_type
|
848 |
+
|
849 |
+
for i in range(num_layers):
|
850 |
+
in_channels = in_channels if i == 0 else out_channels
|
851 |
+
self.resnets.append(
|
852 |
+
ResnetBlock2D(
|
853 |
+
in_channels=in_channels,
|
854 |
+
out_channels=out_channels,
|
855 |
+
temb_channels=temb_channels,
|
856 |
+
eps=resnet_eps,
|
857 |
+
groups=min(in_channels // 4, 32),
|
858 |
+
groups_out=min(out_channels // 4, 32),
|
859 |
+
dropout=dropout,
|
860 |
+
time_embedding_norm=resnet_time_scale_shift,
|
861 |
+
non_linearity=resnet_act_fn,
|
862 |
+
output_scale_factor=output_scale_factor,
|
863 |
+
pre_norm=resnet_pre_norm,
|
864 |
+
)
|
865 |
+
)
|
866 |
+
self.attentions.append(
|
867 |
+
AttentionBlock(
|
868 |
+
out_channels,
|
869 |
+
num_head_channels=attn_num_head_channels,
|
870 |
+
rescale_output_factor=output_scale_factor,
|
871 |
+
eps=resnet_eps,
|
872 |
+
)
|
873 |
+
)
|
874 |
+
|
875 |
+
if add_downsample:
|
876 |
+
self.resnet_down = ResnetBlock2D(
|
877 |
+
in_channels=out_channels,
|
878 |
+
out_channels=out_channels,
|
879 |
+
temb_channels=temb_channels,
|
880 |
+
eps=resnet_eps,
|
881 |
+
groups=min(out_channels // 4, 32),
|
882 |
+
dropout=dropout,
|
883 |
+
time_embedding_norm=resnet_time_scale_shift,
|
884 |
+
non_linearity=resnet_act_fn,
|
885 |
+
output_scale_factor=output_scale_factor,
|
886 |
+
pre_norm=resnet_pre_norm,
|
887 |
+
use_in_shortcut=True,
|
888 |
+
down=True,
|
889 |
+
kernel="fir",
|
890 |
+
)
|
891 |
+
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
|
892 |
+
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
|
893 |
+
else:
|
894 |
+
self.resnet_down = None
|
895 |
+
self.downsamplers = None
|
896 |
+
self.skip_conv = None
|
897 |
+
|
898 |
+
def forward(self, hidden_states, temb=None, skip_sample=None):
|
899 |
+
output_states = ()
|
900 |
+
|
901 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
902 |
+
hidden_states = resnet(hidden_states, temb)
|
903 |
+
hidden_states = attn(hidden_states)
|
904 |
+
output_states += (hidden_states,)
|
905 |
+
|
906 |
+
if self.downsamplers is not None:
|
907 |
+
hidden_states = self.resnet_down(hidden_states, temb)
|
908 |
+
for downsampler in self.downsamplers:
|
909 |
+
skip_sample = downsampler(skip_sample)
|
910 |
+
|
911 |
+
hidden_states = self.skip_conv(skip_sample) + hidden_states
|
912 |
+
|
913 |
+
output_states += (hidden_states,)
|
914 |
+
|
915 |
+
return hidden_states, output_states, skip_sample
|
916 |
+
|
917 |
+
|
918 |
+
class SkipDownBlock2D(nn.Module):
|
919 |
+
def __init__(
|
920 |
+
self,
|
921 |
+
in_channels: int,
|
922 |
+
out_channels: int,
|
923 |
+
temb_channels: int,
|
924 |
+
dropout: float = 0.0,
|
925 |
+
num_layers: int = 1,
|
926 |
+
resnet_eps: float = 1e-6,
|
927 |
+
resnet_time_scale_shift: str = "default",
|
928 |
+
resnet_act_fn: str = "swish",
|
929 |
+
resnet_pre_norm: bool = True,
|
930 |
+
output_scale_factor=np.sqrt(2.0),
|
931 |
+
add_downsample=True,
|
932 |
+
downsample_padding=1,
|
933 |
+
):
|
934 |
+
super().__init__()
|
935 |
+
self.resnets = nn.ModuleList([])
|
936 |
+
|
937 |
+
for i in range(num_layers):
|
938 |
+
in_channels = in_channels if i == 0 else out_channels
|
939 |
+
self.resnets.append(
|
940 |
+
ResnetBlock2D(
|
941 |
+
in_channels=in_channels,
|
942 |
+
out_channels=out_channels,
|
943 |
+
temb_channels=temb_channels,
|
944 |
+
eps=resnet_eps,
|
945 |
+
groups=min(in_channels // 4, 32),
|
946 |
+
groups_out=min(out_channels // 4, 32),
|
947 |
+
dropout=dropout,
|
948 |
+
time_embedding_norm=resnet_time_scale_shift,
|
949 |
+
non_linearity=resnet_act_fn,
|
950 |
+
output_scale_factor=output_scale_factor,
|
951 |
+
pre_norm=resnet_pre_norm,
|
952 |
+
)
|
953 |
+
)
|
954 |
+
|
955 |
+
if add_downsample:
|
956 |
+
self.resnet_down = ResnetBlock2D(
|
957 |
+
in_channels=out_channels,
|
958 |
+
out_channels=out_channels,
|
959 |
+
temb_channels=temb_channels,
|
960 |
+
eps=resnet_eps,
|
961 |
+
groups=min(out_channels // 4, 32),
|
962 |
+
dropout=dropout,
|
963 |
+
time_embedding_norm=resnet_time_scale_shift,
|
964 |
+
non_linearity=resnet_act_fn,
|
965 |
+
output_scale_factor=output_scale_factor,
|
966 |
+
pre_norm=resnet_pre_norm,
|
967 |
+
use_in_shortcut=True,
|
968 |
+
down=True,
|
969 |
+
kernel="fir",
|
970 |
+
)
|
971 |
+
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
|
972 |
+
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
|
973 |
+
else:
|
974 |
+
self.resnet_down = None
|
975 |
+
self.downsamplers = None
|
976 |
+
self.skip_conv = None
|
977 |
+
|
978 |
+
def forward(self, hidden_states, temb=None, skip_sample=None):
|
979 |
+
output_states = ()
|
980 |
+
|
981 |
+
for resnet in self.resnets:
|
982 |
+
hidden_states = resnet(hidden_states, temb)
|
983 |
+
output_states += (hidden_states,)
|
984 |
+
|
985 |
+
if self.downsamplers is not None:
|
986 |
+
hidden_states = self.resnet_down(hidden_states, temb)
|
987 |
+
for downsampler in self.downsamplers:
|
988 |
+
skip_sample = downsampler(skip_sample)
|
989 |
+
|
990 |
+
hidden_states = self.skip_conv(skip_sample) + hidden_states
|
991 |
+
|
992 |
+
output_states += (hidden_states,)
|
993 |
+
|
994 |
+
return hidden_states, output_states, skip_sample
|
995 |
+
|
996 |
+
|
997 |
+
class AttnUpBlock2D(nn.Module):
|
998 |
+
def __init__(
|
999 |
+
self,
|
1000 |
+
in_channels: int,
|
1001 |
+
prev_output_channel: int,
|
1002 |
+
out_channels: int,
|
1003 |
+
temb_channels: int,
|
1004 |
+
dropout: float = 0.0,
|
1005 |
+
num_layers: int = 1,
|
1006 |
+
resnet_eps: float = 1e-6,
|
1007 |
+
resnet_time_scale_shift: str = "default",
|
1008 |
+
resnet_act_fn: str = "swish",
|
1009 |
+
resnet_groups: int = 32,
|
1010 |
+
resnet_pre_norm: bool = True,
|
1011 |
+
attention_type="default",
|
1012 |
+
attn_num_head_channels=1,
|
1013 |
+
output_scale_factor=1.0,
|
1014 |
+
add_upsample=True,
|
1015 |
+
):
|
1016 |
+
super().__init__()
|
1017 |
+
resnets = []
|
1018 |
+
attentions = []
|
1019 |
+
|
1020 |
+
self.attention_type = attention_type
|
1021 |
+
|
1022 |
+
for i in range(num_layers):
|
1023 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1024 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1025 |
+
|
1026 |
+
resnets.append(
|
1027 |
+
ResnetBlock2D(
|
1028 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1029 |
+
out_channels=out_channels,
|
1030 |
+
temb_channels=temb_channels,
|
1031 |
+
eps=resnet_eps,
|
1032 |
+
groups=resnet_groups,
|
1033 |
+
dropout=dropout,
|
1034 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1035 |
+
non_linearity=resnet_act_fn,
|
1036 |
+
output_scale_factor=output_scale_factor,
|
1037 |
+
pre_norm=resnet_pre_norm,
|
1038 |
+
)
|
1039 |
+
)
|
1040 |
+
attentions.append(
|
1041 |
+
AttentionBlock(
|
1042 |
+
out_channels,
|
1043 |
+
num_head_channels=attn_num_head_channels,
|
1044 |
+
rescale_output_factor=output_scale_factor,
|
1045 |
+
eps=resnet_eps,
|
1046 |
+
norm_num_groups=resnet_groups,
|
1047 |
+
)
|
1048 |
+
)
|
1049 |
+
|
1050 |
+
self.attentions = nn.ModuleList(attentions)
|
1051 |
+
self.resnets = nn.ModuleList(resnets)
|
1052 |
+
|
1053 |
+
if add_upsample:
|
1054 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1055 |
+
else:
|
1056 |
+
self.upsamplers = None
|
1057 |
+
|
1058 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
1059 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1060 |
+
# pop res hidden states
|
1061 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1062 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1063 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1064 |
+
|
1065 |
+
hidden_states = resnet(hidden_states, temb)
|
1066 |
+
hidden_states = attn(hidden_states)
|
1067 |
+
|
1068 |
+
if self.upsamplers is not None:
|
1069 |
+
for upsampler in self.upsamplers:
|
1070 |
+
hidden_states = upsampler(hidden_states)
|
1071 |
+
|
1072 |
+
return hidden_states
|
1073 |
+
|
1074 |
+
|
1075 |
+
class CrossAttnUpBlock2D(nn.Module):
|
1076 |
+
def __init__(
|
1077 |
+
self,
|
1078 |
+
in_channels: int,
|
1079 |
+
out_channels: int,
|
1080 |
+
prev_output_channel: int,
|
1081 |
+
temb_channels: int,
|
1082 |
+
dropout: float = 0.0,
|
1083 |
+
num_layers: int = 1,
|
1084 |
+
resnet_eps: float = 1e-6,
|
1085 |
+
resnet_time_scale_shift: str = "default",
|
1086 |
+
resnet_act_fn: str = "swish",
|
1087 |
+
resnet_groups: int = 32,
|
1088 |
+
resnet_pre_norm: bool = True,
|
1089 |
+
attn_num_head_channels=1,
|
1090 |
+
cross_attention_dim=1280,
|
1091 |
+
attention_type="default",
|
1092 |
+
output_scale_factor=1.0,
|
1093 |
+
add_upsample=True,
|
1094 |
+
):
|
1095 |
+
super().__init__()
|
1096 |
+
resnets = []
|
1097 |
+
attentions = []
|
1098 |
+
|
1099 |
+
self.attention_type = attention_type
|
1100 |
+
self.attn_num_head_channels = attn_num_head_channels
|
1101 |
+
|
1102 |
+
for i in range(num_layers):
|
1103 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1104 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1105 |
+
|
1106 |
+
resnets.append(
|
1107 |
+
ResnetBlock2D(
|
1108 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1109 |
+
out_channels=out_channels,
|
1110 |
+
temb_channels=temb_channels,
|
1111 |
+
eps=resnet_eps,
|
1112 |
+
groups=resnet_groups,
|
1113 |
+
dropout=dropout,
|
1114 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1115 |
+
non_linearity=resnet_act_fn,
|
1116 |
+
output_scale_factor=output_scale_factor,
|
1117 |
+
pre_norm=resnet_pre_norm,
|
1118 |
+
)
|
1119 |
+
)
|
1120 |
+
attentions.append(
|
1121 |
+
Transformer2DModel(
|
1122 |
+
attn_num_head_channels,
|
1123 |
+
out_channels // attn_num_head_channels,
|
1124 |
+
in_channels=out_channels,
|
1125 |
+
num_layers=1,
|
1126 |
+
cross_attention_dim=cross_attention_dim,
|
1127 |
+
norm_num_groups=resnet_groups,
|
1128 |
+
)
|
1129 |
+
)
|
1130 |
+
self.attentions = nn.ModuleList(attentions)
|
1131 |
+
self.resnets = nn.ModuleList(resnets)
|
1132 |
+
|
1133 |
+
if add_upsample:
|
1134 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1135 |
+
else:
|
1136 |
+
self.upsamplers = None
|
1137 |
+
|
1138 |
+
self.gradient_checkpointing = False
|
1139 |
+
|
1140 |
+
def set_attention_slice(self, slice_size):
|
1141 |
+
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
1142 |
+
raise ValueError(
|
1143 |
+
f"Make sure slice_size {slice_size} is a divisor of "
|
1144 |
+
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
1145 |
+
)
|
1146 |
+
if slice_size is not None and slice_size > self.attn_num_head_channels:
|
1147 |
+
raise ValueError(
|
1148 |
+
f"Chunk_size {slice_size} has to be smaller or equal to "
|
1149 |
+
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
1150 |
+
)
|
1151 |
+
|
1152 |
+
for attn in self.attentions:
|
1153 |
+
attn._set_attention_slice(slice_size)
|
1154 |
+
|
1155 |
+
self.gradient_checkpointing = False
|
1156 |
+
|
1157 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
1158 |
+
for attn in self.attentions:
|
1159 |
+
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
1160 |
+
|
1161 |
+
def forward(
|
1162 |
+
self,
|
1163 |
+
hidden_states,
|
1164 |
+
res_hidden_states_tuple,
|
1165 |
+
temb=None,
|
1166 |
+
encoder_hidden_states=None,
|
1167 |
+
encoder_attention_mask=None,
|
1168 |
+
upsample_size=None,
|
1169 |
+
):
|
1170 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1171 |
+
# pop res hidden states
|
1172 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1173 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1174 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1175 |
+
|
1176 |
+
if self.training and self.gradient_checkpointing:
|
1177 |
+
|
1178 |
+
def create_custom_forward(module, return_dict=None):
|
1179 |
+
def custom_forward(*inputs):
|
1180 |
+
if return_dict is not None:
|
1181 |
+
return module(*inputs, return_dict=return_dict)
|
1182 |
+
else:
|
1183 |
+
return module(*inputs)
|
1184 |
+
|
1185 |
+
return custom_forward
|
1186 |
+
|
1187 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1188 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1189 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states,
|
1190 |
+
encoder_attention_mask
|
1191 |
+
)[0]
|
1192 |
+
else:
|
1193 |
+
hidden_states = resnet(hidden_states, temb)
|
1194 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
|
1195 |
+
encoder_attention_mask=encoder_attention_mask).sample
|
1196 |
+
|
1197 |
+
if self.upsamplers is not None:
|
1198 |
+
for upsampler in self.upsamplers:
|
1199 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1200 |
+
|
1201 |
+
return hidden_states
|
1202 |
+
|
1203 |
+
|
1204 |
+
class UpBlock2D(nn.Module):
|
1205 |
+
def __init__(
|
1206 |
+
self,
|
1207 |
+
in_channels: int,
|
1208 |
+
prev_output_channel: int,
|
1209 |
+
out_channels: int,
|
1210 |
+
temb_channels: int,
|
1211 |
+
dropout: float = 0.0,
|
1212 |
+
num_layers: int = 1,
|
1213 |
+
resnet_eps: float = 1e-6,
|
1214 |
+
resnet_time_scale_shift: str = "default",
|
1215 |
+
resnet_act_fn: str = "swish",
|
1216 |
+
resnet_groups: int = 32,
|
1217 |
+
resnet_pre_norm: bool = True,
|
1218 |
+
output_scale_factor=1.0,
|
1219 |
+
add_upsample=True,
|
1220 |
+
):
|
1221 |
+
super().__init__()
|
1222 |
+
resnets = []
|
1223 |
+
|
1224 |
+
for i in range(num_layers):
|
1225 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1226 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1227 |
+
|
1228 |
+
resnets.append(
|
1229 |
+
ResnetBlock2D(
|
1230 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1231 |
+
out_channels=out_channels,
|
1232 |
+
temb_channels=temb_channels,
|
1233 |
+
eps=resnet_eps,
|
1234 |
+
groups=resnet_groups,
|
1235 |
+
dropout=dropout,
|
1236 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1237 |
+
non_linearity=resnet_act_fn,
|
1238 |
+
output_scale_factor=output_scale_factor,
|
1239 |
+
pre_norm=resnet_pre_norm,
|
1240 |
+
)
|
1241 |
+
)
|
1242 |
+
|
1243 |
+
self.resnets = nn.ModuleList(resnets)
|
1244 |
+
|
1245 |
+
if add_upsample:
|
1246 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1247 |
+
else:
|
1248 |
+
self.upsamplers = None
|
1249 |
+
|
1250 |
+
self.gradient_checkpointing = False
|
1251 |
+
|
1252 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
1253 |
+
for resnet in self.resnets:
|
1254 |
+
# pop res hidden states
|
1255 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1256 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1257 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1258 |
+
|
1259 |
+
if self.training and self.gradient_checkpointing:
|
1260 |
+
|
1261 |
+
def create_custom_forward(module):
|
1262 |
+
def custom_forward(*inputs):
|
1263 |
+
return module(*inputs)
|
1264 |
+
|
1265 |
+
return custom_forward
|
1266 |
+
|
1267 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1268 |
+
else:
|
1269 |
+
hidden_states = resnet(hidden_states, temb)
|
1270 |
+
|
1271 |
+
if self.upsamplers is not None:
|
1272 |
+
for upsampler in self.upsamplers:
|
1273 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1274 |
+
|
1275 |
+
return hidden_states
|
1276 |
+
|
1277 |
+
|
1278 |
+
class UpDecoderBlock2D(nn.Module):
|
1279 |
+
def __init__(
|
1280 |
+
self,
|
1281 |
+
in_channels: int,
|
1282 |
+
out_channels: int,
|
1283 |
+
dropout: float = 0.0,
|
1284 |
+
num_layers: int = 1,
|
1285 |
+
resnet_eps: float = 1e-6,
|
1286 |
+
resnet_time_scale_shift: str = "default",
|
1287 |
+
resnet_act_fn: str = "swish",
|
1288 |
+
resnet_groups: int = 32,
|
1289 |
+
resnet_pre_norm: bool = True,
|
1290 |
+
output_scale_factor=1.0,
|
1291 |
+
add_upsample=True,
|
1292 |
+
):
|
1293 |
+
super().__init__()
|
1294 |
+
resnets = []
|
1295 |
+
|
1296 |
+
for i in range(num_layers):
|
1297 |
+
input_channels = in_channels if i == 0 else out_channels
|
1298 |
+
|
1299 |
+
resnets.append(
|
1300 |
+
ResnetBlock2D(
|
1301 |
+
in_channels=input_channels,
|
1302 |
+
out_channels=out_channels,
|
1303 |
+
temb_channels=None,
|
1304 |
+
eps=resnet_eps,
|
1305 |
+
groups=resnet_groups,
|
1306 |
+
dropout=dropout,
|
1307 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1308 |
+
non_linearity=resnet_act_fn,
|
1309 |
+
output_scale_factor=output_scale_factor,
|
1310 |
+
pre_norm=resnet_pre_norm,
|
1311 |
+
)
|
1312 |
+
)
|
1313 |
+
|
1314 |
+
self.resnets = nn.ModuleList(resnets)
|
1315 |
+
|
1316 |
+
if add_upsample:
|
1317 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1318 |
+
else:
|
1319 |
+
self.upsamplers = None
|
1320 |
+
|
1321 |
+
def forward(self, hidden_states):
|
1322 |
+
for resnet in self.resnets:
|
1323 |
+
hidden_states = resnet(hidden_states, temb=None)
|
1324 |
+
|
1325 |
+
if self.upsamplers is not None:
|
1326 |
+
for upsampler in self.upsamplers:
|
1327 |
+
hidden_states = upsampler(hidden_states)
|
1328 |
+
|
1329 |
+
return hidden_states
|
1330 |
+
|
1331 |
+
|
1332 |
+
class AttnUpDecoderBlock2D(nn.Module):
|
1333 |
+
def __init__(
|
1334 |
+
self,
|
1335 |
+
in_channels: int,
|
1336 |
+
out_channels: int,
|
1337 |
+
dropout: float = 0.0,
|
1338 |
+
num_layers: int = 1,
|
1339 |
+
resnet_eps: float = 1e-6,
|
1340 |
+
resnet_time_scale_shift: str = "default",
|
1341 |
+
resnet_act_fn: str = "swish",
|
1342 |
+
resnet_groups: int = 32,
|
1343 |
+
resnet_pre_norm: bool = True,
|
1344 |
+
attn_num_head_channels=1,
|
1345 |
+
output_scale_factor=1.0,
|
1346 |
+
add_upsample=True,
|
1347 |
+
):
|
1348 |
+
super().__init__()
|
1349 |
+
resnets = []
|
1350 |
+
attentions = []
|
1351 |
+
|
1352 |
+
for i in range(num_layers):
|
1353 |
+
input_channels = in_channels if i == 0 else out_channels
|
1354 |
+
|
1355 |
+
resnets.append(
|
1356 |
+
ResnetBlock2D(
|
1357 |
+
in_channels=input_channels,
|
1358 |
+
out_channels=out_channels,
|
1359 |
+
temb_channels=None,
|
1360 |
+
eps=resnet_eps,
|
1361 |
+
groups=resnet_groups,
|
1362 |
+
dropout=dropout,
|
1363 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1364 |
+
non_linearity=resnet_act_fn,
|
1365 |
+
output_scale_factor=output_scale_factor,
|
1366 |
+
pre_norm=resnet_pre_norm,
|
1367 |
+
)
|
1368 |
+
)
|
1369 |
+
attentions.append(
|
1370 |
+
AttentionBlock(
|
1371 |
+
out_channels,
|
1372 |
+
num_head_channels=attn_num_head_channels,
|
1373 |
+
rescale_output_factor=output_scale_factor,
|
1374 |
+
eps=resnet_eps,
|
1375 |
+
norm_num_groups=resnet_groups,
|
1376 |
+
)
|
1377 |
+
)
|
1378 |
+
|
1379 |
+
self.attentions = nn.ModuleList(attentions)
|
1380 |
+
self.resnets = nn.ModuleList(resnets)
|
1381 |
+
|
1382 |
+
if add_upsample:
|
1383 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1384 |
+
else:
|
1385 |
+
self.upsamplers = None
|
1386 |
+
|
1387 |
+
def forward(self, hidden_states):
|
1388 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1389 |
+
hidden_states = resnet(hidden_states, temb=None)
|
1390 |
+
hidden_states = attn(hidden_states)
|
1391 |
+
|
1392 |
+
if self.upsamplers is not None:
|
1393 |
+
for upsampler in self.upsamplers:
|
1394 |
+
hidden_states = upsampler(hidden_states)
|
1395 |
+
|
1396 |
+
return hidden_states
|
1397 |
+
|
1398 |
+
|
1399 |
+
class AttnSkipUpBlock2D(nn.Module):
|
1400 |
+
def __init__(
|
1401 |
+
self,
|
1402 |
+
in_channels: int,
|
1403 |
+
prev_output_channel: int,
|
1404 |
+
out_channels: int,
|
1405 |
+
temb_channels: int,
|
1406 |
+
dropout: float = 0.0,
|
1407 |
+
num_layers: int = 1,
|
1408 |
+
resnet_eps: float = 1e-6,
|
1409 |
+
resnet_time_scale_shift: str = "default",
|
1410 |
+
resnet_act_fn: str = "swish",
|
1411 |
+
resnet_pre_norm: bool = True,
|
1412 |
+
attn_num_head_channels=1,
|
1413 |
+
attention_type="default",
|
1414 |
+
output_scale_factor=np.sqrt(2.0),
|
1415 |
+
upsample_padding=1,
|
1416 |
+
add_upsample=True,
|
1417 |
+
):
|
1418 |
+
super().__init__()
|
1419 |
+
self.attentions = nn.ModuleList([])
|
1420 |
+
self.resnets = nn.ModuleList([])
|
1421 |
+
|
1422 |
+
self.attention_type = attention_type
|
1423 |
+
|
1424 |
+
for i in range(num_layers):
|
1425 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1426 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1427 |
+
|
1428 |
+
self.resnets.append(
|
1429 |
+
ResnetBlock2D(
|
1430 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1431 |
+
out_channels=out_channels,
|
1432 |
+
temb_channels=temb_channels,
|
1433 |
+
eps=resnet_eps,
|
1434 |
+
groups=min(resnet_in_channels + res_skip_channels // 4, 32),
|
1435 |
+
groups_out=min(out_channels // 4, 32),
|
1436 |
+
dropout=dropout,
|
1437 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1438 |
+
non_linearity=resnet_act_fn,
|
1439 |
+
output_scale_factor=output_scale_factor,
|
1440 |
+
pre_norm=resnet_pre_norm,
|
1441 |
+
)
|
1442 |
+
)
|
1443 |
+
|
1444 |
+
self.attentions.append(
|
1445 |
+
AttentionBlock(
|
1446 |
+
out_channels,
|
1447 |
+
num_head_channels=attn_num_head_channels,
|
1448 |
+
rescale_output_factor=output_scale_factor,
|
1449 |
+
eps=resnet_eps,
|
1450 |
+
)
|
1451 |
+
)
|
1452 |
+
|
1453 |
+
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
|
1454 |
+
if add_upsample:
|
1455 |
+
self.resnet_up = ResnetBlock2D(
|
1456 |
+
in_channels=out_channels,
|
1457 |
+
out_channels=out_channels,
|
1458 |
+
temb_channels=temb_channels,
|
1459 |
+
eps=resnet_eps,
|
1460 |
+
groups=min(out_channels // 4, 32),
|
1461 |
+
groups_out=min(out_channels // 4, 32),
|
1462 |
+
dropout=dropout,
|
1463 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1464 |
+
non_linearity=resnet_act_fn,
|
1465 |
+
output_scale_factor=output_scale_factor,
|
1466 |
+
pre_norm=resnet_pre_norm,
|
1467 |
+
use_in_shortcut=True,
|
1468 |
+
up=True,
|
1469 |
+
kernel="fir",
|
1470 |
+
)
|
1471 |
+
self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
1472 |
+
self.skip_norm = torch.nn.GroupNorm(
|
1473 |
+
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
|
1474 |
+
)
|
1475 |
+
self.act = nn.SiLU()
|
1476 |
+
else:
|
1477 |
+
self.resnet_up = None
|
1478 |
+
self.skip_conv = None
|
1479 |
+
self.skip_norm = None
|
1480 |
+
self.act = None
|
1481 |
+
|
1482 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
|
1483 |
+
for resnet in self.resnets:
|
1484 |
+
# pop res hidden states
|
1485 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1486 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1487 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1488 |
+
|
1489 |
+
hidden_states = resnet(hidden_states, temb)
|
1490 |
+
|
1491 |
+
hidden_states = self.attentions[0](hidden_states)
|
1492 |
+
|
1493 |
+
if skip_sample is not None:
|
1494 |
+
skip_sample = self.upsampler(skip_sample)
|
1495 |
+
else:
|
1496 |
+
skip_sample = 0
|
1497 |
+
|
1498 |
+
if self.resnet_up is not None:
|
1499 |
+
skip_sample_states = self.skip_norm(hidden_states)
|
1500 |
+
skip_sample_states = self.act(skip_sample_states)
|
1501 |
+
skip_sample_states = self.skip_conv(skip_sample_states)
|
1502 |
+
|
1503 |
+
skip_sample = skip_sample + skip_sample_states
|
1504 |
+
|
1505 |
+
hidden_states = self.resnet_up(hidden_states, temb)
|
1506 |
+
|
1507 |
+
return hidden_states, skip_sample
|
1508 |
+
|
1509 |
+
|
1510 |
+
class SkipUpBlock2D(nn.Module):
|
1511 |
+
def __init__(
|
1512 |
+
self,
|
1513 |
+
in_channels: int,
|
1514 |
+
prev_output_channel: int,
|
1515 |
+
out_channels: int,
|
1516 |
+
temb_channels: int,
|
1517 |
+
dropout: float = 0.0,
|
1518 |
+
num_layers: int = 1,
|
1519 |
+
resnet_eps: float = 1e-6,
|
1520 |
+
resnet_time_scale_shift: str = "default",
|
1521 |
+
resnet_act_fn: str = "swish",
|
1522 |
+
resnet_pre_norm: bool = True,
|
1523 |
+
output_scale_factor=np.sqrt(2.0),
|
1524 |
+
add_upsample=True,
|
1525 |
+
upsample_padding=1,
|
1526 |
+
):
|
1527 |
+
super().__init__()
|
1528 |
+
self.resnets = nn.ModuleList([])
|
1529 |
+
|
1530 |
+
for i in range(num_layers):
|
1531 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1532 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1533 |
+
|
1534 |
+
self.resnets.append(
|
1535 |
+
ResnetBlock2D(
|
1536 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1537 |
+
out_channels=out_channels,
|
1538 |
+
temb_channels=temb_channels,
|
1539 |
+
eps=resnet_eps,
|
1540 |
+
groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
|
1541 |
+
groups_out=min(out_channels // 4, 32),
|
1542 |
+
dropout=dropout,
|
1543 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1544 |
+
non_linearity=resnet_act_fn,
|
1545 |
+
output_scale_factor=output_scale_factor,
|
1546 |
+
pre_norm=resnet_pre_norm,
|
1547 |
+
)
|
1548 |
+
)
|
1549 |
+
|
1550 |
+
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
|
1551 |
+
if add_upsample:
|
1552 |
+
self.resnet_up = ResnetBlock2D(
|
1553 |
+
in_channels=out_channels,
|
1554 |
+
out_channels=out_channels,
|
1555 |
+
temb_channels=temb_channels,
|
1556 |
+
eps=resnet_eps,
|
1557 |
+
groups=min(out_channels // 4, 32),
|
1558 |
+
groups_out=min(out_channels // 4, 32),
|
1559 |
+
dropout=dropout,
|
1560 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1561 |
+
non_linearity=resnet_act_fn,
|
1562 |
+
output_scale_factor=output_scale_factor,
|
1563 |
+
pre_norm=resnet_pre_norm,
|
1564 |
+
use_in_shortcut=True,
|
1565 |
+
up=True,
|
1566 |
+
kernel="fir",
|
1567 |
+
)
|
1568 |
+
self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
1569 |
+
self.skip_norm = torch.nn.GroupNorm(
|
1570 |
+
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
|
1571 |
+
)
|
1572 |
+
self.act = nn.SiLU()
|
1573 |
+
else:
|
1574 |
+
self.resnet_up = None
|
1575 |
+
self.skip_conv = None
|
1576 |
+
self.skip_norm = None
|
1577 |
+
self.act = None
|
1578 |
+
|
1579 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
|
1580 |
+
for resnet in self.resnets:
|
1581 |
+
# pop res hidden states
|
1582 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1583 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1584 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1585 |
+
|
1586 |
+
hidden_states = resnet(hidden_states, temb)
|
1587 |
+
|
1588 |
+
if skip_sample is not None:
|
1589 |
+
skip_sample = self.upsampler(skip_sample)
|
1590 |
+
else:
|
1591 |
+
skip_sample = 0
|
1592 |
+
|
1593 |
+
if self.resnet_up is not None:
|
1594 |
+
skip_sample_states = self.skip_norm(hidden_states)
|
1595 |
+
skip_sample_states = self.act(skip_sample_states)
|
1596 |
+
skip_sample_states = self.skip_conv(skip_sample_states)
|
1597 |
+
|
1598 |
+
skip_sample = skip_sample + skip_sample_states
|
1599 |
+
|
1600 |
+
hidden_states = self.resnet_up(hidden_states, temb)
|
1601 |
+
|
1602 |
+
return hidden_states, skip_sample
|
models/diffusers_override/unet_2d_condition.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.modeling_utils import ModelMixin
|
23 |
+
from diffusers.utils import BaseOutput, logging
|
24 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
25 |
+
from .unet_2d_blocks import (
|
26 |
+
CrossAttnDownBlock2D,
|
27 |
+
CrossAttnUpBlock2D,
|
28 |
+
DownBlock2D,
|
29 |
+
UNetMidBlock2DCrossAttn,
|
30 |
+
UpBlock2D,
|
31 |
+
get_down_block,
|
32 |
+
get_up_block,
|
33 |
+
)
|
34 |
+
|
35 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class UNet2DConditionOutput(BaseOutput):
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
43 |
+
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
44 |
+
"""
|
45 |
+
|
46 |
+
sample: torch.FloatTensor
|
47 |
+
|
48 |
+
|
49 |
+
class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
50 |
+
r"""
|
51 |
+
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
|
52 |
+
and returns sample shaped output.
|
53 |
+
|
54 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
55 |
+
implements for all the models (such as downloading or saving, etc.)
|
56 |
+
|
57 |
+
Parameters:
|
58 |
+
sample_size (`int`, *optional*): The size of the input sample.
|
59 |
+
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
60 |
+
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
61 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
62 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
63 |
+
Whether to flip the sin to cos in the time embedding.
|
64 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
65 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
66 |
+
The tuple of downsample blocks to use.
|
67 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
|
68 |
+
The tuple of upsample blocks to use.
|
69 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
70 |
+
The tuple of output channels for each block.
|
71 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
72 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
73 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
74 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
75 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
76 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
77 |
+
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
78 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
79 |
+
"""
|
80 |
+
|
81 |
+
_supports_gradient_checkpointing = True
|
82 |
+
|
83 |
+
@register_to_config
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
sample_size: Optional[int] = None,
|
87 |
+
in_channels: int = 4,
|
88 |
+
out_channels: int = 4,
|
89 |
+
center_input_sample: bool = False,
|
90 |
+
flip_sin_to_cos: bool = True,
|
91 |
+
freq_shift: int = 0,
|
92 |
+
down_block_types: Tuple[str] = (
|
93 |
+
"CrossAttnDownBlock2D",
|
94 |
+
"CrossAttnDownBlock2D",
|
95 |
+
"CrossAttnDownBlock2D",
|
96 |
+
"DownBlock2D",
|
97 |
+
),
|
98 |
+
up_block_types: Tuple[str] = (
|
99 |
+
"UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
100 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
101 |
+
layers_per_block: int = 2,
|
102 |
+
downsample_padding: int = 1,
|
103 |
+
mid_block_scale_factor: float = 1,
|
104 |
+
act_fn: str = "silu",
|
105 |
+
norm_num_groups: int = 32,
|
106 |
+
norm_eps: float = 1e-5,
|
107 |
+
cross_attention_dim: int = 1280,
|
108 |
+
attention_head_dim: int = 8,
|
109 |
+
):
|
110 |
+
super().__init__()
|
111 |
+
|
112 |
+
self.sample_size = sample_size
|
113 |
+
time_embed_dim = block_out_channels[0] * 4
|
114 |
+
|
115 |
+
# input
|
116 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
117 |
+
|
118 |
+
# time
|
119 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
120 |
+
timestep_input_dim = block_out_channels[0]
|
121 |
+
|
122 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
123 |
+
|
124 |
+
self.down_blocks = nn.ModuleList([])
|
125 |
+
self.mid_block = None
|
126 |
+
self.up_blocks = nn.ModuleList([])
|
127 |
+
|
128 |
+
# down
|
129 |
+
output_channel = block_out_channels[0]
|
130 |
+
for i, down_block_type in enumerate(down_block_types):
|
131 |
+
input_channel = output_channel
|
132 |
+
output_channel = block_out_channels[i]
|
133 |
+
is_final_block = i == len(block_out_channels) - 1
|
134 |
+
|
135 |
+
down_block = get_down_block(
|
136 |
+
down_block_type,
|
137 |
+
num_layers=layers_per_block,
|
138 |
+
in_channels=input_channel,
|
139 |
+
out_channels=output_channel,
|
140 |
+
temb_channels=time_embed_dim,
|
141 |
+
add_downsample=not is_final_block,
|
142 |
+
resnet_eps=norm_eps,
|
143 |
+
resnet_act_fn=act_fn,
|
144 |
+
resnet_groups=norm_num_groups,
|
145 |
+
cross_attention_dim=cross_attention_dim,
|
146 |
+
attn_num_head_channels=attention_head_dim,
|
147 |
+
downsample_padding=downsample_padding,
|
148 |
+
)
|
149 |
+
self.down_blocks.append(down_block)
|
150 |
+
|
151 |
+
# mid
|
152 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
153 |
+
in_channels=block_out_channels[-1],
|
154 |
+
temb_channels=time_embed_dim,
|
155 |
+
resnet_eps=norm_eps,
|
156 |
+
resnet_act_fn=act_fn,
|
157 |
+
output_scale_factor=mid_block_scale_factor,
|
158 |
+
resnet_time_scale_shift="default",
|
159 |
+
cross_attention_dim=cross_attention_dim,
|
160 |
+
attn_num_head_channels=attention_head_dim,
|
161 |
+
resnet_groups=norm_num_groups,
|
162 |
+
)
|
163 |
+
|
164 |
+
# count how many layers upsample the images
|
165 |
+
self.num_upsamplers = 0
|
166 |
+
|
167 |
+
# up
|
168 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
169 |
+
output_channel = reversed_block_out_channels[0]
|
170 |
+
for i, up_block_type in enumerate(up_block_types):
|
171 |
+
is_final_block = i == len(block_out_channels) - 1
|
172 |
+
|
173 |
+
prev_output_channel = output_channel
|
174 |
+
output_channel = reversed_block_out_channels[i]
|
175 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
176 |
+
|
177 |
+
# add upsample block for all BUT final layer
|
178 |
+
if not is_final_block:
|
179 |
+
add_upsample = True
|
180 |
+
self.num_upsamplers += 1
|
181 |
+
else:
|
182 |
+
add_upsample = False
|
183 |
+
|
184 |
+
up_block = get_up_block(
|
185 |
+
up_block_type,
|
186 |
+
num_layers=layers_per_block + 1,
|
187 |
+
in_channels=input_channel,
|
188 |
+
out_channels=output_channel,
|
189 |
+
prev_output_channel=prev_output_channel,
|
190 |
+
temb_channels=time_embed_dim,
|
191 |
+
add_upsample=add_upsample,
|
192 |
+
resnet_eps=norm_eps,
|
193 |
+
resnet_act_fn=act_fn,
|
194 |
+
resnet_groups=norm_num_groups,
|
195 |
+
cross_attention_dim=cross_attention_dim,
|
196 |
+
attn_num_head_channels=attention_head_dim,
|
197 |
+
)
|
198 |
+
self.up_blocks.append(up_block)
|
199 |
+
prev_output_channel = output_channel
|
200 |
+
|
201 |
+
# out
|
202 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
203 |
+
self.conv_act = nn.SiLU()
|
204 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
205 |
+
|
206 |
+
def set_attention_slice(self, slice_size):
|
207 |
+
if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
|
208 |
+
raise ValueError(
|
209 |
+
f"Make sure slice_size {slice_size} is a divisor of "
|
210 |
+
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
|
211 |
+
)
|
212 |
+
if slice_size is not None and slice_size > self.config.attention_head_dim:
|
213 |
+
raise ValueError(
|
214 |
+
f"Chunk_size {slice_size} has to be smaller or equal to "
|
215 |
+
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
|
216 |
+
)
|
217 |
+
|
218 |
+
for block in self.down_blocks:
|
219 |
+
if hasattr(block, "attentions") and block.attentions is not None:
|
220 |
+
block.set_attention_slice(slice_size)
|
221 |
+
|
222 |
+
self.mid_block.set_attention_slice(slice_size)
|
223 |
+
|
224 |
+
for block in self.up_blocks:
|
225 |
+
if hasattr(block, "attentions") and block.attentions is not None:
|
226 |
+
block.set_attention_slice(slice_size)
|
227 |
+
|
228 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
229 |
+
for block in self.down_blocks:
|
230 |
+
if hasattr(block, "attentions") and block.attentions is not None:
|
231 |
+
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
232 |
+
|
233 |
+
self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
234 |
+
|
235 |
+
for block in self.up_blocks:
|
236 |
+
if hasattr(block, "attentions") and block.attentions is not None:
|
237 |
+
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
238 |
+
|
239 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
240 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
241 |
+
module.gradient_checkpointing = value
|
242 |
+
|
243 |
+
def forward(
|
244 |
+
self,
|
245 |
+
sample: torch.FloatTensor,
|
246 |
+
timestep: Union[torch.Tensor, float, int],
|
247 |
+
encoder_hidden_states: torch.Tensor,
|
248 |
+
encoder_attention_mask: torch.Tensor,
|
249 |
+
return_dict: bool = True,
|
250 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
251 |
+
r"""
|
252 |
+
Args:
|
253 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
254 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
255 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
256 |
+
(batch_size, sequence_length, hidden_size) encoder hidden states
|
257 |
+
encoder_attention_mask (`torch.FloatTensor`):
|
258 |
+
(batch_size, sequence_length) encoder attention mask
|
259 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
260 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
264 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
265 |
+
returning a tuple, the first element is the sample tensor.
|
266 |
+
"""
|
267 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
268 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
269 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
270 |
+
# on the fly if necessary.
|
271 |
+
default_overall_up_factor = 2 ** self.num_upsamplers
|
272 |
+
|
273 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
274 |
+
forward_upsample_size = False
|
275 |
+
upsample_size = None
|
276 |
+
|
277 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
278 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
279 |
+
forward_upsample_size = True
|
280 |
+
|
281 |
+
# 0. center input if necessary
|
282 |
+
if self.config.center_input_sample:
|
283 |
+
sample = 2 * sample - 1.0
|
284 |
+
|
285 |
+
# 1. time
|
286 |
+
timesteps = timestep
|
287 |
+
if not torch.is_tensor(timesteps):
|
288 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
289 |
+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
290 |
+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
291 |
+
timesteps = timesteps[None].to(sample.device)
|
292 |
+
|
293 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
294 |
+
timesteps = timesteps.expand(sample.shape[0])
|
295 |
+
|
296 |
+
t_emb = self.time_proj(timesteps)
|
297 |
+
|
298 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
299 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
300 |
+
# there might be better ways to encapsulate this.
|
301 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
302 |
+
emb = self.time_embedding(t_emb)
|
303 |
+
|
304 |
+
# 2. pre-process
|
305 |
+
sample = self.conv_in(sample)
|
306 |
+
|
307 |
+
# 3. down
|
308 |
+
down_block_res_samples = (sample,)
|
309 |
+
for downsample_block in self.down_blocks:
|
310 |
+
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
311 |
+
sample, res_samples = downsample_block(
|
312 |
+
hidden_states=sample,
|
313 |
+
temb=emb,
|
314 |
+
encoder_hidden_states=encoder_hidden_states,
|
315 |
+
encoder_attention_mask=encoder_attention_mask,
|
316 |
+
)
|
317 |
+
else:
|
318 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
319 |
+
|
320 |
+
down_block_res_samples += res_samples
|
321 |
+
|
322 |
+
# 4. mid
|
323 |
+
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states,
|
324 |
+
encoder_attention_mask=encoder_attention_mask)
|
325 |
+
|
326 |
+
# 5. up
|
327 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
328 |
+
is_final_block = i == len(self.up_blocks) - 1
|
329 |
+
|
330 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
331 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
332 |
+
|
333 |
+
# if we have not reached the final block and need to forward the
|
334 |
+
# upsample size, we do it here
|
335 |
+
if not is_final_block and forward_upsample_size:
|
336 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
337 |
+
|
338 |
+
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
|
339 |
+
sample = upsample_block(
|
340 |
+
hidden_states=sample,
|
341 |
+
temb=emb,
|
342 |
+
res_hidden_states_tuple=res_samples,
|
343 |
+
encoder_hidden_states=encoder_hidden_states,
|
344 |
+
encoder_attention_mask=encoder_attention_mask,
|
345 |
+
upsample_size=upsample_size,
|
346 |
+
)
|
347 |
+
else:
|
348 |
+
sample = upsample_block(
|
349 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
350 |
+
)
|
351 |
+
# 6. post-process
|
352 |
+
sample = self.conv_norm_out(sample)
|
353 |
+
sample = self.conv_act(sample)
|
354 |
+
sample = self.conv_out(sample)
|
355 |
+
|
356 |
+
if not return_dict:
|
357 |
+
return (sample,)
|
358 |
+
|
359 |
+
return UNet2DConditionOutput(sample=sample)
|
models/inception.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import models
|
5 |
+
|
6 |
+
try:
|
7 |
+
from torchvision.models.utils import load_state_dict_from_url
|
8 |
+
except ImportError:
|
9 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
10 |
+
|
11 |
+
# Inception weights ported to Pytorch from
|
12 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
13 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
|
14 |
+
|
15 |
+
|
16 |
+
class InceptionV3(nn.Module):
|
17 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
18 |
+
|
19 |
+
# Index of default block of inception to return,
|
20 |
+
# corresponds to output of final average pooling
|
21 |
+
DEFAULT_BLOCK_INDEX = 3
|
22 |
+
|
23 |
+
# Maps feature dimensionality to their output blocks indices
|
24 |
+
BLOCK_INDEX_BY_DIM = {
|
25 |
+
64: 0, # First max pooling features
|
26 |
+
192: 1, # Second max pooling featurs
|
27 |
+
768: 2, # Pre-aux classifier features
|
28 |
+
2048: 3 # Final average pooling features
|
29 |
+
}
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
output_blocks=[DEFAULT_BLOCK_INDEX],
|
33 |
+
resize_input=True,
|
34 |
+
normalize_input=True,
|
35 |
+
requires_grad=False,
|
36 |
+
use_fid_inception=True):
|
37 |
+
"""Build pretrained InceptionV3
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
output_blocks : list of int
|
42 |
+
Indices of blocks to return features of. Possible values are:
|
43 |
+
- 0: corresponds to output of first max pooling
|
44 |
+
- 1: corresponds to output of second max pooling
|
45 |
+
- 2: corresponds to output which is fed to aux classifier
|
46 |
+
- 3: corresponds to output of final average pooling
|
47 |
+
resize_input : bool
|
48 |
+
If true, bilinearly resizes input to width and height 299 before
|
49 |
+
feeding input to model. As the network without fully connected
|
50 |
+
layers is fully convolutional, it should be able to handle inputs
|
51 |
+
of arbitrary size, so resizing might not be strictly needed
|
52 |
+
normalize_input : bool
|
53 |
+
If true, scales the input from range (0, 1) to the range the
|
54 |
+
pretrained Inception network expects, namely (-1, 1)
|
55 |
+
requires_grad : bool
|
56 |
+
If true, parameters of the model require gradients. Possibly useful
|
57 |
+
for finetuning the network
|
58 |
+
use_fid_inception : bool
|
59 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
60 |
+
FID implementation. If false, uses the pretrained Inception model
|
61 |
+
available in torchvision. The FID Inception model has different
|
62 |
+
weights and a slightly different structure from torchvision's
|
63 |
+
Inception model. If you want to compute FID scores, you are
|
64 |
+
strongly advised to set this parameter to true to get comparable
|
65 |
+
results.
|
66 |
+
"""
|
67 |
+
super(InceptionV3, self).__init__()
|
68 |
+
|
69 |
+
self.resize_input = resize_input
|
70 |
+
self.normalize_input = normalize_input
|
71 |
+
self.output_blocks = sorted(output_blocks)
|
72 |
+
self.last_needed_block = max(output_blocks)
|
73 |
+
|
74 |
+
assert self.last_needed_block <= 3, \
|
75 |
+
'Last possible output block index is 3'
|
76 |
+
|
77 |
+
self.blocks = nn.ModuleList()
|
78 |
+
|
79 |
+
if use_fid_inception:
|
80 |
+
inception = fid_inception_v3()
|
81 |
+
else:
|
82 |
+
inception = models.inception_v3(pretrained=True)
|
83 |
+
|
84 |
+
# Block 0: input to maxpool1
|
85 |
+
block0 = [
|
86 |
+
inception.Conv2d_1a_3x3,
|
87 |
+
inception.Conv2d_2a_3x3,
|
88 |
+
inception.Conv2d_2b_3x3,
|
89 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
90 |
+
]
|
91 |
+
self.blocks.append(nn.Sequential(*block0))
|
92 |
+
|
93 |
+
# Block 1: maxpool1 to maxpool2
|
94 |
+
if self.last_needed_block >= 1:
|
95 |
+
block1 = [
|
96 |
+
inception.Conv2d_3b_1x1,
|
97 |
+
inception.Conv2d_4a_3x3,
|
98 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
99 |
+
]
|
100 |
+
self.blocks.append(nn.Sequential(*block1))
|
101 |
+
|
102 |
+
# Block 2: maxpool2 to aux classifier
|
103 |
+
if self.last_needed_block >= 2:
|
104 |
+
block2 = [
|
105 |
+
inception.Mixed_5b,
|
106 |
+
inception.Mixed_5c,
|
107 |
+
inception.Mixed_5d,
|
108 |
+
inception.Mixed_6a,
|
109 |
+
inception.Mixed_6b,
|
110 |
+
inception.Mixed_6c,
|
111 |
+
inception.Mixed_6d,
|
112 |
+
inception.Mixed_6e,
|
113 |
+
]
|
114 |
+
self.blocks.append(nn.Sequential(*block2))
|
115 |
+
|
116 |
+
# Block 3: aux classifier to final avgpool
|
117 |
+
if self.last_needed_block >= 3:
|
118 |
+
block3 = [
|
119 |
+
inception.Mixed_7a,
|
120 |
+
inception.Mixed_7b,
|
121 |
+
inception.Mixed_7c,
|
122 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
123 |
+
]
|
124 |
+
self.blocks.append(nn.Sequential(*block3))
|
125 |
+
|
126 |
+
for param in self.parameters():
|
127 |
+
param.requires_grad = requires_grad
|
128 |
+
|
129 |
+
def forward(self, inp):
|
130 |
+
"""Get Inception feature maps
|
131 |
+
|
132 |
+
Parameters
|
133 |
+
----------
|
134 |
+
inp : torch.autograd.Variable
|
135 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
136 |
+
range (0, 1)
|
137 |
+
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
141 |
+
block, sorted ascending by index
|
142 |
+
"""
|
143 |
+
outp = []
|
144 |
+
x = inp
|
145 |
+
|
146 |
+
if self.resize_input:
|
147 |
+
x = F.interpolate(x,
|
148 |
+
size=(299, 299),
|
149 |
+
mode='bilinear',
|
150 |
+
align_corners=False)
|
151 |
+
|
152 |
+
if self.normalize_input:
|
153 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
154 |
+
|
155 |
+
for idx, block in enumerate(self.blocks):
|
156 |
+
x = block(x)
|
157 |
+
if idx in self.output_blocks:
|
158 |
+
outp.append(x)
|
159 |
+
|
160 |
+
if idx == self.last_needed_block:
|
161 |
+
break
|
162 |
+
|
163 |
+
return outp
|
164 |
+
|
165 |
+
|
166 |
+
def fid_inception_v3():
|
167 |
+
"""Build pretrained Inception model for FID computation
|
168 |
+
|
169 |
+
The Inception model for FID computation uses a different set of weights
|
170 |
+
and has a slightly different structure than torchvision's Inception.
|
171 |
+
|
172 |
+
This method first constructs torchvision's Inception and then patches the
|
173 |
+
necessary parts that are different in the FID Inception model.
|
174 |
+
"""
|
175 |
+
inception = models.inception_v3(num_classes=1008,
|
176 |
+
aux_logits=False,
|
177 |
+
pretrained=False)
|
178 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
179 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
180 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
181 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
182 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
183 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
184 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
185 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
186 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
187 |
+
|
188 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
189 |
+
inception.load_state_dict(state_dict)
|
190 |
+
return inception
|
191 |
+
|
192 |
+
|
193 |
+
class FIDInceptionA(models.inception.InceptionA):
|
194 |
+
"""InceptionA block patched for FID computation"""
|
195 |
+
|
196 |
+
def __init__(self, in_channels, pool_features):
|
197 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
198 |
+
|
199 |
+
def forward(self, x):
|
200 |
+
branch1x1 = self.branch1x1(x)
|
201 |
+
|
202 |
+
branch5x5 = self.branch5x5_1(x)
|
203 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
204 |
+
|
205 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
206 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
207 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
208 |
+
|
209 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
210 |
+
# its average calculation
|
211 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
212 |
+
count_include_pad=False)
|
213 |
+
branch_pool = self.branch_pool(branch_pool)
|
214 |
+
|
215 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
216 |
+
return torch.cat(outputs, 1)
|
217 |
+
|
218 |
+
|
219 |
+
class FIDInceptionC(models.inception.InceptionC):
|
220 |
+
"""InceptionC block patched for FID computation"""
|
221 |
+
|
222 |
+
def __init__(self, in_channels, channels_7x7):
|
223 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
224 |
+
|
225 |
+
def forward(self, x):
|
226 |
+
branch1x1 = self.branch1x1(x)
|
227 |
+
|
228 |
+
branch7x7 = self.branch7x7_1(x)
|
229 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
230 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
231 |
+
|
232 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
233 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
234 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
235 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
236 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
237 |
+
|
238 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
239 |
+
# its average calculation
|
240 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
241 |
+
count_include_pad=False)
|
242 |
+
branch_pool = self.branch_pool(branch_pool)
|
243 |
+
|
244 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
245 |
+
return torch.cat(outputs, 1)
|
246 |
+
|
247 |
+
|
248 |
+
class FIDInceptionE_1(models.inception.InceptionE):
|
249 |
+
"""First InceptionE block patched for FID computation"""
|
250 |
+
|
251 |
+
def __init__(self, in_channels):
|
252 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
253 |
+
|
254 |
+
def forward(self, x):
|
255 |
+
branch1x1 = self.branch1x1(x)
|
256 |
+
|
257 |
+
branch3x3 = self.branch3x3_1(x)
|
258 |
+
branch3x3 = [
|
259 |
+
self.branch3x3_2a(branch3x3),
|
260 |
+
self.branch3x3_2b(branch3x3),
|
261 |
+
]
|
262 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
263 |
+
|
264 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
265 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
266 |
+
branch3x3dbl = [
|
267 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
268 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
269 |
+
]
|
270 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
271 |
+
|
272 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
273 |
+
# its average calculation
|
274 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
275 |
+
count_include_pad=False)
|
276 |
+
branch_pool = self.branch_pool(branch_pool)
|
277 |
+
|
278 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
279 |
+
return torch.cat(outputs, 1)
|
280 |
+
|
281 |
+
|
282 |
+
class FIDInceptionE_2(models.inception.InceptionE):
|
283 |
+
"""Second InceptionE block patched for FID computation"""
|
284 |
+
|
285 |
+
def __init__(self, in_channels):
|
286 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
287 |
+
|
288 |
+
def forward(self, x):
|
289 |
+
branch1x1 = self.branch1x1(x)
|
290 |
+
|
291 |
+
branch3x3 = self.branch3x3_1(x)
|
292 |
+
branch3x3 = [
|
293 |
+
self.branch3x3_2a(branch3x3),
|
294 |
+
self.branch3x3_2b(branch3x3),
|
295 |
+
]
|
296 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
297 |
+
|
298 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
299 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
300 |
+
branch3x3dbl = [
|
301 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
302 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
303 |
+
]
|
304 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
305 |
+
|
306 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
307 |
+
# pooling. This is likely an error in this specific Inception
|
308 |
+
# implementation, as other Inception models use average pooling here
|
309 |
+
# (which matches the description in the paper).
|
310 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
311 |
+
branch_pool = self.branch_pool(branch_pool)
|
312 |
+
|
313 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
314 |
+
return torch.cat(outputs, 1)
|
v1-5-pruned-emaonly.ckpt → pororo_100.h5
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6b5d47440de7abbbbb2265e1d5ecbc1c5d4d3188434db3988cb13e7ec5fa7549
|
3 |
+
size 69568
|
readme-storyvisualization.md
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 一、基于叙事文本的跨模态序列图像生成模型
|
2 |
+
|
3 |
+
## 安装环境
|
4 |
+
conda create -n arldm python=3.8
|
5 |
+
conda activate arldm
|
6 |
+
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch-lts
|
7 |
+
cd /root/lihui/StoryVisualization
|
8 |
+
pip install -r requirements.txt
|
9 |
+
## 数据准备
|
10 |
+
Download the PororoSV dataset here.
|
11 |
+
To accelerate I/O, using the following scrips to convert your downloaded data to HDF5
|
12 |
+
python data_script/pororo_hdf5.py
|
13 |
+
--data_dir /path/to/pororo_data
|
14 |
+
--save_path /path/to/save_hdf5_file
|
15 |
+
## 配置文件config.yaml
|
16 |
+
|
17 |
+
#device
|
18 |
+
mode: sample # train sample
|
19 |
+
ckpt_dir: /root/lihui/StoryVisualization/save_ckpt_epoch5_new # checkpoint directory
|
20 |
+
run_name: ARLDM # name for this run
|
21 |
+
|
22 |
+
#train
|
23 |
+
train_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for resume, none for train from scratch
|
24 |
+
|
25 |
+
#sample
|
26 |
+
test_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for test
|
27 |
+
sample_output_dir: /root/lihui/StoryVisualization/save_samples_128_epoch50 # output directory
|
28 |
+
## 训练
|
29 |
+
在 config.yaml 中指定您的目录和设备配置并运行:
|
30 |
+
python main.py
|
31 |
+
## 采样
|
32 |
+
在 config.yaml 中指定您的目录和设备配置并运行:
|
33 |
+
python main.py
|
34 |
+
## 引用
|
35 |
+
@article{pan2022synthesizing,
|
36 |
+
title={Synthesizing Coherent Story with Auto-Regressive Latent Diffusion Models},
|
37 |
+
author={Pan, Xichen and Qin, Pengda and Li, Yuhong and Xue, Hui and Chen, Wenhu},
|
38 |
+
journal={arXiv preprint arXiv:2211.10950},
|
39 |
+
year={2022}
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
### 二、基于Real-ESRGAN的超分算法
|
44 |
+
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
|
45 |
+
[论文] [项目主页] [YouTube 视频] [B站视频] [Poster] [PPT]
|
46 |
+
Xintao Wang, Liangbin Xie, Chao Dong, Ying Shan
|
47 |
+
Tencent ARC Lab; Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences
|
48 |
+
## 环境
|
49 |
+
Python >= 3.7 (推荐使用Anaconda或Miniconda)
|
50 |
+
PyTorch >= 1.7
|
51 |
+
## 安装
|
52 |
+
1、直接进入已配好的文件夹
|
53 |
+
cd /root/lihui/StoryVisualization/Real-ESRGAN
|
54 |
+
2、或 把项目克隆到本地
|
55 |
+
bash git clone https://github.com/xinntao/Real-ESRGAN.git cd Real-ESRGAN
|
56 |
+
3、 安装各种依赖
|
57 |
+
```bash
|
58 |
+
安装 basicsr - https://github.com/xinntao/BasicSR
|
59 |
+
#我们使用BasicSR来训练以及推断
|
60 |
+
pip install basicsr
|
61 |
+
#facexlib和gfpgan是用来增强人脸的
|
62 |
+
pip install facexlib pip install gfpgan pip install -r requirements.txt python setup.py develop
|
63 |
+
```
|
64 |
+
## 训练
|
65 |
+
训练好的模型: RealESRGAN_x4plus_anime_6B
|
66 |
+
有关waifu2x的更多信息和对比在anime_model.md中。
|
67 |
+
## 下载模型
|
68 |
+
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P weights
|
69 |
+
## 推断
|
70 |
+
python inference_realesrgan.py -n RealESRGAN_x4plus_anime_6B -i inputs
|
71 |
+
结果在results文件夹
|
72 |
+
## BibTeX 引用
|
73 |
+
@Article{wang2021realesrgan,
|
74 |
+
title={Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data},
|
75 |
+
author={Xintao Wang and Liangbin Xie and Chao Dong and Ying Shan},
|
76 |
+
journal={arXiv:2107.10833},
|
77 |
+
year={2021}
|
78 |
+
}
|
79 |
+
|
80 |
+
|
81 |
+
### 三、基于YOLOv5的目标角色检测算法
|
82 |
+
## 安装
|
83 |
+
克隆 repo,并要求在 Python>=3.7.0 环境中安装 requirements.txt ,且要求 PyTorch>=1.7 。
|
84 |
+
git clone https://github.com/ultralytics/yolov5 # clone
|
85 |
+
cd /root/lihui/StoryVisualization
|
86 |
+
cd yolov5
|
87 |
+
pip install -r requirements.txt # install
|
88 |
+
## 转换图片
|
89 |
+
cd /root/lihui/StoryVisualization
|
90 |
+
python transtoyolo.py
|
91 |
+
## 使用 detect.py 推理
|
92 |
+
detect.py 在各种来源上运行推理, 模型 自动从 最新的YOLOv5 release 中下载,并将结果保存到 runs/detect 。
|
93 |
+
python detect.py --weights yolov5s.pt --source 0 # webcam
|
94 |
+
img.jpg # image
|
95 |
+
vid.mp4 # video
|
96 |
+
screen # screenshot
|
97 |
+
path/ # directory
|
98 |
+
list.txt # list of images
|
99 |
+
list.streams # list of streams
|
100 |
+
'path/*.jpg' # glob
|
101 |
+
'https://youtu.be/Zgi9g1ksQHc' # YouTube
|
102 |
+
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
|
103 |
+
## 训练
|
104 |
+
最新的 模型 和 数据集 将自动的从 YOLOv5 release 中下载。 YOLOv5n/s/m/l/x 在 V100 GPU 的训练时间为 1/2/4/6/8 天( 多GPU 训练速度更快)。 尽可能使用更大的 --batch-size ,或通过 --batch-size -1 实现 YOLOv5 自动批处理 。下方显示的 batchsize 适用于 V100-16GB。
|
105 |
+
python train.py --data xxx.yaml --epochs 500 --weights '' --cfg yolov5l --batch-size 64
|
106 |
+
# xx.yaml文件为转换后的数据
|
107 |
+
|
108 |
+
## 许可
|
109 |
+
YOLOv5 在两种不同的 License 下可用:
|
110 |
+
AGPL-3.0 License: 查看 License 文件的详细信息。
|
111 |
+
企业License:在没有 AGPL-3.0 开源要求的情况下为商业产品开发提供更大的灵活性。典型用例是将 Ultralytics 软件和 AI 模型嵌入到商业产品和应用程序中。在以下位置申请企业许可证 Ultralytics 许可 。
|
112 |
+
|
113 |
+
|
114 |
+
### 四、演示系统
|
115 |
+
|
116 |
+
## 指定文件目录并运行:
|
117 |
+
cd /root/lihui/StoryVisualization/visualsystem
|
118 |
+
python main.py
|
119 |
+
|
120 |
+
|
121 |
+
#
|
122 |
+
Your identification has been saved in .
|
123 |
+
Your public key has been saved in C:\Users\30254/.ssh/id_ed25519.pub.
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pytorch_lightning<1.7.0
|
2 |
+
lightning-bolts
|
3 |
+
transformers==4.24.0
|
4 |
+
diffusers==0.7.2
|
5 |
+
timm
|
6 |
+
ftfy
|
7 |
+
hydra-core
|
8 |
+
opencv-python
|
9 |
+
h5py
|
10 |
+
scipy
|
run.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python main.py
|
test.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import h5py
|
3 |
+
import copy
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
|
7 |
+
import numpy
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
|
12 |
+
def gettext(index):
|
13 |
+
with h5py.File('/root/lihui/StoryVisualization/pororo.h5', 'r') as h5:
|
14 |
+
story = list()
|
15 |
+
h5 = h5['test']
|
16 |
+
# 读取当前索引处的文本,并使用decode方法将其解码为UTF-8
|
17 |
+
texts = h5['text'][index].decode('utf-8').split('|')
|
18 |
+
symbol = '\n'
|
19 |
+
texts = symbol.join(texts)
|
20 |
+
texts = 'Story<' + str(index) + '> :' + '\n' + texts
|
21 |
+
print(texts)
|
22 |
+
return texts
|
23 |
+
|
24 |
+
|
25 |
+
# for i in range(1000):
|
26 |
+
# gettext(i)
|
27 |
+
|
28 |
+
# 截取前100的数据集
|
29 |
+
# ###正确的##############
|
30 |
+
# # import h5py
|
31 |
+
# # import numpy as np
|
32 |
+
# # from PIL import Image
|
33 |
+
# #
|
34 |
+
# #
|
35 |
+
# # # 创建名为“images”的子目录来保存图像
|
36 |
+
# # os.makedirs("train_images", exist_ok=True)
|
37 |
+
# #
|
38 |
+
# # 创建一个h5文件
|
39 |
+
# nf = h5py.File('/root/lihui/StoryVisualization/pororo_100.h5', "w")
|
40 |
+
# with h5py.File('/root/lihui/StoryVisualization/pororo.h5', 'r') as f:
|
41 |
+
# test_group = f['test']
|
42 |
+
# texts = np.array(test_group['text'][()])
|
43 |
+
# ngroup = nf.create_group('test')
|
44 |
+
# ntext = ngroup.create_dataset('text', (100,), dtype=h5py.string_dtype(encoding='utf-8'))
|
45 |
+
# for i in range(100):
|
46 |
+
# ntext[i]=texts[i]
|
47 |
+
# print(f"样本 {i}:")
|
48 |
+
# # for j in range(5):
|
49 |
+
# # # 创建一个固定的文件名来保存图像
|
50 |
+
# # # filename = os.path.join("images", f"image_{i}_{j}.png")
|
51 |
+
# # # # 将HDF5文件中的图像数据保存到文件中
|
52 |
+
# # # with open(filename, "wb") as img_file:
|
53 |
+
# # # img_file.write(test_group[f'image{j}'][i])
|
54 |
+
# # # 打印文本信息和文件名
|
55 |
+
# # ntext[i]='|'.join(texts[i].decode('utf-8').split('|')[j])
|
56 |
+
# # print(f"图像{j}已保存到文件:{filename}")
|
57 |
+
# print(ntext[i])
|
58 |
+
# nf.close()
|
59 |
+
|
60 |
+
#保存测试集图像,随机截取视频帧
|
61 |
+
with h5py.File(r'C:\Users\zjlab\Desktop\StoryVisualization\pororo.h5', 'r') as h5:
|
62 |
+
h5 = h5['test']
|
63 |
+
|
64 |
+
for index in range(len(h5['text'])): #len(h5['text'])
|
65 |
+
# index = int(index + 1)
|
66 |
+
# print(index)
|
67 |
+
images = list()
|
68 |
+
for i in range(5):
|
69 |
+
# 从h5文件中读取一组图像和对应的文本。
|
70 |
+
im = h5['image{}'.format(i)][index]
|
71 |
+
# print(im)
|
72 |
+
# pil_img = Image.fromarray(im)
|
73 |
+
# # 保存图像
|
74 |
+
# pil_img.save(os.path.join('/root/lihui/StoryVisualization/ori_test_images', '{:04d}.png'.format(i)))
|
75 |
+
# 对每个图像解码
|
76 |
+
im = cv2.imdecode(im, cv2.IMREAD_COLOR)
|
77 |
+
# 随机选择一个128像素的图像切片
|
78 |
+
idx = random.randint(0, im.shape[0] / 128 - 1)
|
79 |
+
# 将切片后的图像加到images列表中
|
80 |
+
images.append(im[idx * 128: (idx + 1) * 128])
|
81 |
+
# 深拷贝,后续不随images变化
|
82 |
+
# ori_images = copy.deepcopy(images)
|
83 |
+
# 保存test原始图像
|
84 |
+
|
85 |
+
# for i, im in enumerate(images):
|
86 |
+
# file_path = 'C:/Users/zjlab/Desktop/StoryVisualization/test_images/group{:02d}_image{:02d}.png'.format(
|
87 |
+
# index + 1,
|
88 |
+
# i + 1)
|
89 |
+
# cv2.imwrite(file_path, im)
|
90 |
+
|
91 |
+
ori_images_pil = Image.fromarray(images[i])#numpy.uint8(images[i].detach().cpu().squeeze().float().numpy())).convert("RGB")
|
92 |
+
ori_images_pil.save(
|
93 |
+
os.path.join('C:/Users/zjlab/Desktop/StoryVisualization/test_images',
|
94 |
+
'group{:02d}_image{:02d}.png'.format(index + 1,i + 1)))
|
transtoyolo.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import json
|
6 |
+
from glob import glob
|
7 |
+
import cv2
|
8 |
+
import shutil
|
9 |
+
import yaml
|
10 |
+
from sklearn.model_selection import train_test_split
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
# 获取当前路径
|
15 |
+
ROOT_DIR = os.getcwd()
|
16 |
+
|
17 |
+
'''
|
18 |
+
统一图像格式
|
19 |
+
'''
|
20 |
+
def change_image_format(label_path=ROOT_DIR, suffix='.png'):
|
21 |
+
"""
|
22 |
+
统一当前文件夹下所有图像的格式,如'.jpg'
|
23 |
+
:param suffix: 图像文件后缀
|
24 |
+
:param label_path:当前文件路径
|
25 |
+
:return:
|
26 |
+
"""
|
27 |
+
externs = ['png', 'jpg', 'JPEG', 'BMP', 'bmp']
|
28 |
+
files = list()
|
29 |
+
# 获取尾缀在ecterns中的所有图像
|
30 |
+
for extern in externs:
|
31 |
+
files.extend(glob(label_path + "\\*." + extern))
|
32 |
+
# 遍历所有图像,转换图像格式
|
33 |
+
for file in files:
|
34 |
+
name = ''.join(file.split('.')[:-1])
|
35 |
+
file_suffix = file.split('.')[-1]
|
36 |
+
if file_suffix != suffix.split('.')[-1]:
|
37 |
+
# 重命名为jpg
|
38 |
+
new_name = name + suffix
|
39 |
+
# 读取图像
|
40 |
+
image = cv2.imread(file)
|
41 |
+
# 重新存图为jpg格式
|
42 |
+
cv2.imwrite(new_name, image)
|
43 |
+
# 删除旧图像
|
44 |
+
os.remove(file)
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
'''
|
49 |
+
读取所有json文件,获取所有的类别
|
50 |
+
'''
|
51 |
+
def get_all_class(file_list, label_path=ROOT_DIR):
|
52 |
+
"""
|
53 |
+
从json文件中获取当前数据的所有类别
|
54 |
+
:param file_list:当前路径下的所有文件名
|
55 |
+
:param label_path:当前文件路径
|
56 |
+
:return:
|
57 |
+
"""
|
58 |
+
# 初始化类别列表
|
59 |
+
classes = list()
|
60 |
+
# 遍历所有json,读取shape中的label值内容,添加到classes
|
61 |
+
for filename in tqdm(file_list):
|
62 |
+
json_path = os.path.join(label_path, filename + '.json')
|
63 |
+
json_file = json.load(open(json_path, "r", encoding="utf-8"))
|
64 |
+
for item in json_file["shapes"]:
|
65 |
+
label_class = item['label']
|
66 |
+
if label_class not in classes:
|
67 |
+
classes.append(label_class)
|
68 |
+
print('read file done')
|
69 |
+
return classes
|
70 |
+
|
71 |
+
|
72 |
+
'''
|
73 |
+
划分训练集、验证机、测试集
|
74 |
+
'''
|
75 |
+
def split_dataset(label_path, test_size=0.3, isUseTest=False, useNumpyShuffle=False):
|
76 |
+
"""
|
77 |
+
将文件分为训练集,测试集和验证集
|
78 |
+
:param useNumpyShuffle: 使用numpy方法分割数据集
|
79 |
+
:param test_size: 分割测试集或验证集的比例
|
80 |
+
:param isUseTest: 是否使用测试集,默认为False
|
81 |
+
:param label_path:当前文件路径
|
82 |
+
:return:
|
83 |
+
"""
|
84 |
+
# 获取所有json
|
85 |
+
files = glob(label_path + "\\*.json")
|
86 |
+
files = [i.replace("\\", "/").split("/")[-1].split(".json")[0] for i in files]
|
87 |
+
|
88 |
+
if useNumpyShuffle:
|
89 |
+
file_length = len(files)
|
90 |
+
index = np.arange(file_length)
|
91 |
+
np.random.seed(32)
|
92 |
+
np.random.shuffle(index) # 随机划分
|
93 |
+
|
94 |
+
test_files = None
|
95 |
+
# 是否有测试集
|
96 |
+
if isUseTest:
|
97 |
+
trainval_files, test_files = np.array(files)[index[:int(file_length * (1 - test_size))]], np.array(files)[
|
98 |
+
index[int(file_length * (1 - test_size)):]]
|
99 |
+
else:
|
100 |
+
trainval_files = files
|
101 |
+
# 划分训练集和测试集
|
102 |
+
train_files, val_files = np.array(trainval_files)[index[:int(len(trainval_files) * (1 - test_size))]], \
|
103 |
+
np.array(trainval_files)[index[int(len(trainval_files) * (1 - test_size)):]]
|
104 |
+
else:
|
105 |
+
test_files = None
|
106 |
+
if isUseTest:
|
107 |
+
trainval_files, test_files = train_test_split(files, test_size=test_size, random_state=55)
|
108 |
+
else:
|
109 |
+
trainval_files = files
|
110 |
+
train_files, val_files = train_test_split(trainval_files, test_size=test_size, random_state=55)
|
111 |
+
|
112 |
+
return train_files, val_files, test_files, files
|
113 |
+
|
114 |
+
|
115 |
+
'''
|
116 |
+
生成yolov5的训练、验证、测试集的文件夹
|
117 |
+
'''
|
118 |
+
def create_save_file(label_path=ROOT_DIR):
|
119 |
+
"""
|
120 |
+
按照训练时的图像和标注路径创建文件夹
|
121 |
+
:param label_path:当前文件路径
|
122 |
+
:return:
|
123 |
+
"""
|
124 |
+
# 生成训练集
|
125 |
+
train_image = os.path.join(label_path, 'train', 'images')
|
126 |
+
if not os.path.exists(train_image):
|
127 |
+
os.makedirs(train_image)
|
128 |
+
train_label = os.path.join(label_path, 'train', 'labels')
|
129 |
+
if not os.path.exists(train_label):
|
130 |
+
os.makedirs(train_label)
|
131 |
+
# 生成验证集
|
132 |
+
val_image = os.path.join(label_path, 'valid', 'images')
|
133 |
+
if not os.path.exists(val_image):
|
134 |
+
os.makedirs(val_image)
|
135 |
+
val_label = os.path.join(label_path, 'valid', 'labels')
|
136 |
+
if not os.path.exists(val_label):
|
137 |
+
os.makedirs(val_label)
|
138 |
+
# 生成测试集
|
139 |
+
test_image = os.path.join(label_path, 'test', 'images')
|
140 |
+
if not os.path.exists(test_image):
|
141 |
+
os.makedirs(test_image)
|
142 |
+
test_label = os.path.join(label_path, 'test', 'labels')
|
143 |
+
if not os.path.exists(test_label):
|
144 |
+
os.makedirs(test_label)
|
145 |
+
return train_image, train_label, val_image, val_label, test_image, test_label
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
'''
|
150 |
+
转换,根据图像大小,返回box框的中点和高宽信息
|
151 |
+
'''
|
152 |
+
def convert(size, box):
|
153 |
+
# 宽
|
154 |
+
dw = 1. / (size[0])
|
155 |
+
# 高
|
156 |
+
dh = 1. / (size[1])
|
157 |
+
|
158 |
+
x = (box[0] + box[1]) / 2.0 - 1
|
159 |
+
y = (box[2] + box[3]) / 2.0 - 1
|
160 |
+
# 宽
|
161 |
+
w = box[1] - box[0]
|
162 |
+
# 高
|
163 |
+
h = box[3] - box[2]
|
164 |
+
|
165 |
+
x = x * dw
|
166 |
+
w = w * dw
|
167 |
+
y = y * dh
|
168 |
+
h = h * dh
|
169 |
+
return x, y, w, h
|
170 |
+
|
171 |
+
|
172 |
+
'''
|
173 |
+
移动图像和标注文件到指定的训练集、验证集和测试集中
|
174 |
+
'''
|
175 |
+
def push_into_file(file, images, labels, label_path=ROOT_DIR, suffix='.jpg'):
|
176 |
+
"""
|
177 |
+
最终生成在当前文件夹下的所有文件按image和label分别存在到训练集/验证集/测试集路径的文件夹下
|
178 |
+
:param file: 文件名列表
|
179 |
+
:param images: 存放images的路径
|
180 |
+
:param labels: 存放labels的路径
|
181 |
+
:param label_path: 当前文件路径
|
182 |
+
:param suffix: 图像文件后缀
|
183 |
+
:return:
|
184 |
+
"""
|
185 |
+
# 遍历所有文件
|
186 |
+
for filename in file:
|
187 |
+
# 图像文件
|
188 |
+
image_file = os.path.join(label_path, filename + suffix)
|
189 |
+
# 标注文件
|
190 |
+
label_file = os.path.join(label_path, filename + '.txt')
|
191 |
+
# yolov5存放图像文件夹
|
192 |
+
if not os.path.exists(os.path.join(images, filename + suffix)):
|
193 |
+
try:
|
194 |
+
shutil.move(image_file, images)
|
195 |
+
except OSError:
|
196 |
+
pass
|
197 |
+
# yolov5存放标注文件夹
|
198 |
+
if not os.path.exists(os.path.join(labels, filename + suffix)):
|
199 |
+
try:
|
200 |
+
shutil.move(label_file, labels)
|
201 |
+
except OSError:
|
202 |
+
pass
|
203 |
+
|
204 |
+
'''
|
205 |
+
|
206 |
+
'''
|
207 |
+
def json2txt(classes, txt_Name='allfiles', label_path=ROOT_DIR, suffix='.png'):
|
208 |
+
"""
|
209 |
+
将json文件转化为txt文件,并将json文件存放到指定文件夹
|
210 |
+
:param classes: 类别名
|
211 |
+
:param txt_Name:txt文件,用来存放所有文件的路径
|
212 |
+
:param label_path:当前文件路径
|
213 |
+
:param suffix:图像文件后缀
|
214 |
+
:return:
|
215 |
+
"""
|
216 |
+
store_json = os.path.join(label_path, 'json')
|
217 |
+
if not os.path.exists(store_json):
|
218 |
+
os.makedirs(store_json)
|
219 |
+
|
220 |
+
_, _, _, files = split_dataset(label_path)
|
221 |
+
if not os.path.exists(os.path.join(label_path, 'tmp')):
|
222 |
+
os.makedirs(os.path.join(label_path, 'tmp'))
|
223 |
+
|
224 |
+
list_file = open('tmp/%s.txt' % txt_Name, 'w')
|
225 |
+
for json_file_ in tqdm(files):
|
226 |
+
json_filename = os.path.join(label_path, json_file_ + ".json")
|
227 |
+
imagePath = os.path.join(label_path, json_file_ + suffix)
|
228 |
+
list_file.write('%s\n' % imagePath)
|
229 |
+
out_file = open('%s/%s.txt' % (label_path, json_file_), 'w')
|
230 |
+
json_file = json.load(open(json_filename, "r", encoding="utf-8"))
|
231 |
+
if os.path.exists(imagePath):
|
232 |
+
height, width, channels = cv2.imread(imagePath).shape
|
233 |
+
for multi in json_file["shapes"]:
|
234 |
+
if len(multi["points"][0]) == 0:
|
235 |
+
out_file.write('')
|
236 |
+
continue
|
237 |
+
points = np.array(multi["points"])
|
238 |
+
xmin = min(points[:, 0]) if min(points[:, 0]) > 0 else 0
|
239 |
+
xmax = max(points[:, 0]) if max(points[:, 0]) > 0 else 0
|
240 |
+
ymin = min(points[:, 1]) if min(points[:, 1]) > 0 else 0
|
241 |
+
ymax = max(points[:, 1]) if max(points[:, 1]) > 0 else 0
|
242 |
+
label = multi["label"]
|
243 |
+
if xmax <= xmin:
|
244 |
+
pass
|
245 |
+
elif ymax <= ymin:
|
246 |
+
pass
|
247 |
+
else:
|
248 |
+
cls_id = classes.index(label)
|
249 |
+
b = (float(xmin), float(xmax), float(ymin), float(ymax))
|
250 |
+
bb = convert((width, height), b)
|
251 |
+
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
|
252 |
+
# print(json_filename, xmin, ymin, xmax, ymax, cls_id)
|
253 |
+
if not os.path.exists(os.path.join(store_json, json_file_ + '.json')):
|
254 |
+
try:
|
255 |
+
shutil.move(json_filename, store_json)
|
256 |
+
except OSError:
|
257 |
+
pass
|
258 |
+
|
259 |
+
'''
|
260 |
+
创建yaml文件
|
261 |
+
'''
|
262 |
+
def create_yaml(classes, label_path, isUseTest=False):
|
263 |
+
nc = len(classes)
|
264 |
+
if not isUseTest:
|
265 |
+
desired_caps = {
|
266 |
+
'path': label_path,
|
267 |
+
'train': 'train/images',
|
268 |
+
'val': 'valid/images',
|
269 |
+
'nc': nc,
|
270 |
+
'names': classes
|
271 |
+
}
|
272 |
+
else:
|
273 |
+
desired_caps = {
|
274 |
+
'path': label_path,
|
275 |
+
'train': 'train/images',
|
276 |
+
'val': 'valid/images',
|
277 |
+
'test': 'test/images',
|
278 |
+
'nc': nc,
|
279 |
+
'names': classes
|
280 |
+
}
|
281 |
+
yamlpath = os.path.join(label_path, "data" + ".yaml")
|
282 |
+
|
283 |
+
# 写入到yaml文件
|
284 |
+
with open(yamlpath, "w+", encoding="utf-8") as f:
|
285 |
+
for key, val in desired_caps.items():
|
286 |
+
yaml.dump({key: val}, f, default_flow_style=False)
|
287 |
+
|
288 |
+
|
289 |
+
# 首先确保当前文件夹下的所有图片统一后缀,如.jpg,如果为其他后缀,将suffix改为对应的后缀,如.png
|
290 |
+
def ChangeToYolo5(label_path=r"D:\storydata", suffix='.png', test_size=0.1, isUseTest=False):
|
291 |
+
"""
|
292 |
+
生成最终标准格式的文件
|
293 |
+
:param test_size: 分割测试集或验证集的比例
|
294 |
+
:param label_path:当前文件路径
|
295 |
+
:param suffix: 文件后缀名
|
296 |
+
:param isUseTest: 是否使用测试集
|
297 |
+
:return:
|
298 |
+
"""
|
299 |
+
# step1:统一图像格式
|
300 |
+
change_image_format(label_path)
|
301 |
+
# step2:根据json文件划分训练集、验证集、测试集
|
302 |
+
train_files, val_files, test_file, files = split_dataset(label_path, test_size=test_size, isUseTest=isUseTest)
|
303 |
+
# step3:根据json文件,获取所有类别
|
304 |
+
classes = get_all_class(files)
|
305 |
+
# step4:将json文件转化为txt文件,并将json文件存放到指定文件夹
|
306 |
+
json2txt(classes)
|
307 |
+
# step5:创建yolov5训练所需的yaml文件
|
308 |
+
create_yaml(classes, label_path, isUseTest=isUseTest)
|
309 |
+
# step6:生成yolov5的训练、验证、测试集的文件夹
|
310 |
+
train_image, train_label, val_image, val_label, test_image, test_label = create_save_file(label_path)
|
311 |
+
# step7:将所有图像和标注文件,移动到对应的训练集、验证集、测试集
|
312 |
+
push_into_file(train_files, train_image, train_label, suffix=suffix) # 将文件移动到训练集文件中
|
313 |
+
push_into_file(val_files, val_image, val_label, suffix=suffix) # 将文件移动到验证集文件夹中
|
314 |
+
if test_file is not None: # 如果测试集存在,则将文件移动到测试集文件中
|
315 |
+
push_into_file(test_file, test_image, test_label, suffix=suffix)
|
316 |
+
print('create dataset done')
|
317 |
+
|
318 |
+
|
319 |
+
if __name__ == "__main__":
|
320 |
+
ChangeToYolo5()
|
v1-5-pruned-emaonly.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:6ce0161689b3853acaa03779ec93eafe75a02f4ced659bee03f50797806fa2fa
|
3 |
-
size 4265146304
|
|
|
|
|
|
|
|
v1-5-pruned.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:1a189f0be69d6106a48548e7626207dddd7042a418dbf372cefd05e0cdba61b6
|
3 |
-
size 7703324286
|
|
|
|
|
|
|
|