wondervictor commited on
Commit
2422035
·
1 Parent(s): f6bd4fa

update README

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +163 -0
  2. app.py +29 -4
  3. app_canny.py +100 -0
  4. app_depth.py +92 -0
  5. autoregressive/models/README.md +6 -0
  6. autoregressive/models/dinov2_adapter.py +36 -0
  7. autoregressive/models/generate.py +204 -0
  8. autoregressive/models/gpt_t2i.py +561 -0
  9. autoregressive/sample/sample_c2i.py +151 -0
  10. autoregressive/sample/sample_c2i_ddp.py +188 -0
  11. autoregressive/sample/sample_t2i.py +215 -0
  12. autoregressive/sample/sample_t2i_MR.py +237 -0
  13. autoregressive/sample/sample_t2i_ddp.py +229 -0
  14. checkpoints/vq_ds16_t2i.pt +3 -0
  15. condition/README.md +23 -0
  16. condition/canny.py +25 -0
  17. condition/depth.py +47 -0
  18. condition/example/t2i/multi_resolution/bird.jpg +0 -0
  19. condition/example/t2i/multi_resolution/car.jpg +0 -0
  20. condition/example/t2i/multigen/doll.jpg +0 -0
  21. condition/example/t2i/multigen/girl.jpg +0 -0
  22. condition/example/t2i/multigen/house.jpg +0 -0
  23. condition/example/t2i/multigen/sofa.png +0 -0
  24. condition/hed.py +117 -0
  25. condition/lineart.py +98 -0
  26. condition/midas/depth.py +223 -0
  27. condition/midas/midas/__init__.py +0 -0
  28. condition/midas/midas/base_model.py +16 -0
  29. condition/midas/midas/blocks.py +341 -0
  30. condition/midas/midas/dpt_depth.py +108 -0
  31. condition/midas/midas/midas_net.py +76 -0
  32. condition/midas/midas/midas_net_custom.py +128 -0
  33. condition/midas/midas/transforms.py +234 -0
  34. condition/midas/midas/vit.py +491 -0
  35. condition/utils.py +38 -0
  36. language/README.md +14 -0
  37. language/extract_t5_feature.py +129 -0
  38. language/t5.py +201 -0
  39. model.py +242 -0
  40. style.css +10 -0
  41. tokenizer/consistencydecoder/README.md +14 -0
  42. tokenizer/consistencydecoder/cd_demo.py +57 -0
  43. tokenizer/consistencydecoder/reconstruction_cd_ddp.py +208 -0
  44. tokenizer/tokenizer_image/cache/vgg.pth +3 -0
  45. tokenizer/tokenizer_image/discriminator.py +255 -0
  46. tokenizer/tokenizer_image/discriminator_patchgan.py +152 -0
  47. tokenizer/tokenizer_image/discriminator_stylegan.py +101 -0
  48. tokenizer/tokenizer_image/lpips.py +164 -0
  49. tokenizer/tokenizer_image/reconstruction_vq_ddp.py +207 -0
  50. tokenizer/tokenizer_image/vq_demo.py +84 -0
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+
app.py CHANGED
@@ -1,7 +1,32 @@
 
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
  import gradio as gr
3
+ from huggingface_hub import hf_hub_download
4
+ from model import Model
5
+ from app_canny import create_demo as create_demo_canny
6
+ from app_depth import create_demo as create_demo_depth
7
+ import os
8
 
 
 
9
 
10
+ hf_hub_download('wondervictor/ControlAR', filename='canny_MR.safetensors', cache_dir='./checkpoints/')
11
+ hf_hub_download('wondervictor/ControlAR', filename='depth_MR.safetensors', cache_dir='./checkpoints/')
12
+
13
+
14
+ DESCRIPTION = "# [ControlAR: Controllable Image Generation with Autoregressive Models](https://arxiv.org/abs/2410.02705) \n ### The first row in outputs is the input image and condition. The second row is the images generated by ControlAR. \n ### You can run locally by following the instruction on our [Github Repo](https://github.com/hustvl/ControlAR)."
15
+ SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
16
+ model = Model()
17
+ device = "cuda"
18
+ with gr.Blocks(css="style.css") as demo:
19
+ gr.Markdown(DESCRIPTION)
20
+ gr.DuplicateButton(
21
+ value="Duplicate Space for private use",
22
+ elem_id="duplicate-button",
23
+ visible=SHOW_DUPLICATE_BUTTON,
24
+ )
25
+ with gr.Tabs():
26
+ with gr.TabItem("Depth"):
27
+ create_demo_depth(model.process_depth)
28
+ with gr.TabItem("Canny"):
29
+ create_demo_canny(model.process_canny)
30
+
31
+ if __name__ == "__main__":
32
+ demo.queue().launch(share=False, server_name="0.0.0.0")
app_canny.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
4
+ if randomize_seed:
5
+ seed = random.randint(0, 100000000)
6
+ return seed
7
+ examples = [
8
+ [
9
+ "condition/example/t2i/multigen/doll.png",
10
+ "A stuffed animal wearing a mask and a leash, sitting on a blanket",
11
+ "(512, 512)"
12
+ ],
13
+ [
14
+ "condition/example/t2i/multigen/girl.png",
15
+ "An anime style girl with blue hair",
16
+ "(512, 512)"
17
+ ],
18
+ [
19
+ "condition/example/t2i/multi_resolution/bird.jpg",
20
+ "colorful bird",
21
+ "(921, 564)"
22
+ ],
23
+ ]
24
+ def create_demo(process):
25
+ with gr.Blocks() as demo:
26
+ with gr.Row():
27
+ with gr.Column():
28
+ image = gr.Image()
29
+ prompt = gr.Textbox(label="Prompt")
30
+ run_button = gr.Button("Run")
31
+ with gr.Accordion("Advanced options", open=False):
32
+ canny_low_threshold = gr.Slider(
33
+ label="Canny low threshold", minimum=0, maximum=1000, value=100, step=50
34
+ )
35
+ canny_high_threshold = gr.Slider(
36
+ label="Canny high threshold", minimum=0, maximum=1000, value=200, step=50
37
+ )
38
+ cfg_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=4, step=0.1)
39
+ relolution = gr.Slider(label="(H, W)", minimum=384, maximum=768, value=512, step=16)
40
+ top_k = gr.Slider(minimum=1, maximum=16384, step=1, value=2000, label='Top-K')
41
+ top_p = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label="Top-P")
42
+ temperature = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label='Temperature')
43
+ seed = gr.Slider(label="Seed", minimum=0, maximum=100000000, step=1, value=0)
44
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
45
+ with gr.Column():
46
+ result = gr.Gallery(label="Output", show_label=False, height='800px', columns=2, object_fit="scale-down")
47
+ gr.Examples(
48
+ examples=examples,
49
+ inputs=[
50
+ image,
51
+ prompt,
52
+ relolution,
53
+ ],
54
+ outputs=result,
55
+ fn=process,
56
+ )
57
+ inputs = [
58
+ image,
59
+ prompt,
60
+ cfg_scale,
61
+ temperature,
62
+ top_k,
63
+ top_p,
64
+ seed,
65
+ canny_low_threshold,
66
+ canny_high_threshold,
67
+ ]
68
+ prompt.submit(
69
+ fn=randomize_seed_fn,
70
+ inputs=[seed, randomize_seed],
71
+ outputs=seed,
72
+ queue=False,
73
+ api_name=False,
74
+ ).then(
75
+ fn=process,
76
+ inputs=inputs,
77
+ outputs=result,
78
+ api_name=False,
79
+ )
80
+ run_button.click(
81
+ fn=randomize_seed_fn,
82
+ inputs=[seed, randomize_seed],
83
+ outputs=seed,
84
+ queue=False,
85
+ api_name=False,
86
+ ).then(
87
+ fn=process,
88
+ inputs=inputs,
89
+ outputs=result,
90
+ api_name="canny",
91
+ )
92
+ return demo
93
+ if __name__ == "__main__":
94
+ from model import Model
95
+ model = Model()
96
+ demo = create_demo(model.process_canny)
97
+ demo.queue().launch(
98
+ share=False,
99
+ server_name="0.0.0.0"
100
+ )
app_depth.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
4
+ if randomize_seed:
5
+ seed = random.randint(0, 100000000)
6
+ return seed
7
+ examples = [
8
+ [
9
+ "condition/example/t2i/multigen/sofa.png",
10
+ "The red sofa in the living room has several pillows on it",
11
+ "(512, 512)"
12
+ ],
13
+ [
14
+ "condition/example/t2i/multigen/house.png",
15
+ "A brick house with a chimney under a starry sky.",
16
+ "(512, 512)"
17
+ ],
18
+ [
19
+ "condition/example/t2i/multi_resolution/car.jpg",
20
+ "a sport car",
21
+ "(448, 768)"
22
+ ]
23
+ ]
24
+ def create_demo(process):
25
+ with gr.Blocks() as demo:
26
+ with gr.Row():
27
+ with gr.Column():
28
+ image = gr.Image()
29
+ prompt = gr.Textbox(label="Prompt")
30
+ run_button = gr.Button("Run")
31
+ with gr.Accordion("Advanced options", open=False):
32
+ cfg_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=4, step=0.1)
33
+ resolution = gr.Slider(label="(H, W)", minimum=384, maximum=768, value=512, step=16)
34
+ top_k = gr.Slider(minimum=1, maximum=16384, step=1, value=2000, label='Top-K')
35
+ top_p = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label="Top-P")
36
+ temperature = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label='Temperature')
37
+ seed = gr.Slider(label="Seed", minimum=0, maximum=100000000, step=1, value=0)
38
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
39
+ with gr.Column():
40
+ result = gr.Gallery(label="Output", show_label=False, height='800px', columns=2, object_fit="scale-down")
41
+ gr.Examples(
42
+ examples=examples,
43
+ inputs=[
44
+ image,
45
+ prompt,
46
+ resolution,
47
+ ],
48
+ outputs=result,
49
+ fn=process,
50
+ )
51
+ inputs = [
52
+ image,
53
+ prompt,
54
+ cfg_scale,
55
+ temperature,
56
+ top_k,
57
+ top_p,
58
+ seed,
59
+ ]
60
+ prompt.submit(
61
+ fn=randomize_seed_fn,
62
+ inputs=[seed, randomize_seed],
63
+ outputs=seed,
64
+ queue=False,
65
+ api_name=False,
66
+ ).then(
67
+ fn=process,
68
+ inputs=inputs,
69
+ outputs=result,
70
+ api_name=False,
71
+ )
72
+ run_button.click(
73
+ fn=randomize_seed_fn,
74
+ inputs=[seed, randomize_seed],
75
+ outputs=seed,
76
+ queue=False,
77
+ api_name=False,
78
+ ).then(
79
+ fn=process,
80
+ inputs=inputs,
81
+ outputs=result,
82
+ api_name="canny",
83
+ )
84
+ return demo
85
+ if __name__ == "__main__":
86
+ from model import Model
87
+ model = Model()
88
+ demo = create_demo(model.process_depth)
89
+ demo.queue().launch(
90
+ share=False,
91
+ server_name="0.0.0.0"
92
+ )
autoregressive/models/README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Download the vit weight first
2
+
3
+ ViT-small: https://huggingface.co/WinKawaks/vit-small-patch16-224 \
4
+ Dinov2-small: https://huggingface.co/facebook/dinov2-small
5
+
6
+ Put them here
autoregressive/models/dinov2_adapter.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoImageProcessor, AutoModel
2
+ from PIL import Image
3
+ import requests
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class Dinov2_Adapter(nn.Module):
9
+ def __init__(self, input_dim=1, output_dim=768, attention=False, pool=False, nheads=8, dropout=0.1, adapter_size='small', condition_type='canny'):
10
+ super(Dinov2_Adapter, self).__init__()
11
+ print(f"Choose adapter size: {adapter_size}")
12
+ print(f"condition type: {condition_type}")
13
+ self.model = AutoModel.from_pretrained(f'autoregressive/models/dinov2-{adapter_size}')
14
+ self.condition_type = condition_type
15
+
16
+ def to_patch14(self, input):
17
+ H, W = input.shape[2:]
18
+ new_H = (H // 16) * 14
19
+ new_W = (W // 16) * 14
20
+ if self.condition_type in ['canny', 'seg']:
21
+ output = torch.nn.functional.interpolate(input, size=(new_H, new_W), mode='nearest')#, align_corners=True) canny, seg
22
+ else:
23
+ output = torch.nn.functional.interpolate(input, size=(new_H, new_W), mode='bicubic', align_corners=True) # depth, lineart, hed
24
+ return output
25
+
26
+ def forward(self, x):
27
+ x = self.to_patch14(x)
28
+ x = self.model(x)
29
+ return x.last_hidden_state[:, 1:]
30
+
31
+
32
+ if __name__ == '__main__':
33
+ model = Dinov2_Adapter().cuda()
34
+ inputs = torch.randn(4,3,512,512).cuda()
35
+ outputs = model(inputs)
36
+ print(outputs.shape)
autoregressive/models/generate.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
3
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+ import torch._dynamo.config
8
+ import torch._inductor.config
9
+ import copy
10
+ import time
11
+ # torch._inductor.config.coordinate_descent_tuning = True
12
+ # torch._inductor.config.triton.unique_kernel_names = True
13
+ # torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
14
+
15
+
16
+ ### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html
17
+ def top_k_top_p_filtering(
18
+ logits,
19
+ top_k: int = 0,
20
+ top_p: float = 1.0,
21
+ filter_value: float = -float("Inf"),
22
+ min_tokens_to_keep: int = 1,
23
+ ):
24
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
25
+ Args:
26
+ logits: logits distribution shape (batch size, vocabulary size)
27
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
28
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
29
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
30
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
31
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
32
+ """
33
+ if top_k > 0:
34
+ # import pdb;pdb.set_trace()
35
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
36
+ # Remove all tokens with a probability less than the last token of the top-k
37
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
38
+ logits[indices_to_remove] = filter_value
39
+
40
+ if top_p < 1.0:
41
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
42
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
43
+
44
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
45
+ sorted_indices_to_remove = cumulative_probs > top_p
46
+ if min_tokens_to_keep > 1:
47
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
48
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
49
+ # Shift the indices to the right to keep also the first token above the threshold
50
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
51
+ sorted_indices_to_remove[..., 0] = 0
52
+
53
+ # scatter sorted tensors to original indexing
54
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
55
+ logits[indices_to_remove] = filter_value
56
+ return logits
57
+
58
+
59
+ def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sample_logits=True):
60
+ logits = logits[:, -1, :] / max(temperature, 1e-5)
61
+ if top_k > 0 or top_p < 1.0:
62
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
63
+ probs = F.softmax(logits, dim=-1)
64
+ # values, indices = torch.max(probs, dim=1, keepdim=True)
65
+ # mask = (probs == values).float()
66
+ # probs = probs * (1 - mask)
67
+ # values, indices = torch.max(probs, dim=1, keepdim=True)
68
+ # mask = (probs == values).float()
69
+ # probs = probs * (1 - mask)
70
+ if sample_logits:
71
+ idx = torch.multinomial(probs, num_samples=1)
72
+ else:
73
+ _, idx = torch.topk(probs, k=1, dim=-1)
74
+ return idx, probs
75
+
76
+
77
+ def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: int = None, **kwargs):
78
+ logits = logits / max(temperature, 1e-5)
79
+ if top_k > 0 or top_p < 1.0:
80
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
81
+ probs = torch.nn.functional.softmax(logits, dim=-1)
82
+ return probs
83
+
84
+
85
+ def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, **sampling_kwargs):
86
+ if cfg_scale > 1.0:
87
+ logits, _ = model(None, cond_idx, input_pos, condition=condition)
88
+ logits_combined = logits
89
+ cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
90
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
91
+ else:
92
+ logits, _ = model(None, cond_idx, input_pos, condition=condition)
93
+
94
+ return sample(logits, **sampling_kwargs)[0]
95
+
96
+
97
+ def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, condition: torch.Tensor, **sampling_kwargs):
98
+ assert input_pos.shape[-1] == 1
99
+ if cfg_scale > 1.0:
100
+ x_combined = torch.cat([x, x])
101
+ logits, _ = model(x_combined, cond_idx=None, input_pos=input_pos, condition=condition)
102
+ logits_combined = logits
103
+ cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
104
+ if cfg_flag:
105
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
106
+ else:
107
+ logits = cond_logits
108
+ else:
109
+ logits, _ = model(x, cond_idx=None, input_pos=input_pos, condition=None)
110
+ return sample(logits, **sampling_kwargs)
111
+
112
+
113
+ def decode_n_tokens(
114
+ model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
115
+ cfg_scale: float, cfg_interval: int, condition: torch.Tensor,
116
+ **sampling_kwargs):
117
+ new_tokens, new_probs = [], []
118
+ cfg_flag = True
119
+ for i in range(num_new_tokens):
120
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
121
+ if cfg_interval > -1 and i > cfg_interval:
122
+ cfg_flag = False
123
+ next_token, next_prob = decode_one_token(
124
+ model, cur_token, input_pos, cfg_scale, cfg_flag, condition=condition, **sampling_kwargs
125
+ )
126
+ input_pos += 1
127
+ new_tokens.append(next_token.clone())
128
+ new_probs.append(next_prob.clone())
129
+ cur_token = next_token.view(-1, 1)
130
+
131
+ return new_tokens, new_probs
132
+
133
+
134
+ @torch.no_grad()
135
+ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
136
+ if condition is not None:
137
+ condition = model.adapter(condition)
138
+ condition = model.adapter_mlp(condition)
139
+ if model.model_type == 'c2i':
140
+ if cfg_scale > 1.0:
141
+ cond_null = torch.ones_like(cond) * model.num_classes
142
+ cond_combined = torch.cat([cond, cond_null])
143
+ if condition is not None:
144
+ condition_null = torch.zeros_like(condition)
145
+ condition_combined = torch.cat((condition, condition_null), dim=0)
146
+ else:
147
+ condition_combined = None
148
+ else:
149
+ cond_combined = cond
150
+ if condition is not None:
151
+ condition_combined = condition
152
+ else:
153
+ condition_combined = None
154
+ T = 1+condition_token_nums
155
+ elif model.model_type == 't2i':
156
+ if cfg_scale > 1.0:
157
+ cond_null = torch.zeros_like(cond) + model.cls_embedding.uncond_embedding
158
+ cond_combined = torch.cat([cond, cond_null])
159
+
160
+ if condition is not None:
161
+ condition_null = torch.zeros_like(condition)
162
+ condition_combined = torch.cat((condition, condition_null), dim=0)
163
+ else:
164
+ condition_combined = None
165
+ else:
166
+ cond_combined = cond
167
+ if condition is not None:
168
+ condition_combined = condition
169
+ else:
170
+ condition_combined = None
171
+ T = cond.shape[1]
172
+ else:
173
+ raise Exception("please check model type")
174
+
175
+ T_new = T + max_new_tokens
176
+ max_seq_length = T_new
177
+ max_batch_size = cond.shape[0]
178
+
179
+ device = cond.device
180
+ with torch.device(device):
181
+ max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
182
+ model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=model.tok_embeddings.weight.dtype)
183
+
184
+ if emb_masks is not None:
185
+ assert emb_masks.shape[0] == max_batch_size
186
+ assert emb_masks.shape[-1] == T
187
+ if cfg_scale > 1.0:
188
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
189
+ else:
190
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
191
+
192
+ eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
193
+ model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
194
+
195
+ # create an empty tensor of the expected final shape and fill in the current tokens
196
+ seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
197
+ input_pos = torch.arange(0, T, device=device)
198
+ next_token = prefill(model, cond_combined, input_pos, cfg_scale, condition_combined, **sampling_kwargs)
199
+ seq[:, T:T+1] = next_token
200
+
201
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
202
+ generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, condition=condition_combined, **sampling_kwargs)
203
+ seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
204
+ return seq[:, T:]
autoregressive/models/gpt_t2i.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # VQGAN: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py
3
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
4
+ # nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py
5
+ # llama: https://github.com/facebookresearch/llama/blob/main/llama/model.py
6
+ # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
7
+ # PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
8
+ from dataclasses import dataclass
9
+ from typing import Optional, List
10
+
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import functional as F
15
+ from utils.drop_path import DropPath
16
+ # from autoregressive.models.vit_adapter import ViT_Adapter
17
+ from autoregressive.models.dinov2_adapter import Dinov2_Adapter
18
+
19
+
20
+ def get_causal_mask(seq_length):
21
+ mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).type(torch.bool)
22
+ mask = mask.masked_fill(mask, float('-inf'))
23
+ mask = mask.masked_fill(~mask, float(0.0))
24
+ return mask
25
+
26
+ def find_multiple(n: int, k: int):
27
+ if n % k == 0:
28
+ return n
29
+ return n + k - (n % k)
30
+
31
+ @dataclass
32
+ class ModelArgs:
33
+ dim: int = 4096
34
+ n_layer: int = 32
35
+ n_head: int = 32
36
+ n_kv_head: Optional[int] = None
37
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
38
+ ffn_dim_multiplier: Optional[float] = None
39
+ rope_base: float = 10000
40
+ norm_eps: float = 1e-5
41
+ initializer_range: float = 0.02
42
+
43
+ token_dropout_p: float = 0.1
44
+ attn_dropout_p: float = 0.0
45
+ resid_dropout_p: float = 0.1
46
+ ffn_dropout_p: float = 0.1
47
+ drop_path_rate: float = 0.0
48
+
49
+ num_classes: int = 1000
50
+ caption_dim: int = 2048
51
+ class_dropout_prob: float = 0.1
52
+ model_type: str = 'c2i'
53
+
54
+ vocab_size: int = 16384
55
+ cls_token_num: int = 1
56
+ block_size: int = 256
57
+ max_batch_size: int = 32
58
+ max_seq_len: int = 2048
59
+ adapter_size: str = 'small'
60
+ condition_type: str = 'canny'
61
+
62
+
63
+
64
+ #################################################################################
65
+ # Embedding Layers for Class Labels #
66
+ #################################################################################
67
+ class LabelEmbedder(nn.Module):
68
+ """
69
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
70
+ """
71
+ def __init__(self, num_classes, hidden_size, dropout_prob):
72
+ super().__init__()
73
+ use_cfg_embedding = dropout_prob > 0
74
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
75
+ self.num_classes = num_classes
76
+ self.dropout_prob = dropout_prob
77
+
78
+ def token_drop(self, labels, force_drop_ids=None):
79
+ """
80
+ Drops labels to enable classifier-free guidance.
81
+ """
82
+ if force_drop_ids is None:
83
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
84
+ else:
85
+ drop_ids = force_drop_ids == 1
86
+ labels = torch.where(drop_ids, self.num_classes, labels)
87
+ return labels, drop_ids
88
+
89
+ def forward(self, labels, train, force_drop_ids=None):
90
+ use_dropout = self.dropout_prob > 0
91
+ if (train and use_dropout) or (force_drop_ids is not None):
92
+ labels,drop_ids = self.token_drop(labels, force_drop_ids)
93
+ embeddings = self.embedding_table(labels).unsqueeze(1)
94
+ if (train and use_dropout) or (force_drop_ids is not None):
95
+ return embeddings,drop_ids
96
+ else:
97
+ return embeddings
98
+
99
+
100
+ class ConditionEmbedder(nn.Module):
101
+ """
102
+ Embeds Condition into vector representations. Also handles label dropout for classifier-free guidance.
103
+ """
104
+ def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120, vocab_size=16384):
105
+ super().__init__()
106
+ self.cap_proj = MLP(in_features=hidden_size, hidden_features=hidden_size, out_features=hidden_size)
107
+ self.register_buffer("uncond_embedding", torch.zeros(token_num, hidden_size) / hidden_size ** 0.5)
108
+ self.uncond_prob = uncond_prob
109
+
110
+ def token_drop(self, caption, force_drop_ids=None, drop_ids=None):
111
+ """
112
+ Drops labels to enable classifier-free guidance.
113
+ """
114
+ if force_drop_ids is None:
115
+ if drop_ids is None:
116
+ drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob
117
+ else:
118
+ drop_ids = force_drop_ids == 1
119
+
120
+ caption = torch.where(drop_ids[:, None, None], self.uncond_embedding[:caption.shape[1]], caption)
121
+ return caption
122
+
123
+ def forward(self, caption, train, force_drop_ids=None, drop_ids=None):
124
+ use_dropout = self.uncond_prob > 0
125
+ if (train and use_dropout) or (force_drop_ids is not None):
126
+ caption = self.token_drop(caption, force_drop_ids, drop_ids)
127
+ embeddings = self.cap_proj(caption)
128
+ return embeddings
129
+
130
+ #################################################################################
131
+ # Embedding Layers for Text Feature #
132
+ #################################################################################
133
+ class CaptionEmbedder(nn.Module):
134
+ """
135
+ Embeds text caption into vector representations. Also handles label dropout for classifier-free guidance.
136
+ """
137
+ def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120):
138
+ super().__init__()
139
+ self.cap_proj = MLP(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size)
140
+ self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
141
+ self.uncond_prob = uncond_prob
142
+
143
+ def token_drop(self, caption, force_drop_ids=None):
144
+ """
145
+ Drops labels to enable classifier-free guidance.
146
+ """
147
+ if force_drop_ids is None:
148
+ drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob
149
+ else:
150
+ drop_ids = force_drop_ids == 1
151
+ caption = torch.where(drop_ids[:, None, None], self.uncond_embedding, caption)
152
+ return caption, drop_ids
153
+
154
+ def forward(self, caption, train, force_drop_ids=None):
155
+ use_dropout = self.uncond_prob > 0
156
+ if (train and use_dropout) or (force_drop_ids is not None):
157
+ caption, drop_ids = self.token_drop(caption, force_drop_ids)
158
+ embeddings = self.cap_proj(caption)
159
+ if (train and use_dropout) or (force_drop_ids is not None):
160
+ return embeddings,drop_ids
161
+ else:
162
+ return embeddings
163
+
164
+
165
+ class MLP(nn.Module):
166
+ def __init__(self, in_features, hidden_features, out_features):
167
+ super().__init__()
168
+ out_features = out_features or in_features
169
+ hidden_features = hidden_features or in_features
170
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
171
+ self.act = nn.GELU(approximate='tanh')
172
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
173
+
174
+ nn.init.zeros_(self.fc1.weight)
175
+ nn.init.zeros_(self.fc2.weight)
176
+
177
+ def forward(self, x):
178
+ x = self.fc1(x)
179
+ x = self.act(x)
180
+ x = self.fc2(x)
181
+ return x
182
+
183
+
184
+ #################################################################################
185
+ # GPT Model #
186
+ #################################################################################
187
+ class RMSNorm(torch.nn.Module):
188
+ def __init__(self, dim: int, eps: float = 1e-5):
189
+ super().__init__()
190
+ self.eps = eps
191
+ self.weight = nn.Parameter(torch.ones(dim))
192
+
193
+ def _norm(self, x):
194
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
195
+
196
+ def forward(self, x):
197
+ output = self._norm(x.float()).type_as(x)
198
+ return output * self.weight
199
+
200
+
201
+ class FeedForward(nn.Module):
202
+ def __init__(self, config: ModelArgs):
203
+ super().__init__()
204
+ hidden_dim = 4 * config.dim
205
+ hidden_dim = int(2 * hidden_dim / 3)
206
+ # custom dim factor multiplier
207
+ if config.ffn_dim_multiplier is not None:
208
+ hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
209
+ hidden_dim = find_multiple(hidden_dim, config.multiple_of)
210
+
211
+ self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
212
+ self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
213
+ self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
214
+ self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)
215
+
216
+ def forward(self, x):
217
+ return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
218
+
219
+
220
+ class KVCache(nn.Module):
221
+ def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
222
+ super().__init__()
223
+ cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
224
+ self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
225
+ self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
226
+
227
+ def update(self, input_pos, k_val, v_val):
228
+ # input_pos: [S], k_val: [B, H, S, D]
229
+ assert input_pos.shape[0] == k_val.shape[2]
230
+ k_out = self.k_cache
231
+ v_out = self.v_cache
232
+ k_out[:, :, input_pos] = k_val
233
+ v_out[:, :, input_pos] = v_val
234
+
235
+ return k_out, v_out
236
+
237
+
238
+ class Attention(nn.Module):
239
+ def __init__(self, config: ModelArgs):
240
+ super().__init__()
241
+ assert config.dim % config.n_head == 0
242
+ self.dim = config.dim
243
+ self.head_dim = config.dim // config.n_head
244
+ self.n_head = config.n_head
245
+ self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
246
+ total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim
247
+
248
+ # key, query, value projections for all heads, but in a batch
249
+ self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
250
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
251
+ self.kv_cache = None
252
+
253
+ # regularization
254
+ self.attn_dropout_p = config.attn_dropout_p
255
+ self.resid_dropout = nn.Dropout(config.resid_dropout_p)
256
+
257
+ def forward(
258
+ self, x: torch.Tensor, freqs_cis: torch.Tensor = None,
259
+ input_pos: Optional[torch.Tensor] = None,
260
+ mask: Optional[torch.Tensor] = None
261
+ ):
262
+ bsz, seqlen, _ = x.shape
263
+ kv_size = self.n_kv_head * self.head_dim
264
+ xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
265
+
266
+ xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
267
+ xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
268
+ xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)
269
+
270
+ xq = apply_rotary_emb(xq, freqs_cis)
271
+ xk = apply_rotary_emb(xk, freqs_cis)
272
+
273
+ xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
274
+
275
+ if self.kv_cache is not None:
276
+ keys, values = self.kv_cache.update(input_pos, xk, xv)
277
+ else:
278
+ keys, values = xk, xv
279
+ keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
280
+ values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
281
+
282
+ output = F.scaled_dot_product_attention(
283
+ xq, keys, values,
284
+ attn_mask=mask,
285
+ is_causal=True if mask is None else False, # is_causal=False is for KV cache
286
+ dropout_p=self.attn_dropout_p if self.training else 0)
287
+
288
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
289
+
290
+ output = self.resid_dropout(self.wo(output))
291
+ return output
292
+
293
+
294
+ class TransformerBlock(nn.Module):
295
+ def __init__(self, config: ModelArgs, drop_path: float):
296
+ super().__init__()
297
+ self.attention = Attention(config)
298
+ self.feed_forward = FeedForward(config)
299
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
300
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
301
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
302
+
303
+ def forward(
304
+ self, x: torch.Tensor, freqs_cis: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
305
+ h = x + self.drop_path(self.attention(self.attention_norm(x), freqs_cis, start_pos, mask))
306
+ out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
307
+ return out
308
+
309
+
310
+ class Transformer(nn.Module):
311
+ def __init__(self, config: ModelArgs):
312
+ super().__init__()
313
+ self.config = config
314
+ self.vocab_size = config.vocab_size
315
+ self.n_layer = config.n_layer
316
+ self.block_size = config.block_size
317
+ self.num_classes = config.num_classes
318
+ self.model_type = config.model_type
319
+ self.cls_token_num = config.cls_token_num
320
+ self.layer_internal = config.n_layer // 3
321
+ # self.adapter = Adapter(output_dim=768)
322
+ # self.adapter = ViT_Adapter()
323
+ # self.adapter = DeiT_Adapter()
324
+ self.adapter = Dinov2_Adapter(adapter_size=config.adapter_size, condition_type=config.condition_type)
325
+ # self.adapter = EVA_Adapter()
326
+ if config.adapter_size == "small":
327
+ self.adapter_mlp = MLP(384, config.dim, config.dim)
328
+ elif config.adapter_size == 'base':
329
+ self.adapter_mlp = MLP(768, config.dim, config.dim)
330
+
331
+ if self.model_type == 'c2i':
332
+ self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
333
+ elif self.model_type == 't2i':
334
+ self.cls_embedding = CaptionEmbedder(config.caption_dim, config.dim, config.class_dropout_prob)
335
+ else:
336
+ raise Exception("please check model type")
337
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
338
+ self.tok_dropout = nn.Dropout(config.token_dropout_p)
339
+
340
+ self.condition_embeddings = nn.Embedding(config.vocab_size, config.dim)
341
+ self.condition_mlp = ConditionEmbedder(self.block_size, config.dim, config.class_dropout_prob, self.block_size, config.vocab_size)
342
+ self.condition_layers = torch.nn.ModuleList()
343
+ for layer_id in range(3):
344
+ self.condition_layers.append(MLP(config.dim,config.dim,config.dim))
345
+
346
+ # transformer blocks
347
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)]
348
+ self.layers = torch.nn.ModuleList()
349
+ for layer_id in range(config.n_layer):
350
+ self.layers.append(TransformerBlock(config, dpr[layer_id]))
351
+
352
+ # output layer
353
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
354
+ self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
355
+
356
+ # 2d rotary pos embedding
357
+ grid_size = int(self.block_size ** 0.5)
358
+ assert grid_size * grid_size == self.block_size
359
+ self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num)
360
+
361
+ # KVCache
362
+ self.max_batch_size = -1
363
+ self.max_seq_length = -1
364
+
365
+ self.initialize_weights()
366
+ self.condition_token = None
367
+ self.mask = get_causal_mask(256)
368
+ self.global_token = None
369
+
370
+
371
+ def initialize_weights(self):
372
+ # Initialize nn.Linear and nn.Embedding
373
+ self.apply(self._init_weights)
374
+
375
+ # Zero-out output layers:
376
+ nn.init.constant_(self.output.weight, 0)
377
+
378
+
379
+
380
+ def _init_weights(self, module):
381
+ std = self.config.initializer_range
382
+ if isinstance(module, nn.Linear):
383
+ module.weight.data.normal_(mean=0.0, std=std)
384
+ if module.bias is not None:
385
+ module.bias.data.zero_()
386
+ elif isinstance(module, nn.Embedding):
387
+ module.weight.data.normal_(mean=0.0, std=std)
388
+
389
+
390
+ def setup_caches(self, max_batch_size, max_seq_length, dtype):
391
+ # if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
392
+ # return
393
+ head_dim = self.config.dim // self.config.n_head
394
+ max_seq_length = find_multiple(max_seq_length, 8) #
395
+ self.max_seq_length = max_seq_length
396
+ self.max_batch_size = max_batch_size
397
+ for b in self.layers:
398
+ b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype)
399
+
400
+ causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
401
+ self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1)
402
+ grid_size = int(self.config.block_size ** 0.5)
403
+ assert grid_size * grid_size == self.block_size
404
+ self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num)
405
+
406
+
407
+
408
+ def forward(
409
+ self,
410
+ idx: torch.Tensor,
411
+ cond_idx: torch.Tensor, # cond_idx_or_embed
412
+ input_pos: Optional[torch.Tensor] = None,
413
+ targets: Optional[torch.Tensor] = None,
414
+ mask: Optional[torch.Tensor] = None,
415
+ valid: Optional[torch.Tensor] = None,
416
+ condition: Optional[torch.Tensor] = None
417
+ ):
418
+ if idx is not None and cond_idx is not None: # training or naive inference
419
+ cond_embeddings,drop_ids = self.cls_embedding(cond_idx, train=self.training)
420
+ cond_embeddings = cond_embeddings[:,:self.cls_token_num]
421
+ token_embeddings = self.tok_embeddings(idx)
422
+ if condition is not None:
423
+ condition_embeddings = self.adapter(condition)
424
+ condition_embeddings = self.adapter_mlp(condition_embeddings)
425
+ self.condition_token = self.condition_mlp(condition_embeddings,train=self.training, drop_ids=drop_ids)
426
+ token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)
427
+
428
+ h = self.tok_dropout(token_embeddings)
429
+ self.freqs_cis = self.freqs_cis.to(h.device)
430
+ else:
431
+ if cond_idx is not None: # prefill in inference
432
+ token_embeddings = self.cls_embedding(cond_idx, train=self.training)
433
+ token_embeddings = token_embeddings[:,:self.cls_token_num]
434
+ if condition is not None:
435
+ condition_embeddings = self.condition_mlp(condition.to(torch.bfloat16),train=self.training)
436
+ self.condition_token = condition_embeddings
437
+
438
+ else: # decode_n_tokens(kv cache) in inference
439
+ token_embeddings = self.tok_embeddings(idx)
440
+ bs = token_embeddings.shape[0]
441
+ mask = self.causal_mask[:bs, None, input_pos]
442
+ h = self.tok_dropout(token_embeddings)
443
+ self.freqs_cis = self.freqs_cis
444
+
445
+ if self.training:
446
+ freqs_cis = self.freqs_cis[:token_embeddings.shape[1]]
447
+ else:
448
+ freqs_cis = self.freqs_cis[input_pos]
449
+ # transformer blocks
450
+ for i, layer in enumerate(self.layers):
451
+ if i%self.layer_internal == 0:
452
+ if self.training:
453
+ h[:, self.cls_token_num-1:] = h[:, self.cls_token_num-1:] + self.condition_layers[i//self.layer_internal](self.condition_token)
454
+ else:
455
+ if len(input_pos)>1:
456
+ h[:, -1:] = h[:, -1:] + self.condition_layers[i//self.layer_internal](self.condition_token[:,0:1])
457
+ else:
458
+ h = h + self.condition_layers[i//self.layer_internal](self.condition_token[:,input_pos-self.cls_token_num+1])
459
+ h = layer(h, freqs_cis, input_pos, mask)
460
+ # output layers
461
+ h = self.norm(h)
462
+ logits = self.output(h).float()
463
+
464
+ if self.training:
465
+ logits = logits[:, self.cls_token_num - 1:].contiguous()
466
+ # if we are given some desired targets also calculate the loss
467
+ loss = None
468
+ if valid is not None:
469
+ loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
470
+ valid_all = valid[:,None].repeat(1, targets.shape[1]).view(-1)
471
+ loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1)
472
+ elif targets is not None:
473
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
474
+
475
+
476
+ return logits, loss
477
+
478
+
479
+ def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
480
+ return list(self.layers)
481
+
482
+
483
+
484
+ #################################################################################
485
+ # Rotary Positional Embedding Functions #
486
+ #################################################################################
487
+ # https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
488
+ def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120):
489
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
490
+ t = torch.arange(seq_len, device=freqs.device)
491
+ freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
492
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
493
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2)
494
+ cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
495
+ return cond_cache
496
+
497
+
498
+ def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
499
+ # split the dimension into half, one for x and one for y
500
+ half_dim = n_elem // 2
501
+ freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
502
+ t = torch.arange(grid_size, device=freqs.device)
503
+ freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
504
+ freqs_grid = torch.concat([
505
+ freqs[:, None, :].expand(-1, grid_size, -1),
506
+ freqs[None, :, :].expand(grid_size, -1, -1),
507
+ ], dim=-1) # (grid_size, grid_size, head_dim // 2)
508
+ cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
509
+ cache = cache_grid.flatten(0, 1)
510
+ cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
511
+ return cond_cache
512
+
513
+
514
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
515
+ # x: (bs, seq_len, n_head, head_dim)
516
+ # freqs_cis (seq_len, head_dim // 2, 2)
517
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
518
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
519
+ x_out2 = torch.stack([
520
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
521
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
522
+ ], dim=-1)
523
+ x_out2 = x_out2.flatten(3)
524
+ return x_out2.type_as(x)
525
+
526
+
527
+
528
+ #################################################################################
529
+ # GPT Configs #
530
+ #################################################################################
531
+ ### text-conditional
532
+ def GPT_7B(**kwargs):
533
+ return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B
534
+
535
+ def GPT_3B(**kwargs):
536
+ return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B
537
+
538
+ def GPT_1B(**kwargs):
539
+ return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B
540
+
541
+ ### class-conditional
542
+ def GPT_XXXL(**kwargs):
543
+ return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B
544
+
545
+ def GPT_XXL(**kwargs):
546
+ return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B
547
+
548
+ def GPT_XL(**kwargs):
549
+ return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M
550
+
551
+ def GPT_L(**kwargs):
552
+ return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M
553
+
554
+ def GPT_B(**kwargs):
555
+ return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
556
+
557
+
558
+ GPT_models = {
559
+ 'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
560
+ 'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B,
561
+ }
autoregressive/sample/sample_c2i.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/sample.py
3
+ import torch
4
+ torch.backends.cuda.matmul.allow_tf32 = True
5
+ torch.backends.cudnn.allow_tf32 = True
6
+ torch.set_float32_matmul_precision('high')
7
+ setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
8
+ setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
9
+ from torchvision.utils import save_image
10
+ import os
11
+ import sys
12
+ current_directory = os.getcwd()
13
+ sys.path.append(current_directory)
14
+
15
+ from PIL import Image
16
+ import time
17
+ import argparse
18
+ from tokenizer.tokenizer_image.vq_model import VQ_models
19
+ from autoregressive.models.gpt import GPT_models
20
+ from autoregressive.models.generate import generate
21
+ from functools import partial
22
+ import torch.nn.functional as F
23
+ import numpy as np
24
+ import cv2
25
+
26
+
27
+ def main(args):
28
+ # Setup PyTorch:
29
+ torch.manual_seed(args.seed)
30
+ torch.backends.cudnn.deterministic = True
31
+ torch.backends.cudnn.benchmark = False
32
+ torch.set_grad_enabled(False)
33
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
34
+
35
+ # create and load model
36
+ vq_model = VQ_models[args.vq_model](
37
+ codebook_size=args.codebook_size,
38
+ codebook_embed_dim=args.codebook_embed_dim)
39
+ vq_model.to(device)
40
+ vq_model.eval()
41
+ checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
42
+ vq_model.load_state_dict(checkpoint["model"])
43
+ del checkpoint
44
+ print(f"image tokenizer is loaded")
45
+
46
+ # create and load gpt model
47
+ precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
48
+ latent_size = args.image_size // args.downsample_size
49
+ gpt_model = GPT_models[args.gpt_model](
50
+ vocab_size=args.codebook_size,
51
+ block_size=latent_size ** 2,
52
+ num_classes=args.num_classes,
53
+ cls_token_num=args.cls_token_num,
54
+ model_type=args.gpt_type,
55
+ condition_token_num=args.condition_token_nums,
56
+ image_size=args.image_size
57
+ ).to(device=device, dtype=precision)
58
+
59
+ _, file_extension = os.path.splitext(args.gpt_ckpt)
60
+ if file_extension.lower() == '.safetensors':
61
+ from safetensors.torch import load_file
62
+ model_weight = load_file(args.gpt_ckpt)
63
+ gpt_model.load_state_dict(model_weight, strict=False)
64
+ gpt_model.eval()
65
+ else:
66
+ checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
67
+ if "model" in checkpoint: # ddp
68
+ model_weight = checkpoint["model"]
69
+ elif "module" in checkpoint: # deepspeed
70
+ model_weight = checkpoint["module"]
71
+ elif "state_dict" in checkpoint:
72
+ model_weight = checkpoint["state_dict"]
73
+ else:
74
+ raise Exception("please check model weight")
75
+ gpt_model.load_state_dict(model_weight, strict=False)
76
+ gpt_model.eval()
77
+ del checkpoint
78
+ print(f"gpt model is loaded")
79
+
80
+ if args.compile:
81
+ print(f"compiling the model...")
82
+ gpt_model = torch.compile(
83
+ gpt_model,
84
+ mode="reduce-overhead",
85
+ fullgraph=True
86
+ ) # requires PyTorch 2.0 (optional)
87
+ else:
88
+ print(f"no need to compile model in demo")
89
+
90
+ condition_null = None
91
+ if args.condition_type == 'canny':
92
+ sample_list = [650, 2312, 15000, 48850] # canny
93
+ elif args.condition_type == 'depth':
94
+ sample_list = [101, 4351, 10601, 48901]
95
+
96
+ class_labels = [np.load(f"condition/example/c2i/{args.condition_type}/{i}.npy")[0] for i in sample_list]
97
+ condition_imgs = [np.array(Image.open((f"condition/example/c2i/{args.condition_type}/{i}.png")))[None,None,...] for i in sample_list]
98
+ condition_imgs = torch.from_numpy(np.concatenate(condition_imgs, axis=0)).to(device).to(torch.float32)/255
99
+ condition_imgs = 2*(condition_imgs-0.5)
100
+ print(condition_imgs.shape)
101
+ c_indices = torch.tensor(class_labels, device=device)
102
+ qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
103
+ t1 = time.time()
104
+
105
+ index_sample = generate(
106
+ gpt_model, c_indices, latent_size ** 2, condition=condition_imgs.repeat(1,3,1,1).to(precision), condition_null=condition_null, condition_token_nums=args.condition_token_nums,
107
+ cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval,
108
+ temperature=args.temperature, top_k=args.top_k,
109
+ top_p=args.top_p, sample_logits=True,
110
+ )
111
+
112
+ sampling_time = time.time() - t1
113
+ print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
114
+
115
+ t2 = time.time()
116
+ samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
117
+ decoder_time = time.time() - t2
118
+ print(f"decoder takes about {decoder_time:.2f} seconds.")
119
+ # Save and display images:
120
+ condition_imgs = condition_imgs.repeat(1,3,1,1)
121
+ samples = torch.cat((condition_imgs[:4], samples[:4]),dim=0)
122
+ save_image(samples, f"sample/example/sample_{args.gpt_type}_{args.condition_type}.png", nrow=4, normalize=True, value_range=(-1, 1))
123
+
124
+
125
+
126
+ if __name__ == "__main__":
127
+ parser = argparse.ArgumentParser()
128
+ parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
129
+ parser.add_argument("--gpt-ckpt", type=str, default=None)
130
+ parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
131
+ parser.add_argument("--from-fsdp", action='store_true')
132
+ parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
133
+ parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
134
+ parser.add_argument("--compile", action='store_true', default=False)
135
+ parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
136
+ parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
137
+ parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
138
+ parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
139
+ parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=256)
140
+ parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
141
+ parser.add_argument("--num-classes", type=int, default=1000)
142
+ parser.add_argument("--cfg-scale", type=float, default=4.0)
143
+ parser.add_argument("--cfg-interval", type=float, default=-1)
144
+ parser.add_argument("--seed", type=int, default=0)
145
+ parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with")
146
+ parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
147
+ parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
148
+ parser.add_argument("--condition-token-nums", type=int, default=0)
149
+ parser.add_argument("--condition-type", type=str, default='canny', choices=['canny', 'depth'])
150
+ args = parser.parse_args()
151
+ main(args)
autoregressive/sample/sample_c2i_ddp.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py
3
+ import torch
4
+ torch.backends.cuda.matmul.allow_tf32 = True
5
+ torch.backends.cudnn.allow_tf32 = True
6
+ import torch.nn.functional as F
7
+ import torch.distributed as dist
8
+
9
+ from tqdm import tqdm
10
+ import os
11
+ from PIL import Image
12
+ import numpy as np
13
+ import math
14
+ import argparse
15
+
16
+ from tokenizer.tokenizer_image.vq_model import VQ_models
17
+ from autoregressive.models.gpt import GPT_models
18
+ from autoregressive.models.generate import generate
19
+
20
+
21
+ def create_npz_from_sample_folder(sample_dir, num=50_000):
22
+ """
23
+ Builds a single .npz file from a folder of .png samples.
24
+ """
25
+ samples = []
26
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
27
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
28
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
29
+ samples.append(sample_np)
30
+ samples = np.stack(samples)
31
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
32
+ npz_path = f"{sample_dir}.npz"
33
+ np.savez(npz_path, arr_0=samples)
34
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
35
+ return npz_path
36
+
37
+
38
+ def main(args):
39
+ # Setup PyTorch:
40
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
41
+ torch.set_grad_enabled(False)
42
+
43
+ # Setup DDP:
44
+ dist.init_process_group("nccl")
45
+ rank = dist.get_rank()
46
+ device = rank % torch.cuda.device_count()
47
+ seed = args.global_seed * dist.get_world_size() + rank
48
+ torch.manual_seed(seed)
49
+ torch.cuda.set_device(device)
50
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
51
+
52
+ # create and load model
53
+ vq_model = VQ_models[args.vq_model](
54
+ codebook_size=args.codebook_size,
55
+ codebook_embed_dim=args.codebook_embed_dim)
56
+ vq_model.to(device)
57
+ vq_model.eval()
58
+ checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
59
+ vq_model.load_state_dict(checkpoint["model"])
60
+ del checkpoint
61
+
62
+ # create and load gpt model
63
+ precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
64
+ latent_size = args.image_size // args.downsample_size
65
+ gpt_model = GPT_models[args.gpt_model](
66
+ vocab_size=args.codebook_size,
67
+ block_size=latent_size ** 2,
68
+ num_classes=args.num_classes,
69
+ cls_token_num=args.cls_token_num,
70
+ model_type=args.gpt_type,
71
+ ).to(device=device, dtype=precision)
72
+ checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
73
+ if args.from_fsdp: # fsdp
74
+ model_weight = checkpoint
75
+ elif "model" in checkpoint: # ddp
76
+ model_weight = checkpoint["model"]
77
+ elif "module" in checkpoint: # deepspeed
78
+ model_weight = checkpoint["module"]
79
+ elif "state_dict" in checkpoint:
80
+ model_weight = checkpoint["state_dict"]
81
+ else:
82
+ raise Exception("please check model weight, maybe add --from-fsdp to run command")
83
+ # if 'freqs_cis' in model_weight:
84
+ # model_weight.pop('freqs_cis')
85
+ gpt_model.load_state_dict(model_weight, strict=False)
86
+ gpt_model.eval()
87
+ del checkpoint
88
+
89
+ if args.compile:
90
+ print(f"compiling the model...")
91
+ gpt_model = torch.compile(
92
+ gpt_model,
93
+ mode="reduce-overhead",
94
+ fullgraph=True
95
+ ) # requires PyTorch 2.0 (optional)
96
+ else:
97
+ print(f"no model compile")
98
+
99
+ # Create folder to save samples:
100
+ model_string_name = args.gpt_model.replace("/", "-")
101
+ if args.from_fsdp:
102
+ ckpt_string_name = args.gpt_ckpt.split('/')[-2]
103
+ else:
104
+ ckpt_string_name = os.path.basename(args.gpt_ckpt).replace(".pth", "").replace(".pt", "")
105
+ folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-size-{args.image_size_eval}-{args.vq_model}-" \
106
+ f"topk-{args.top_k}-topp-{args.top_p}-temperature-{args.temperature}-" \
107
+ f"cfg-{args.cfg_scale}-seed-{args.global_seed}"
108
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
109
+ if rank == 0:
110
+ os.makedirs(sample_folder_dir, exist_ok=True)
111
+ print(f"Saving .png samples at {sample_folder_dir}")
112
+ dist.barrier()
113
+
114
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
115
+ n = args.per_proc_batch_size
116
+ global_batch_size = n * dist.get_world_size()
117
+ # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
118
+ total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
119
+ if rank == 0:
120
+ print(f"Total number of images that will be sampled: {total_samples}")
121
+ assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
122
+ samples_needed_this_gpu = int(total_samples // dist.get_world_size())
123
+ assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
124
+ iterations = int(samples_needed_this_gpu // n)
125
+ pbar = range(iterations)
126
+ pbar = tqdm(pbar) if rank == 0 else pbar
127
+ total = 0
128
+ for _ in pbar:
129
+ # Sample inputs:
130
+ c_indices = torch.randint(0, args.num_classes, (n,), device=device)
131
+ qzshape = [len(c_indices), args.codebook_embed_dim, latent_size, latent_size]
132
+
133
+ index_sample = generate(
134
+ gpt_model, c_indices, latent_size ** 2,
135
+ cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval,
136
+ temperature=args.temperature, top_k=args.top_k,
137
+ top_p=args.top_p, sample_logits=True,
138
+ )
139
+
140
+ samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
141
+ if args.image_size_eval != args.image_size:
142
+ samples = F.interpolate(samples, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
143
+ samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
144
+
145
+ # Save samples to disk as individual .png files
146
+ for i, sample in enumerate(samples):
147
+ index = i * dist.get_world_size() + rank + total
148
+ Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
149
+ total += global_batch_size
150
+
151
+ # Make sure all processes have finished saving their samples before attempting to convert to .npz
152
+ dist.barrier()
153
+ if rank == 0:
154
+ create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
155
+ print("Done.")
156
+ dist.barrier()
157
+ dist.destroy_process_group()
158
+
159
+
160
+
161
+ if __name__ == "__main__":
162
+ parser = argparse.ArgumentParser()
163
+ parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
164
+ parser.add_argument("--gpt-ckpt", type=str, default=None)
165
+ parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
166
+ parser.add_argument("--from-fsdp", action='store_true')
167
+ parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
168
+ parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
169
+ parser.add_argument("--compile", action='store_true', default=True)
170
+ parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
171
+ parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
172
+ parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
173
+ parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
174
+ parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384)
175
+ parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256)
176
+ parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
177
+ parser.add_argument("--num-classes", type=int, default=1000)
178
+ parser.add_argument("--cfg-scale", type=float, default=1.5)
179
+ parser.add_argument("--cfg-interval", type=float, default=-1)
180
+ parser.add_argument("--sample-dir", type=str, default="samples")
181
+ parser.add_argument("--per-proc-batch-size", type=int, default=32)
182
+ parser.add_argument("--num-fid-samples", type=int, default=5000)
183
+ parser.add_argument("--global-seed", type=int, default=0)
184
+ parser.add_argument("--top-k", type=int, default=0,help="top-k value to sample with")
185
+ parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
186
+ parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
187
+ args = parser.parse_args()
188
+ main(args)
autoregressive/sample/sample_t2i.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.backends.cuda.matmul.allow_tf32 = True
3
+ torch.backends.cudnn.allow_tf32 = True
4
+ torch.set_float32_matmul_precision('high')
5
+ setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
6
+ setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
7
+ from torchvision.utils import save_image
8
+
9
+ import os
10
+ import sys
11
+ current_directory = os.getcwd()
12
+ sys.path.append(current_directory)
13
+ import time
14
+ import argparse
15
+ from tokenizer.tokenizer_image.vq_model import VQ_models
16
+ from language.t5 import T5Embedder
17
+ from autoregressive.models.gpt import GPT_models
18
+ from autoregressive.models.gpt_t2i import GPT_models
19
+ from autoregressive.models.generate import generate
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+ from dataset.t2i_control import build_t2i_control_code
22
+ from accelerate import Accelerator
23
+ from dataset.build import build_dataset
24
+ from pathlib import Path
25
+ from accelerate.utils import ProjectConfiguration, set_seed
26
+ import torch.nn.functional as F
27
+ from condition.canny import CannyDetector
28
+ from condition.hed import HEDdetector
29
+ import numpy as np
30
+ from PIL import Image
31
+ from condition.lineart import LineArt
32
+ import cv2
33
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
34
+ def main(args):
35
+ # Setup PyTorch:
36
+ torch.manual_seed(args.seed)
37
+ torch.backends.cudnn.deterministic = True
38
+ torch.backends.cudnn.benchmark = False
39
+ torch.set_grad_enabled(False)
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+
42
+ # create and load model
43
+ vq_model = VQ_models[args.vq_model](
44
+ codebook_size=args.codebook_size,
45
+ codebook_embed_dim=args.codebook_embed_dim)
46
+ vq_model.to(device)
47
+ vq_model.eval()
48
+ checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
49
+ vq_model.load_state_dict(checkpoint["model"])
50
+ del checkpoint
51
+ print(f"image tokenizer is loaded")
52
+
53
+ # create and load gpt model
54
+ precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
55
+ latent_size = args.image_size // args.downsample_size
56
+ gpt_model = GPT_models[args.gpt_model](
57
+ block_size=latent_size ** 2,
58
+ cls_token_num=args.cls_token_num,
59
+ model_type=args.gpt_type,
60
+ condition_type=args.condition_type,
61
+ ).to(device=device, dtype=precision)
62
+
63
+ _, file_extension = os.path.splitext(args.gpt_ckpt)
64
+ if file_extension.lower() == '.safetensors':
65
+ from safetensors.torch import load_file
66
+ model_weight = load_file(args.gpt_ckpt)
67
+ gpt_model.load_state_dict(model_weight, strict=False)
68
+ gpt_model.eval()
69
+ else:
70
+ checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
71
+ if "model" in checkpoint: # ddp
72
+ model_weight = checkpoint["model"]
73
+ elif "module" in checkpoint: # deepspeed
74
+ model_weight = checkpoint["module"]
75
+ elif "state_dict" in checkpoint:
76
+ model_weight = checkpoint["state_dict"]
77
+ else:
78
+ raise Exception("please check model weight")
79
+ gpt_model.load_state_dict(model_weight, strict=False)
80
+ gpt_model.eval()
81
+ del checkpoint
82
+ print(f"gpt model is loaded")
83
+
84
+ if args.compile:
85
+ print(f"compiling the model...")
86
+ gpt_model = torch.compile(
87
+ gpt_model,
88
+ mode="reduce-overhead",
89
+ fullgraph=True
90
+ ) # requires PyTorch 2.0 (optional)
91
+ else:
92
+ print(f"no need to compile model in demo")
93
+
94
+ assert os.path.exists(args.t5_path)
95
+ t5_model = T5Embedder(
96
+ device=device,
97
+ local_cache=True,
98
+ cache_dir=args.t5_path,
99
+ dir_or_name=args.t5_model_type,
100
+ torch_dtype=precision,
101
+ model_max_length=args.t5_feature_max_len,
102
+ )
103
+
104
+
105
+ if args.condition_type == 'canny':
106
+ get_control = CannyDetector()
107
+ elif args.condition_type == 'hed':
108
+ get_control = HEDdetector().to(device).eval()
109
+ elif args.condition_type == 'lineart':
110
+ get_control = LineArt()
111
+ get_control.load_state_dict(torch.load('condition/ckpts/model.pth', map_location=torch.device('cpu')))
112
+ get_control.to(device)
113
+ elif args.condition_type == 'depth':
114
+ processor = DPTImageProcessor.from_pretrained("condition/ckpts/dpt_large")
115
+ model = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large").to(device)
116
+ with torch.no_grad():
117
+
118
+ condition_path = args.condition_path
119
+ if args.condition_type == 'seg':
120
+ condition_img = torch.from_numpy(np.array(Image.open(condition_path)))
121
+ condition_img = condition_img.permute(2,0,1).unsqueeze(0).repeat(2,1,1,1)
122
+ elif args.condition_type == 'canny':
123
+ condition_img = get_control(np.array(Image.open(condition_path)))
124
+ condition_img = torch.from_numpy(condition_img[None,None,...]).repeat(2,3,1,1)
125
+ elif args.condition_type == 'hed':
126
+ condition_img = get_control(torch.from_numpy(np.array(Image.open(condition_path))).permute(2,0,1).unsqueeze(0).to(device))
127
+ condition_img = condition_img.unsqueeze(1).repeat(2,3,1,1)
128
+ elif args.condition_type == 'lineart':
129
+ condition_img = get_control(torch.from_numpy(np.array(Image.open(condition_path))).permute(2,0,1).unsqueeze(0).to(device).float())
130
+ condition_img = condition_img.repeat(2,3,1,1) * 255
131
+ elif args.condition_type == 'depth':
132
+ images = Image.open(condition_path)
133
+ inputs = processor(images=images, return_tensors="pt", size=(512,512)).to(device)
134
+ outputs = model(**inputs)
135
+ condition_img = outputs.predicted_depth
136
+ condition_img = condition_img.unsqueeze(0).repeat(2,3,1,1)
137
+ condition_img = (condition_img * 255 / condition_img.max())
138
+ condition_img = condition_img.to(device)
139
+ condition_img = 2*(condition_img/255 - 0.5)
140
+ prompts = [args.prompt if args.prompt is not None else "a high-quality image"]
141
+ prompts = prompts * 2
142
+ caption_embs, emb_masks = t5_model.get_text_embeddings(prompts)
143
+
144
+ if not args.no_left_padding:
145
+ print(f"processing left-padding...")
146
+ # a naive way to implement left-padding
147
+ new_emb_masks = torch.flip(emb_masks, dims=[-1])
148
+ new_caption_embs = []
149
+ for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)):
150
+ valid_num = int(emb_mask.sum().item())
151
+ print(f' prompt {idx} token len: {valid_num}')
152
+ new_caption_emb = torch.cat([caption_emb[valid_num:],caption_emb[:valid_num]])
153
+ new_caption_embs.append(new_caption_emb)
154
+ new_caption_embs = torch.stack(new_caption_embs)
155
+ else:
156
+ new_caption_embs, new_emb_masks = caption_embs, emb_masks
157
+ c_indices = new_caption_embs * new_emb_masks[:,:, None]
158
+ c_emb_masks = new_emb_masks
159
+ qzshape = [len(c_indices), args.codebook_embed_dim, args.image_H//args.downsample_size, args.image_W//args.downsample_size]
160
+ t1 = time.time()
161
+ index_sample = generate(
162
+ gpt_model, c_indices, (args.image_H//args.downsample_size)*(args.image_W//args.downsample_size),#latent_size ** 2,
163
+ c_emb_masks, condition=condition_img.to(precision),
164
+ cfg_scale=args.cfg_scale,
165
+ temperature=args.temperature, top_k=args.top_k,
166
+ top_p=args.top_p, sample_logits=True,
167
+ )
168
+ sampling_time = time.time() - t1
169
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
170
+
171
+ t2 = time.time()
172
+ print(index_sample.shape)
173
+ samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
174
+ decoder_time = time.time() - t2
175
+ print(f"decoder takes about {decoder_time:.2f} seconds.")
176
+
177
+ samples = torch.cat((condition_img[0:1], samples), dim=0)
178
+ save_image(samples, f"sample/example/sample_t2i_{args.condition_type}.png", nrow=4, normalize=True, value_range=(-1, 1))
179
+ print(f"image is saved to sample/example/sample_t2i_{args.condition_type}.png")
180
+ print(prompts)
181
+
182
+
183
+ if __name__ == "__main__":
184
+ parser = argparse.ArgumentParser()
185
+ parser.add_argument("--t5-path", type=str, default='checkpoints/t5-ckpt')
186
+ parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
187
+ parser.add_argument("--t5-feature-max-len", type=int, default=120)
188
+ parser.add_argument("--t5-feature-dim", type=int, default=2048)
189
+ parser.add_argument("--no-left-padding", action='store_true', default=False)
190
+ parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
191
+ parser.add_argument("--gpt-ckpt", type=str, default=None)
192
+ parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image")
193
+ parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input")
194
+ parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
195
+ parser.add_argument("--compile", action='store_true', default=False)
196
+ parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
197
+ parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
198
+ parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
199
+ parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
200
+ parser.add_argument("--image-size", type=int, choices=[256, 320, 384, 400, 448, 512, 576, 640, 704, 768], default=768)
201
+ parser.add_argument("--image-H", type=int, default=512)
202
+ parser.add_argument("--image-W", type=int, default=512)
203
+ parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
204
+ parser.add_argument("--cfg-scale", type=float, default=4)
205
+ parser.add_argument("--seed", type=int, default=0)
206
+ parser.add_argument("--top-k", type=int, default=2000, help="top-k value to sample with")
207
+ parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
208
+ parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
209
+
210
+ parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
211
+ parser.add_argument("--condition-type", type=str, choices=['seg', 'canny', 'hed', 'lineart', 'depth'], default="canny")
212
+ parser.add_argument("--prompt", type=str, default='a high-quality image')
213
+ parser.add_argument("--condition-path", type=str, default='condition/example/t2i/multigen/landscape.png')
214
+ args = parser.parse_args()
215
+ main(args)
autoregressive/sample/sample_t2i_MR.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.backends.cuda.matmul.allow_tf32 = True
3
+ torch.backends.cudnn.allow_tf32 = True
4
+ torch.set_float32_matmul_precision('high')
5
+ setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
6
+ setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
7
+ from torchvision.utils import save_image
8
+
9
+ import os
10
+ import sys
11
+ current_directory = os.getcwd()
12
+ sys.path.append(current_directory)
13
+ import time
14
+ import argparse
15
+ from tokenizer.tokenizer_image.vq_model import VQ_models
16
+ from language.t5 import T5Embedder
17
+ from autoregressive.models.gpt_t2i import GPT_models
18
+ from autoregressive.models.generate import generate
19
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
20
+ from dataset.t2i_control import build_t2i_control_code
21
+ from accelerate import Accelerator
22
+ from dataset.build import build_dataset
23
+ from pathlib import Path
24
+ from accelerate.utils import ProjectConfiguration, set_seed
25
+ import torch.nn.functional as F
26
+ from condition.canny import CannyDetector
27
+ from condition.hed import HEDdetector
28
+ import numpy as np
29
+ from PIL import Image
30
+ from condition.lineart import LineArt
31
+ import cv2
32
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
33
+ from condition.midas.depth import MidasDetector
34
+
35
+
36
+ def resize_image_to_16_multiple(image_path, condition_type='seg'):
37
+ image = Image.open(image_path)
38
+ width, height = image.size
39
+
40
+ if condition_type == 'depth': # The depth model requires a side length that is a multiple of 32
41
+ new_width = (width + 31) // 32 * 32
42
+ new_height = (height + 31) // 32 * 32
43
+ else:
44
+ new_width = (width + 15) // 16 * 16
45
+ new_height = (height + 15) // 16 * 16
46
+
47
+ resized_image = image.resize((new_width, new_height))
48
+ return resized_image
49
+
50
+ def main(args):
51
+ # Setup PyTorch:
52
+ torch.manual_seed(args.seed)
53
+ torch.backends.cudnn.deterministic = True
54
+ torch.backends.cudnn.benchmark = False
55
+ torch.set_grad_enabled(False)
56
+ device = "cuda" if torch.cuda.is_available() else "cpu"
57
+
58
+ # create and load model
59
+ vq_model = VQ_models[args.vq_model](
60
+ codebook_size=args.codebook_size,
61
+ codebook_embed_dim=args.codebook_embed_dim)
62
+ vq_model.to(device)
63
+ vq_model.eval()
64
+ checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
65
+ vq_model.load_state_dict(checkpoint["model"])
66
+ del checkpoint
67
+ print(f"image tokenizer is loaded")
68
+
69
+ # create and load gpt model
70
+ precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
71
+ latent_size = args.image_size // args.downsample_size
72
+ gpt_model = GPT_models[args.gpt_model](
73
+ block_size=latent_size ** 2,
74
+ cls_token_num=args.cls_token_num,
75
+ model_type=args.gpt_type,
76
+ condition_type=args.condition_type,
77
+ ).to(device=device, dtype=precision)
78
+
79
+ _, file_extension = os.path.splitext(args.gpt_ckpt)
80
+ if file_extension.lower() == '.safetensors':
81
+ from safetensors.torch import load_file
82
+ model_weight = load_file(args.gpt_ckpt)
83
+ gpt_model.load_state_dict(model_weight, strict=False)
84
+ gpt_model.eval()
85
+ else:
86
+ checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
87
+ if "model" in checkpoint: # ddp
88
+ model_weight = checkpoint["model"]
89
+ elif "module" in checkpoint: # deepspeed
90
+ model_weight = checkpoint["module"]
91
+ elif "state_dict" in checkpoint:
92
+ model_weight = checkpoint["state_dict"]
93
+ else:
94
+ raise Exception("please check model weight")
95
+ gpt_model.load_state_dict(model_weight, strict=False)
96
+ gpt_model.eval()
97
+ del checkpoint
98
+ print(f"gpt model is loaded")
99
+
100
+ if args.compile:
101
+ print(f"compiling the model...")
102
+ gpt_model = torch.compile(
103
+ gpt_model,
104
+ mode="reduce-overhead",
105
+ fullgraph=True
106
+ ) # requires PyTorch 2.0 (optional)
107
+ else:
108
+ print(f"no need to compile model in demo")
109
+
110
+ assert os.path.exists(args.t5_path)
111
+ t5_model = T5Embedder(
112
+ device=device,
113
+ local_cache=True,
114
+ cache_dir=args.t5_path,
115
+ dir_or_name=args.t5_model_type,
116
+ torch_dtype=precision,
117
+ model_max_length=args.t5_feature_max_len,
118
+ )
119
+
120
+
121
+ if args.condition_type == 'canny':
122
+ get_control = CannyDetector()
123
+ elif args.condition_type == 'hed':
124
+ get_control = HEDdetector().to(device).eval()
125
+ elif args.condition_type == 'lineart':
126
+ get_control = LineArt()
127
+ get_control.load_state_dict(torch.load('condition/ckpts/model.pth', map_location=torch.device('cpu')))
128
+ get_control.to(device)
129
+ elif args.condition_type == 'depth':
130
+ processor = DPTImageProcessor.from_pretrained("condition/ckpts/dpt_large")
131
+ model_large = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large").to(device)
132
+ model = MidasDetector(device=device)
133
+ with torch.no_grad():
134
+
135
+ condition_img = resize_image_to_16_multiple(args.condition_path, args.condition_type)
136
+ W, H = condition_img.size
137
+ print(H,W)
138
+ if args.condition_type == 'seg':
139
+ condition_img = torch.from_numpy(np.array(condition_img))
140
+ condition_img = condition_img.permute(2,0,1).unsqueeze(0).repeat(2,1,1,1)
141
+ elif args.condition_type == 'canny':
142
+ condition_img = get_control(np.array(condition_img))
143
+ condition_img = torch.from_numpy(condition_img[None,None,...]).repeat(2,3,1,1)
144
+ elif args.condition_type == 'hed':
145
+ condition_img = get_control(torch.from_numpy(np.array(condition_img)).permute(2,0,1).unsqueeze(0).to(device))
146
+ condition_img = condition_img.unsqueeze(1).repeat(2,3,1,1)
147
+ elif args.condition_type == 'lineart':
148
+ condition_img = get_control(torch.from_numpy(np.array(condition_img)).permute(2,0,1).unsqueeze(0).to(device).float())
149
+ condition_img = condition_img.repeat(2,3,1,1) * 255
150
+ elif args.condition_type == 'depth':
151
+ images = condition_img
152
+ if H == W:
153
+ inputs = processor(images=images, return_tensors="pt", size=(H,W)).to(device)
154
+ outputs = model_large(**inputs)
155
+ condition_img = outputs.predicted_depth
156
+ condition_img = (condition_img * 255 / condition_img.max())
157
+ else:
158
+ condition_img = torch.from_numpy(model(torch.from_numpy(np.array(condition_img)).to(device))).unsqueeze(0)
159
+ condition_img = condition_img.unsqueeze(0).repeat(2,3,1,1)
160
+ condition_img = condition_img.to(device)
161
+ condition_img = 2*(condition_img/255 - 0.5)
162
+ prompts = [args.prompt if args.prompt is not None else "a high-quality image"]
163
+ prompts = prompts * 2
164
+ caption_embs, emb_masks = t5_model.get_text_embeddings(prompts)
165
+
166
+ if not args.no_left_padding:
167
+ print(f"processing left-padding...")
168
+ # a naive way to implement left-padding
169
+ new_emb_masks = torch.flip(emb_masks, dims=[-1])
170
+ new_caption_embs = []
171
+ for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)):
172
+ valid_num = int(emb_mask.sum().item())
173
+ print(f' prompt {idx} token len: {valid_num}')
174
+ new_caption_emb = torch.cat([caption_emb[valid_num:],caption_emb[:valid_num]])
175
+ new_caption_embs.append(new_caption_emb)
176
+ new_caption_embs = torch.stack(new_caption_embs)
177
+ else:
178
+ new_caption_embs, new_emb_masks = caption_embs, emb_masks
179
+ c_indices = new_caption_embs * new_emb_masks[:,:, None]
180
+ c_emb_masks = new_emb_masks
181
+ qzshape = [len(c_indices), args.codebook_embed_dim, H//args.downsample_size, W//args.downsample_size]
182
+ t1 = time.time()
183
+ index_sample = generate(
184
+ gpt_model, c_indices, (H//args.downsample_size)*(W//args.downsample_size),#latent_size ** 2,
185
+ c_emb_masks, condition=condition_img.to(precision),
186
+ cfg_scale=args.cfg_scale,
187
+ temperature=args.temperature, top_k=args.top_k,
188
+ top_p=args.top_p, sample_logits=True,
189
+ )
190
+ sampling_time = time.time() - t1
191
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
192
+
193
+ t2 = time.time()
194
+ print(index_sample.shape)
195
+ samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
196
+ decoder_time = time.time() - t2
197
+ print(f"decoder takes about {decoder_time:.2f} seconds.")
198
+
199
+ samples = torch.cat((condition_img[0:1], samples), dim=0)
200
+ save_image(samples, f"sample/example/sample_t2i_MR_{args.condition_type}.png", nrow=4, normalize=True, value_range=(-1, 1))
201
+ print(f"image is saved to sample/example/sample_t2i_MR_{args.condition_type}.png")
202
+ print(prompts)
203
+
204
+
205
+ if __name__ == "__main__":
206
+ parser = argparse.ArgumentParser()
207
+ parser.add_argument("--t5-path", type=str, default='checkpoints/t5-ckpt')
208
+ parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
209
+ parser.add_argument("--t5-feature-max-len", type=int, default=120)
210
+ parser.add_argument("--t5-feature-dim", type=int, default=2048)
211
+ parser.add_argument("--no-left-padding", action='store_true', default=False)
212
+ parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
213
+ parser.add_argument("--gpt-ckpt", type=str, default=None)
214
+ parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image")
215
+ parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input")
216
+ parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
217
+ parser.add_argument("--compile", action='store_true', default=False)
218
+ parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
219
+ parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
220
+ parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
221
+ parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
222
+ parser.add_argument("--image-size", type=int, choices=[256, 320, 384, 400, 448, 512, 576, 640, 704, 768], default=768)
223
+ parser.add_argument("--image-H", type=int, default=512)
224
+ parser.add_argument("--image-W", type=int, default=512)
225
+ parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
226
+ parser.add_argument("--cfg-scale", type=float, default=4)
227
+ parser.add_argument("--seed", type=int, default=0)
228
+ parser.add_argument("--top-k", type=int, default=2000, help="top-k value to sample with")
229
+ parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
230
+ parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
231
+
232
+ parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
233
+ parser.add_argument("--condition-type", type=str, choices=['seg', 'canny', 'hed', 'lineart', 'depth'], default="canny")
234
+ parser.add_argument("--prompt", type=str, default='a high-quality image')
235
+ parser.add_argument("--condition-path", type=str, default='condition/example/t2i/multigen/landscape.png')
236
+ args = parser.parse_args()
237
+ main(args)
autoregressive/sample/sample_t2i_ddp.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.backends.cuda.matmul.allow_tf32 = True
3
+ torch.backends.cudnn.allow_tf32 = True
4
+ torch.set_float32_matmul_precision('high')
5
+ setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
6
+ setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
7
+ import torch.nn.functional as F
8
+ import torch.distributed as dist
9
+
10
+ import os
11
+ import math
12
+ import json
13
+ import argparse
14
+ import pandas as pd
15
+ from tqdm import tqdm
16
+ from PIL import Image
17
+
18
+ from tokenizer.tokenizer_image.vq_model import VQ_models
19
+ from language.t5 import T5Embedder
20
+ from autoregressive.models.gpt import GPT_models
21
+ from autoregressive.models.generate import generate
22
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
+
24
+
25
+
26
+ def main(args):
27
+ # Setup PyTorch:
28
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
29
+ torch.set_grad_enabled(False)
30
+
31
+ # Setup DDP:
32
+ dist.init_process_group("nccl")
33
+ rank = dist.get_rank()
34
+ device = rank % torch.cuda.device_count()
35
+ seed = args.global_seed * dist.get_world_size() + rank
36
+ torch.manual_seed(seed)
37
+ torch.cuda.set_device(device)
38
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
39
+
40
+ # create and load model
41
+ vq_model = VQ_models[args.vq_model](
42
+ codebook_size=args.codebook_size,
43
+ codebook_embed_dim=args.codebook_embed_dim)
44
+ vq_model.to(device)
45
+ vq_model.eval()
46
+ checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
47
+ vq_model.load_state_dict(checkpoint["model"])
48
+ del checkpoint
49
+ print(f"image tokenizer is loaded")
50
+
51
+ # create and load gpt model
52
+ precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
53
+ latent_size = args.image_size // args.downsample_size
54
+ gpt_model = GPT_models[args.gpt_model](
55
+ block_size=latent_size ** 2,
56
+ cls_token_num=args.cls_token_num,
57
+ model_type=args.gpt_type,
58
+ ).to(device=device, dtype=precision)
59
+
60
+ checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
61
+
62
+ if "model" in checkpoint: # ddp
63
+ model_weight = checkpoint["model"]
64
+ elif "module" in checkpoint: # deepspeed
65
+ model_weight = checkpoint["module"]
66
+ elif "state_dict" in checkpoint:
67
+ model_weight = checkpoint["state_dict"]
68
+ else:
69
+ raise Exception("please check model weight")
70
+ gpt_model.load_state_dict(model_weight, strict=False)
71
+ gpt_model.eval()
72
+ del checkpoint
73
+ print(f"gpt model is loaded")
74
+
75
+ if args.compile:
76
+ print(f"compiling the model...")
77
+ gpt_model = torch.compile(
78
+ gpt_model,
79
+ mode="reduce-overhead",
80
+ fullgraph=True
81
+ ) # requires PyTorch 2.0 (optional)
82
+ else:
83
+ print(f"no need to compile model in demo")
84
+
85
+ assert os.path.exists(args.t5_path)
86
+ t5_model = T5Embedder(
87
+ device=device,
88
+ local_cache=True,
89
+ cache_dir=args.t5_path,
90
+ dir_or_name=args.t5_model_type,
91
+ torch_dtype=precision,
92
+ model_max_length=args.t5_feature_max_len,
93
+ )
94
+ print(f"t5 model is loaded")
95
+
96
+ # Create folder to save samples:
97
+ model_string_name = args.gpt_model.replace("/", "-")
98
+ ckpt_string_name = os.path.basename(args.gpt_ckpt).replace(".pth", "").replace(".pt", "")
99
+ prompt_name = args.prompt_csv.split('/')[-1].split('.')[0].lower()
100
+ folder_name = f"{model_string_name}-{ckpt_string_name}-{prompt_name}-size-{args.image_size}-size-{args.image_size}-{args.vq_model}-" \
101
+ f"topk-{args.top_k}-topp-{args.top_p}-temperature-{args.temperature}-" \
102
+ f"cfg-{args.cfg_scale}-seed-{args.global_seed}"
103
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
104
+ if rank == 0:
105
+ os.makedirs(f"{sample_folder_dir}/images", exist_ok=True)
106
+ print(f"Saving .png samples at {sample_folder_dir}/images")
107
+ dist.barrier()
108
+
109
+ df = pd.read_csv(args.prompt_csv, delimiter='\t')
110
+ prompt_list = df['Prompt'].tolist()
111
+
112
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
113
+ n = args.per_proc_batch_size
114
+ global_batch_size = n * dist.get_world_size()
115
+ num_fid_samples = min(args.num_fid_samples, len(prompt_list))
116
+ # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
117
+ total_samples = int(math.ceil(num_fid_samples / global_batch_size) * global_batch_size)
118
+ if rank == 0:
119
+ print(f"Total number of images that will be sampled: {total_samples}")
120
+ assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
121
+ samples_needed_this_gpu = int(total_samples // dist.get_world_size())
122
+ assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
123
+ iterations = int(samples_needed_this_gpu // n)
124
+ pbar = range(iterations)
125
+ pbar = tqdm(pbar) if rank == 0 else pbar
126
+ total = 0
127
+ for _ in pbar:
128
+ # Select text prompt
129
+ prompt_batch = []
130
+ for i in range(n):
131
+ index = i * dist.get_world_size() + rank + total
132
+ prompt_batch.append(prompt_list[index] if index < len(prompt_list) else "a cute dog")
133
+
134
+ # Sample inputs:
135
+ caption_embs, emb_masks = t5_model.get_text_embeddings(prompt_batch)
136
+
137
+ if not args.no_left_padding:
138
+ new_emb_masks = torch.flip(emb_masks, dims=[-1])
139
+ new_caption_embs = []
140
+ for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)):
141
+ valid_num = int(emb_mask.sum().item())
142
+ # prompt_cur = prompt_batch[idx]
143
+ # print(f' prompt {idx} token len: {valid_num} : {prompt_cur}')
144
+ new_caption_emb = torch.cat([caption_emb[valid_num:], caption_emb[:valid_num]])
145
+ new_caption_embs.append(new_caption_emb)
146
+ new_caption_embs = torch.stack(new_caption_embs)
147
+
148
+ else:
149
+ new_caption_embs, new_emb_masks = caption_embs, emb_masks
150
+
151
+ c_indices = new_caption_embs * new_emb_masks[:,:, None]
152
+ c_emb_masks = new_emb_masks
153
+
154
+ qzshape = [len(c_indices), args.codebook_embed_dim, latent_size, latent_size]
155
+ index_sample = generate(
156
+ gpt_model, c_indices, latent_size ** 2,
157
+ c_emb_masks,
158
+ cfg_scale=args.cfg_scale,
159
+ temperature=args.temperature, top_k=args.top_k,
160
+ top_p=args.top_p, sample_logits=True,
161
+ )
162
+
163
+ samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
164
+ samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
165
+
166
+ # Save samples to disk as individual .png files
167
+ for i, sample in enumerate(samples):
168
+ index = i * dist.get_world_size() + rank + total
169
+ Image.fromarray(sample).save(f"{sample_folder_dir}/images/{index:06d}.png")
170
+ total += global_batch_size
171
+
172
+ # Make sure all processes have finished saving their samples before attempting to convert to .npz
173
+ dist.barrier()
174
+ if rank == 0:
175
+ # Save infer result in a jsonl file
176
+ json_items = []
177
+ for idx, prompt in enumerate(prompt_list):
178
+ image_path = os.path.join(sample_folder_dir, "images", f"{idx:06d}.png")
179
+ json_items.append({"text": prompt, "image_path": image_path})
180
+ res_jsonl_path = os.path.join(sample_folder_dir, "result.jsonl")
181
+ print(f"Save jsonl to {res_jsonl_path}...")
182
+ with open(res_jsonl_path, "w") as f:
183
+ for item in json_items:
184
+ f.write(json.dumps(item) + "\n")
185
+
186
+ # Save captions to txt
187
+ caption_path = os.path.join(sample_folder_dir, "captions.txt")
188
+ print(f"Save captions to {caption_path}...")
189
+ with open(caption_path, "w") as f:
190
+ for item in prompt_list:
191
+ f.write(f"{item}\n")
192
+ print("Done.")
193
+
194
+ dist.barrier()
195
+ dist.destroy_process_group()
196
+
197
+
198
+
199
+ if __name__ == "__main__":
200
+ parser = argparse.ArgumentParser()
201
+ parser.add_argument("--prompt-csv", type=str, default='evaluations/t2i/PartiPrompts.tsv')
202
+ parser.add_argument("--t5-path", type=str, default='pretrained_models/t5-ckpt')
203
+ parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
204
+ parser.add_argument("--t5-feature-max-len", type=int, default=120)
205
+ parser.add_argument("--t5-feature-dim", type=int, default=2048)
206
+ parser.add_argument("--no-left-padding", action='store_true', default=False)
207
+ parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
208
+ parser.add_argument("--gpt-ckpt", type=str, default=None)
209
+ parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image")
210
+ parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input")
211
+ parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
212
+ parser.add_argument("--compile", action='store_true', default=False)
213
+ parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
214
+ parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
215
+ parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
216
+ parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
217
+ parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=512)
218
+ parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
219
+ parser.add_argument("--num-classes", type=int, default=1000)
220
+ parser.add_argument("--cfg-scale", type=float, default=7.5)
221
+ parser.add_argument("--sample-dir", type=str, default="samples_parti", help="samples_coco or samples_parti")
222
+ parser.add_argument("--per-proc-batch-size", type=int, default=32)
223
+ parser.add_argument("--num-fid-samples", type=int, default=30000)
224
+ parser.add_argument("--global-seed", type=int, default=0)
225
+ parser.add_argument("--top-k", type=int, default=1000, help="top-k value to sample with")
226
+ parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
227
+ parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
228
+ args = parser.parse_args()
229
+ main(args)
checkpoints/vq_ds16_t2i.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e21fc1318e2e9ee641a07bdad0e20675e9ec35e6e3eb911d58b5d7a2cd8d4cb
3
+ size 287920306
condition/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Prepare the preprocessing model
2
+
3
+ Hed: https://huggingface.co/lllyasviel/Annotators/blob/main/ControlNetHED.pth\
4
+ Lineart: https://huggingface.co/spaces/awacke1/Image-to-Line-Drawings/resolve/main/model.pth\
5
+ depth: https://huggingface.co/lllyasviel/Annotators/blob/main/dpt_hybrid-midas-501f0c75.pt (hybrid for inference)\
6
+ https://huggingface.co/Intel/dpt-large (large for test conditional consistency and fid)\
7
+
8
+ We recommend storing them in the following paths
9
+
10
+ |---condition
11
+ |---ckpts
12
+ |---dpt_large
13
+ |---config.json
14
+ |---preprocessor_config.json
15
+ |---pytorch_model.bin
16
+ |---ControlNetHED.pth
17
+ |---dpt_hybrid-midas-501f0c75.pt
18
+ |---model.pth
19
+ |---example
20
+ |---midas
21
+ .
22
+ .
23
+ .
condition/canny.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+
5
+
6
+ class CannyDetector:
7
+ def __call__(self, img, low_threshold=100, high_threshold=200):
8
+ """
9
+ input: array or tensor (H,W,3)
10
+ output: array (H,W)
11
+ """
12
+ if torch.is_tensor(img):
13
+ img = img.cpu().detach().numpy().astype(np.uint8)
14
+ return cv2.Canny(img, low_threshold, high_threshold)
15
+
16
+
17
+ if __name__ == '__main__':
18
+ apply_canny = CannyDetector()
19
+ img = cv2.imread('condition/dragon_resize.png')
20
+ import numpy as np
21
+ print(img.max())
22
+ detected_map = apply_canny(img, 100, 200)
23
+ print(detected_map.shape, detected_map.max(), detected_map.min())
24
+ cv2.imwrite('condition/example_canny.jpg', detected_map)
25
+ np.save('condition/example_canny.npy', detected_map[None,None])
condition/depth.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from controlnet_aux import LineartDetector
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
6
+ class Depth:
7
+ def __init__(self, device):
8
+ self.model = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large")
9
+
10
+ def __call__(self, input_image):
11
+ """
12
+ input: tensor()
13
+ """
14
+ control_image = self.model(input_image)
15
+ return np.array(control_image)
16
+
17
+ if __name__ == '__main__':
18
+ import matplotlib.pyplot as plt
19
+ from tqdm import tqdm
20
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
21
+ from PIL import Image
22
+
23
+ image = Image.open('condition/example/t2i/depth/depth.png')
24
+ img = cv2.imread('condition/example/t2i/depth/depth.png')
25
+ processor = DPTImageProcessor.from_pretrained("condition/ckpts/dpt_large")
26
+ model = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large")
27
+
28
+ inputs = torch.from_numpy(np.array(img)).permute(2,0,1).unsqueeze(0).float()#
29
+ inputs = 2*(inputs/255 - 0.5)
30
+ inputs = processor(images=image, return_tensors="pt", size=(512,512))
31
+ print(inputs)
32
+ with torch.no_grad():
33
+ outputs = model(**inputs)
34
+ predicted_depth = outputs.predicted_depth
35
+ print(predicted_depth.shape)
36
+ prediction = torch.nn.functional.interpolate(
37
+ predicted_depth.unsqueeze(1),
38
+ size=image.size[::-1],
39
+ mode="bicubic",
40
+ align_corners=False,
41
+ )
42
+
43
+ output = prediction.squeeze().cpu().numpy()
44
+ formatted = (output * 255 / np.max(output)).astype("uint8")
45
+
46
+ depth = Image.fromarray(formatted)
47
+ depth.save('condition/example/t2i/depth/example_depth.jpg')
condition/example/t2i/multi_resolution/bird.jpg ADDED
condition/example/t2i/multi_resolution/car.jpg ADDED
condition/example/t2i/multigen/doll.jpg ADDED
condition/example/t2i/multigen/girl.jpg ADDED
condition/example/t2i/multigen/house.jpg ADDED
condition/example/t2i/multigen/sofa.png ADDED
condition/hed.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an improved version and model of HED edge detection with Apache License, Version 2.0.
2
+ # Please use this implementation in your products
3
+ # This implementation may produce slightly different results from Saining Xie's official implementations,
4
+ # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
5
+ # Different from official models and other implementations, this is an RGB-input model (rather than BGR)
6
+ # and in this way it works better for gradio's RGB protocol
7
+
8
+ import os
9
+ import cv2
10
+ import torch
11
+ import numpy as np
12
+ from torch.nn.parallel import DataParallel
13
+ from einops import rearrange
14
+ from condition.utils import annotator_ckpts_path
15
+ import torch.nn.functional as F
16
+
17
+ class DoubleConvBlock(torch.nn.Module):
18
+ def __init__(self, input_channel, output_channel, layer_number):
19
+ super().__init__()
20
+ self.convs = torch.nn.Sequential()
21
+ self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
22
+ for i in range(1, layer_number):
23
+ self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
24
+ self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
25
+
26
+ def __call__(self, x, down_sampling=False):
27
+ h = x
28
+ if down_sampling:
29
+ h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
30
+ for conv in self.convs:
31
+ h = conv(h)
32
+ h = torch.nn.functional.relu(h)
33
+ return h, self.projection(h)
34
+
35
+
36
+ class ControlNetHED_Apache2(torch.nn.Module):
37
+ def __init__(self):
38
+ super().__init__()
39
+ self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
40
+ self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
41
+ self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
42
+ self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
43
+ self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
44
+ self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
45
+
46
+ def __call__(self, x):
47
+ h = x - self.norm
48
+ h, projection1 = self.block1(h)
49
+ h, projection2 = self.block2(h, down_sampling=True)
50
+ h, projection3 = self.block3(h, down_sampling=True)
51
+ h, projection4 = self.block4(h, down_sampling=True)
52
+ h, projection5 = self.block5(h, down_sampling=True)
53
+ return projection1, projection2, projection3, projection4, projection5
54
+
55
+
56
+ class HEDdetector(torch.nn.Module):
57
+ def __init__(self):
58
+ super().__init__()
59
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
60
+ modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
61
+ if not os.path.exists(modelpath):
62
+ from basicsr.utils.download_util import load_file_from_url
63
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
64
+ self.netNetwork = ControlNetHED_Apache2().float()#.to(self.device).eval()
65
+ self.netNetwork.load_state_dict(torch.load(modelpath))
66
+
67
+ def __call__(self, input_image):
68
+ """
69
+ input: tensor (B,C,H,W)
70
+ output: tensor (B,H,W)
71
+ """
72
+ B, C, H, W = input_image.shape
73
+ image_hed = input_image
74
+
75
+ edges = self.netNetwork(image_hed)
76
+ edges = [F.interpolate(e, size=(H, W), mode='bilinear', align_corners=False).squeeze(1) for e in edges]
77
+ edges = torch.stack(edges, dim=1)
78
+ edge = 1 / (1 + torch.exp(-torch.mean(edges, dim=1)))
79
+ edge = (edge * 255.0).clamp(0, 255)
80
+
81
+ return edge
82
+
83
+
84
+ def nms(x, t, s):
85
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
86
+
87
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
88
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
89
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
90
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
91
+
92
+ y = np.zeros_like(x)
93
+
94
+ for f in [f1, f2, f3, f4]:
95
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
96
+
97
+ z = np.zeros_like(y, dtype=np.uint8)
98
+ z[y > t] = 255
99
+ return z
100
+
101
+ if __name__ == '__main__':
102
+ import matplotlib.pyplot as plt
103
+ from tqdm import tqdm
104
+ import torch.nn.functional as F
105
+ device = torch.device('cuda')
106
+ apply_hed = HEDdetector().to(device).eval()
107
+ img = cv2.imread('condition/dragon_1024_512.jpg')
108
+ H,W = img.shape[:2]
109
+ resize_img = cv2.resize(img,(512,1024))
110
+ detected_map = apply_hed(torch.from_numpy(img).permute(2,0,1).unsqueeze(0).cuda())
111
+ resize_detected_map = apply_hed(torch.from_numpy(resize_img).permute(2,0,1).unsqueeze(0).cuda())
112
+ cv2.imwrite('condition/example_hed_resize.jpg', resize_detected_map[0].cpu().detach().numpy())
113
+ resize_detected_map = F.interpolate(resize_detected_map.unsqueeze(0).to(torch.float32), size=(H,W), mode='bilinear', align_corners=False, antialias=True)
114
+ print(abs(detected_map - resize_detected_map).sum())
115
+ print(img.shape, img.max(),img.min(),detected_map.shape, detected_map.max(),detected_map.min())
116
+ cv2.imwrite('condition/example_hed.jpg', detected_map[0].cpu().detach().numpy())
117
+ cv2.imwrite('condition/example_hed_resized.jpg', resize_detected_map[0,0].cpu().detach().numpy())
condition/lineart.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from controlnet_aux import LineartDetector
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ import torch.nn as nn
6
+
7
+
8
+ norm_layer = nn.InstanceNorm2d
9
+ class ResidualBlock(nn.Module):
10
+ def __init__(self, in_features):
11
+ super(ResidualBlock, self).__init__()
12
+
13
+ conv_block = [ nn.ReflectionPad2d(1),
14
+ nn.Conv2d(in_features, in_features, 3),
15
+ norm_layer(in_features),
16
+ nn.ReLU(inplace=True),
17
+ nn.ReflectionPad2d(1),
18
+ nn.Conv2d(in_features, in_features, 3),
19
+ norm_layer(in_features)
20
+ ]
21
+
22
+ self.conv_block = nn.Sequential(*conv_block)
23
+
24
+ def forward(self, x):
25
+ return x + self.conv_block(x)
26
+ class LineArt(nn.Module):
27
+ def __init__(self, input_nc=3, output_nc=1, n_residual_blocks=3, sigmoid=True):
28
+ super(LineArt, self).__init__()
29
+
30
+ # Initial convolution block
31
+ model0 = [ nn.ReflectionPad2d(3),
32
+ nn.Conv2d(input_nc, 64, 7),
33
+ norm_layer(64),
34
+ nn.ReLU(inplace=True) ]
35
+ self.model0 = nn.Sequential(*model0)
36
+
37
+ # Downsampling
38
+ model1 = []
39
+ in_features = 64
40
+ out_features = in_features*2
41
+ for _ in range(2):
42
+ model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
43
+ norm_layer(out_features),
44
+ nn.ReLU(inplace=True) ]
45
+ in_features = out_features
46
+ out_features = in_features*2
47
+ self.model1 = nn.Sequential(*model1)
48
+
49
+ model2 = []
50
+ # Residual blocks
51
+ for _ in range(n_residual_blocks):
52
+ model2 += [ResidualBlock(in_features)]
53
+ self.model2 = nn.Sequential(*model2)
54
+
55
+ # Upsampling
56
+ model3 = []
57
+ out_features = in_features//2
58
+ for _ in range(2):
59
+ model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
60
+ norm_layer(out_features),
61
+ nn.ReLU(inplace=True) ]
62
+ in_features = out_features
63
+ out_features = in_features//2
64
+ self.model3 = nn.Sequential(*model3)
65
+
66
+ # Output layer
67
+ model4 = [ nn.ReflectionPad2d(3),
68
+ nn.Conv2d(64, output_nc, 7)]
69
+ if sigmoid:
70
+ model4 += [nn.Sigmoid()]
71
+
72
+ self.model4 = nn.Sequential(*model4)
73
+
74
+ def forward(self, x, cond=None):
75
+ """
76
+ input: tensor (B,C,H,W)
77
+ output: tensor (B,1,H,W) 0~1
78
+ """
79
+
80
+ out = self.model0(x)
81
+ out = self.model1(out)
82
+ out = self.model2(out)
83
+ out = self.model3(out)
84
+ out = self.model4(out)
85
+
86
+ return out
87
+
88
+
89
+ if __name__ == '__main__':
90
+ import matplotlib.pyplot as plt
91
+ from tqdm import tqdm
92
+ apply_lineart = LineArt()
93
+ apply_lineart.load_state_dict(torch.load('condition/ckpts/model.pth', map_location=torch.device('cpu')))
94
+ img = cv2.imread('condition/car_448_768.jpg')
95
+ img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).repeat(8,1,1,1).float()
96
+ detected_map = apply_lineart(img)
97
+ print(img.shape, img.max(),img.min(),detected_map.shape, detected_map.max(),detected_map.min())
98
+ cv2.imwrite('condition/example_lineart.jpg', 255*detected_map[0,0].cpu().detach().numpy())
condition/midas/depth.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Midas Depth Estimation
2
+ # From https://github.com/isl-org/MiDaS
3
+ # MIT LICENSE
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import os
9
+ import sys
10
+ current_directory = os.getcwd()
11
+ sys.path.append(current_directory)
12
+ from einops import rearrange
13
+ # from .api import MiDaSInference
14
+ from condition.utils import annotator_ckpts_path
15
+ from condition.midas.midas.dpt_depth import DPTDepthModel
16
+ from condition.midas.midas.midas_net import MidasNet
17
+ from condition.midas.midas.midas_net_custom import MidasNet_small
18
+ from condition.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
19
+ import os
20
+ import torch.nn as nn
21
+ from torchvision.transforms import Compose
22
+
23
+ ISL_PATHS = {
24
+ "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
25
+ "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
26
+ "midas_v21": "",
27
+ "midas_v21_small": "",
28
+ }
29
+
30
+ remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
31
+
32
+
33
+ def disabled_train(self, mode=True):
34
+ """Overwrite model.train with this function to make sure train/eval mode
35
+ does not change anymore."""
36
+ return self
37
+
38
+
39
+ def load_midas_transform(model_type):
40
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
41
+ # load transform only
42
+ if model_type == "dpt_large": # DPT-Large
43
+ net_w, net_h = 384, 384
44
+ resize_mode = "minimal"
45
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
46
+
47
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
48
+ net_w, net_h = 384, 384
49
+ resize_mode = "minimal"
50
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
51
+
52
+ elif model_type == "midas_v21":
53
+ net_w, net_h = 384, 384
54
+ resize_mode = "upper_bound"
55
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
56
+
57
+ elif model_type == "midas_v21_small":
58
+ net_w, net_h = 256, 256
59
+ resize_mode = "upper_bound"
60
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
61
+
62
+ else:
63
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
64
+
65
+ transform = Compose(
66
+ [
67
+ Resize(
68
+ net_w,
69
+ net_h,
70
+ resize_target=None,
71
+ keep_aspect_ratio=True,
72
+ ensure_multiple_of=32,
73
+ resize_method=resize_mode,
74
+ image_interpolation_method=cv2.INTER_CUBIC,
75
+ ),
76
+ normalization,
77
+ PrepareForNet(),
78
+ ]
79
+ )
80
+
81
+ return transform
82
+
83
+
84
+ def load_model(model_type):
85
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
86
+ # load network
87
+ model_path = ISL_PATHS[model_type]
88
+ if model_type == "dpt_large": # DPT-Large
89
+ model = DPTDepthModel(
90
+ path=model_path,
91
+ backbone="vitl16_384",
92
+ non_negative=True,
93
+ )
94
+ net_w, net_h = 384, 384
95
+ resize_mode = "minimal"
96
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
97
+
98
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
99
+ if not os.path.exists(model_path):
100
+ from basicsr.utils.download_util import load_file_from_url
101
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
102
+
103
+ model = DPTDepthModel(
104
+ path=model_path,
105
+ backbone="vitb_rn50_384",
106
+ non_negative=True,
107
+ )
108
+ net_w, net_h = 384, 384
109
+ resize_mode = "minimal"
110
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
111
+
112
+ elif model_type == "midas_v21":
113
+ model = MidasNet(model_path, non_negative=True)
114
+ net_w, net_h = 384, 384
115
+ resize_mode = "upper_bound"
116
+ normalization = NormalizeImage(
117
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
118
+ )
119
+
120
+ elif model_type == "midas_v21_small":
121
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
122
+ non_negative=True, blocks={'expand': True})
123
+ net_w, net_h = 256, 256
124
+ resize_mode = "upper_bound"
125
+ normalization = NormalizeImage(
126
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
127
+ )
128
+
129
+ else:
130
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
131
+ assert False
132
+
133
+ transform = Compose(
134
+ [
135
+ Resize(
136
+ net_w,
137
+ net_h,
138
+ resize_target=None,
139
+ keep_aspect_ratio=True,
140
+ ensure_multiple_of=32,
141
+ resize_method=resize_mode,
142
+ image_interpolation_method=cv2.INTER_CUBIC,
143
+ ),
144
+ normalization,
145
+ PrepareForNet(),
146
+ ]
147
+ )
148
+
149
+ return model.eval(), transform
150
+
151
+
152
+ class MiDaSInference(nn.Module):
153
+ MODEL_TYPES_TORCH_HUB = [
154
+ "DPT_Large",
155
+ "DPT_Hybrid",
156
+ "MiDaS_small"
157
+ ]
158
+ MODEL_TYPES_ISL = [
159
+ "dpt_large",
160
+ "dpt_hybrid",
161
+ "midas_v21",
162
+ "midas_v21_small",
163
+ ]
164
+
165
+ def __init__(self, model_type):
166
+ super().__init__()
167
+ assert (model_type in self.MODEL_TYPES_ISL)
168
+ model, _ = load_model(model_type)
169
+ self.model = model
170
+ self.model.train = disabled_train
171
+
172
+ def forward(self, x):
173
+ with torch.no_grad():
174
+ prediction = self.model(x)
175
+ return prediction
176
+
177
+
178
+ class MidasDetector:
179
+ def __init__(self,device=torch.device('cuda:0'), model_type="dpt_hybrid"):
180
+ self.device = device
181
+ self.model = MiDaSInference(model_type=model_type).to(device)
182
+
183
+ def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
184
+ assert input_image.ndim == 3
185
+ image_depth = input_image
186
+ with torch.no_grad():
187
+ image_depth = image_depth
188
+ image_depth = image_depth / 127.5 - 1.0
189
+ image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
190
+ depth = self.model(image_depth)[0]
191
+
192
+ depth_pt = depth.clone()
193
+ depth_pt -= torch.min(depth_pt)
194
+ depth_pt /= torch.max(depth_pt)
195
+ depth_pt = depth_pt.cpu().numpy()
196
+ depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
197
+
198
+ depth_np = depth.cpu().numpy()
199
+ x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
200
+ y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
201
+ z = np.ones_like(x) * a
202
+ x[depth_pt < bg_th] = 0
203
+ y[depth_pt < bg_th] = 0
204
+ # normal = np.stack([x, y, z], axis=2)
205
+ # normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
206
+ # normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
207
+
208
+ return depth_image#, normal_image
209
+
210
+ if __name__ == '__main__':
211
+ import matplotlib.pyplot as plt
212
+ from tqdm import tqdm
213
+ from PIL import Image
214
+ import torchvision.transforms.functional as F
215
+ apply_depth = MidasDetector(device=torch.device('cuda:0'))
216
+ img = cv2.imread('/data/vjuicefs_sz_cv_v2/11171709/ControlAR_github/condition/example/t2i/multi_resolution/car_1_448_768.jpg')
217
+ img = cv2.resize(img,(768,448))
218
+ detected_map = apply_depth(torch.from_numpy(img).cuda().float())
219
+ print(img.shape, img.max(),img.min(),detected_map.shape, detected_map.max(),detected_map.min())
220
+ plt.imshow(detected_map, cmap='gray')
221
+ plt.show()
222
+ cv2.imwrite('condition/example_depth.jpg', detected_map)
223
+ # cv2.imwrite('condition/example_normal.jpg', normal_map)
condition/midas/midas/__init__.py ADDED
File without changes
condition/midas/midas/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device('cpu'))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
condition/midas/midas/blocks.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vit import (
5
+ _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384,
7
+ _make_pretrained_vitb16_384,
8
+ forward_vit,
9
+ )
10
+
11
+ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12
+ if backbone == "vitl16_384":
13
+ pretrained = _make_pretrained_vitl16_384(
14
+ use_pretrained, hooks=hooks, use_readout=use_readout
15
+ )
16
+ scratch = _make_scratch(
17
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
18
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
19
+ elif backbone == "vitb_rn50_384":
20
+ pretrained = _make_pretrained_vitb_rn50_384(
21
+ use_pretrained,
22
+ hooks=hooks,
23
+ use_vit_only=use_vit_only,
24
+ use_readout=use_readout,
25
+ )
26
+ scratch = _make_scratch(
27
+ [256, 512, 768, 768], features, groups=groups, expand=expand
28
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
29
+ elif backbone == "vitb16_384":
30
+ pretrained = _make_pretrained_vitb16_384(
31
+ use_pretrained, hooks=hooks, use_readout=use_readout
32
+ )
33
+ scratch = _make_scratch(
34
+ [96, 192, 384, 768], features, groups=groups, expand=expand
35
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
36
+ elif backbone == "resnext101_wsl":
37
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39
+ elif backbone == "efficientnet_lite3":
40
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42
+ else:
43
+ print(f"Backbone '{backbone}' not implemented")
44
+ assert False
45
+
46
+ return pretrained, scratch
47
+
48
+
49
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50
+ scratch = nn.Module()
51
+
52
+ out_shape1 = out_shape
53
+ out_shape2 = out_shape
54
+ out_shape3 = out_shape
55
+ out_shape4 = out_shape
56
+ if expand==True:
57
+ out_shape1 = out_shape
58
+ out_shape2 = out_shape*2
59
+ out_shape3 = out_shape*4
60
+ out_shape4 = out_shape*8
61
+
62
+ scratch.layer1_rn = nn.Conv2d(
63
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64
+ )
65
+ scratch.layer2_rn = nn.Conv2d(
66
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67
+ )
68
+ scratch.layer3_rn = nn.Conv2d(
69
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70
+ )
71
+ scratch.layer4_rn = nn.Conv2d(
72
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73
+ )
74
+
75
+ return scratch
76
+
77
+
78
+ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79
+ efficientnet = torch.hub.load(
80
+ "rwightman/gen-efficientnet-pytorch",
81
+ "tf_efficientnet_lite3",
82
+ pretrained=use_pretrained,
83
+ exportable=exportable
84
+ )
85
+ return _make_efficientnet_backbone(efficientnet)
86
+
87
+
88
+ def _make_efficientnet_backbone(effnet):
89
+ pretrained = nn.Module()
90
+
91
+ pretrained.layer1 = nn.Sequential(
92
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93
+ )
94
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97
+
98
+ return pretrained
99
+
100
+
101
+ def _make_resnet_backbone(resnet):
102
+ pretrained = nn.Module()
103
+ pretrained.layer1 = nn.Sequential(
104
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105
+ )
106
+
107
+ pretrained.layer2 = resnet.layer2
108
+ pretrained.layer3 = resnet.layer3
109
+ pretrained.layer4 = resnet.layer4
110
+
111
+ return pretrained
112
+
113
+
114
+ def _make_pretrained_resnext101_wsl(use_pretrained):
115
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116
+ return _make_resnet_backbone(resnet)
117
+
118
+
119
+
120
+ class Interpolate(nn.Module):
121
+ """Interpolation module.
122
+ """
123
+
124
+ def __init__(self, scale_factor, mode, align_corners=False):
125
+ """Init.
126
+
127
+ Args:
128
+ scale_factor (float): scaling
129
+ mode (str): interpolation mode
130
+ """
131
+ super(Interpolate, self).__init__()
132
+
133
+ self.interp = nn.functional.interpolate
134
+ self.scale_factor = scale_factor
135
+ self.mode = mode
136
+ self.align_corners = align_corners
137
+
138
+ def forward(self, x):
139
+ """Forward pass.
140
+
141
+ Args:
142
+ x (tensor): input
143
+
144
+ Returns:
145
+ tensor: interpolated data
146
+ """
147
+
148
+ x = self.interp(
149
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150
+ )
151
+
152
+ return x
153
+
154
+
155
+ class ResidualConvUnit(nn.Module):
156
+ """Residual convolution module.
157
+ """
158
+
159
+ def __init__(self, features):
160
+ """Init.
161
+
162
+ Args:
163
+ features (int): number of features
164
+ """
165
+ super().__init__()
166
+
167
+ self.conv1 = nn.Conv2d(
168
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
169
+ )
170
+
171
+ self.conv2 = nn.Conv2d(
172
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
173
+ )
174
+
175
+ self.relu = nn.ReLU(inplace=True)
176
+
177
+ def forward(self, x):
178
+ """Forward pass.
179
+
180
+ Args:
181
+ x (tensor): input
182
+
183
+ Returns:
184
+ tensor: output
185
+ """
186
+ out = self.relu(x)
187
+ out = self.conv1(out)
188
+ out = self.relu(out)
189
+ out = self.conv2(out)
190
+
191
+ return out + x
192
+
193
+
194
+ class FeatureFusionBlock(nn.Module):
195
+ """Feature fusion block.
196
+ """
197
+
198
+ def __init__(self, features):
199
+ """Init.
200
+
201
+ Args:
202
+ features (int): number of features
203
+ """
204
+ super(FeatureFusionBlock, self).__init__()
205
+
206
+ self.resConfUnit1 = ResidualConvUnit(features)
207
+ self.resConfUnit2 = ResidualConvUnit(features)
208
+
209
+ def forward(self, *xs):
210
+ """Forward pass.
211
+
212
+ Returns:
213
+ tensor: output
214
+ """
215
+ output = xs[0]
216
+
217
+ if len(xs) == 2:
218
+ output += self.resConfUnit1(xs[1])
219
+
220
+ output = self.resConfUnit2(output)
221
+
222
+ output = nn.functional.interpolate(
223
+ output, scale_factor=2, mode="bilinear", align_corners=True
224
+ )
225
+
226
+ return output
227
+
228
+
229
+
230
+
231
+ class ResidualConvUnit_custom(nn.Module):
232
+ """Residual convolution module.
233
+ """
234
+
235
+ def __init__(self, features, activation, bn):
236
+ """Init.
237
+
238
+ Args:
239
+ features (int): number of features
240
+ """
241
+ super().__init__()
242
+
243
+ self.bn = bn
244
+
245
+ self.groups=1
246
+
247
+ self.conv1 = nn.Conv2d(
248
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249
+ )
250
+
251
+ self.conv2 = nn.Conv2d(
252
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253
+ )
254
+
255
+ if self.bn==True:
256
+ self.bn1 = nn.BatchNorm2d(features)
257
+ self.bn2 = nn.BatchNorm2d(features)
258
+
259
+ self.activation = activation
260
+
261
+ self.skip_add = nn.quantized.FloatFunctional()
262
+
263
+ def forward(self, x):
264
+ """Forward pass.
265
+
266
+ Args:
267
+ x (tensor): input
268
+
269
+ Returns:
270
+ tensor: output
271
+ """
272
+
273
+ out = self.activation(x)
274
+ out = self.conv1(out)
275
+ if self.bn==True:
276
+ out = self.bn1(out)
277
+
278
+ out = self.activation(out)
279
+ out = self.conv2(out)
280
+ if self.bn==True:
281
+ out = self.bn2(out)
282
+
283
+ if self.groups > 1:
284
+ out = self.conv_merge(out)
285
+
286
+ return self.skip_add.add(out, x)
287
+
288
+ # return out + x
289
+
290
+
291
+ class FeatureFusionBlock_custom(nn.Module):
292
+ """Feature fusion block.
293
+ """
294
+
295
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296
+ """Init.
297
+
298
+ Args:
299
+ features (int): number of features
300
+ """
301
+ super(FeatureFusionBlock_custom, self).__init__()
302
+
303
+ self.deconv = deconv
304
+ self.align_corners = align_corners
305
+
306
+ self.groups=1
307
+
308
+ self.expand = expand
309
+ out_features = features
310
+ if self.expand==True:
311
+ out_features = features//2
312
+
313
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314
+
315
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317
+
318
+ self.skip_add = nn.quantized.FloatFunctional()
319
+
320
+ def forward(self, *xs):
321
+ """Forward pass.
322
+
323
+ Returns:
324
+ tensor: output
325
+ """
326
+ output = xs[0]
327
+
328
+ if len(xs) == 2:
329
+ res = self.resConfUnit1(xs[1])
330
+ output = self.skip_add.add(output, res)
331
+ # output += res
332
+
333
+ output = self.resConfUnit2(output)
334
+
335
+ output = nn.functional.interpolate(
336
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337
+ )
338
+
339
+ output = self.out_conv(output)
340
+
341
+ return output
condition/midas/midas/dpt_depth.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .base_model import BaseModel
6
+ from .blocks import (
7
+ FeatureFusionBlock,
8
+ FeatureFusionBlock_custom,
9
+ Interpolate,
10
+ _make_encoder,
11
+ forward_vit,
12
+ )
13
+
14
+
15
+ def _make_fusion_block(features, use_bn):
16
+ return FeatureFusionBlock_custom(
17
+ features,
18
+ nn.ReLU(False),
19
+ deconv=False,
20
+ bn=use_bn,
21
+ expand=False,
22
+ align_corners=True,
23
+ )
24
+
25
+
26
+ class DPT(BaseModel):
27
+ def __init__(
28
+ self,
29
+ head,
30
+ features=256,
31
+ backbone="vitb_rn50_384",
32
+ readout="project",
33
+ channels_last=False,
34
+ use_bn=False,
35
+ ):
36
+
37
+ super(DPT, self).__init__()
38
+
39
+ self.channels_last = channels_last
40
+
41
+ hooks = {
42
+ "vitb_rn50_384": [0, 1, 8, 11],
43
+ "vitb16_384": [2, 5, 8, 11],
44
+ "vitl16_384": [5, 11, 17, 23],
45
+ }
46
+
47
+ # Instantiate backbone and reassemble blocks
48
+ self.pretrained, self.scratch = _make_encoder(
49
+ backbone,
50
+ features,
51
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
52
+ groups=1,
53
+ expand=False,
54
+ exportable=False,
55
+ hooks=hooks[backbone],
56
+ use_readout=readout,
57
+ )
58
+
59
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63
+
64
+ self.scratch.output_conv = head
65
+
66
+
67
+ def forward(self, x):
68
+ if self.channels_last == True:
69
+ x.contiguous(memory_format=torch.channels_last)
70
+
71
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72
+
73
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
74
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
75
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
76
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
77
+
78
+ path_4 = self.scratch.refinenet4(layer_4_rn)
79
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82
+
83
+ out = self.scratch.output_conv(path_1)
84
+
85
+ return out
86
+
87
+
88
+ class DPTDepthModel(DPT):
89
+ def __init__(self, path=None, non_negative=True, **kwargs):
90
+ features = kwargs["features"] if "features" in kwargs else 256
91
+
92
+ head = nn.Sequential(
93
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96
+ nn.ReLU(True),
97
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98
+ nn.ReLU(True) if non_negative else nn.Identity(),
99
+ nn.Identity(),
100
+ )
101
+
102
+ super().__init__(head, **kwargs)
103
+
104
+ if path is not None:
105
+ self.load(path)
106
+
107
+ def forward(self, x):
108
+ return super().forward(x).squeeze(dim=1)
condition/midas/midas/midas_net.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=256, non_negative=True):
17
+ """Init.
18
+
19
+ Args:
20
+ path (str, optional): Path to saved model. Defaults to None.
21
+ features (int, optional): Number of features. Defaults to 256.
22
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23
+ """
24
+ print("Loading weights: ", path)
25
+
26
+ super(MidasNet, self).__init__()
27
+
28
+ use_pretrained = False if path is None else True
29
+
30
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31
+
32
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
33
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
34
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
35
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
36
+
37
+ self.scratch.output_conv = nn.Sequential(
38
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39
+ Interpolate(scale_factor=2, mode="bilinear"),
40
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41
+ nn.ReLU(True),
42
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43
+ nn.ReLU(True) if non_negative else nn.Identity(),
44
+ )
45
+
46
+ if path:
47
+ self.load(path)
48
+
49
+ def forward(self, x):
50
+ """Forward pass.
51
+
52
+ Args:
53
+ x (tensor): input data (image)
54
+
55
+ Returns:
56
+ tensor: depth
57
+ """
58
+
59
+ layer_1 = self.pretrained.layer1(x)
60
+ layer_2 = self.pretrained.layer2(layer_1)
61
+ layer_3 = self.pretrained.layer3(layer_2)
62
+ layer_4 = self.pretrained.layer4(layer_3)
63
+
64
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
65
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
66
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
67
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
68
+
69
+ path_4 = self.scratch.refinenet4(layer_4_rn)
70
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73
+
74
+ out = self.scratch.output_conv(path_1)
75
+
76
+ return torch.squeeze(out, dim=1)
condition/midas/midas/midas_net_custom.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet_small(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17
+ blocks={'expand': True}):
18
+ """Init.
19
+
20
+ Args:
21
+ path (str, optional): Path to saved model. Defaults to None.
22
+ features (int, optional): Number of features. Defaults to 256.
23
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24
+ """
25
+ print("Loading weights: ", path)
26
+
27
+ super(MidasNet_small, self).__init__()
28
+
29
+ use_pretrained = False if path else True
30
+
31
+ self.channels_last = channels_last
32
+ self.blocks = blocks
33
+ self.backbone = backbone
34
+
35
+ self.groups = 1
36
+
37
+ features1=features
38
+ features2=features
39
+ features3=features
40
+ features4=features
41
+ self.expand = False
42
+ if "expand" in self.blocks and self.blocks['expand'] == True:
43
+ self.expand = True
44
+ features1=features
45
+ features2=features*2
46
+ features3=features*4
47
+ features4=features*8
48
+
49
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50
+
51
+ self.scratch.activation = nn.ReLU(False)
52
+
53
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57
+
58
+
59
+ self.scratch.output_conv = nn.Sequential(
60
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61
+ Interpolate(scale_factor=2, mode="bilinear"),
62
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63
+ self.scratch.activation,
64
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65
+ nn.ReLU(True) if non_negative else nn.Identity(),
66
+ nn.Identity(),
67
+ )
68
+
69
+ if path:
70
+ self.load(path)
71
+
72
+
73
+ def forward(self, x):
74
+ """Forward pass.
75
+
76
+ Args:
77
+ x (tensor): input data (image)
78
+
79
+ Returns:
80
+ tensor: depth
81
+ """
82
+ if self.channels_last==True:
83
+ print("self.channels_last = ", self.channels_last)
84
+ x.contiguous(memory_format=torch.channels_last)
85
+
86
+
87
+ layer_1 = self.pretrained.layer1(x)
88
+ layer_2 = self.pretrained.layer2(layer_1)
89
+ layer_3 = self.pretrained.layer3(layer_2)
90
+ layer_4 = self.pretrained.layer4(layer_3)
91
+
92
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
93
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
94
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
95
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
96
+
97
+
98
+ path_4 = self.scratch.refinenet4(layer_4_rn)
99
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102
+
103
+ out = self.scratch.output_conv(path_1)
104
+
105
+ return torch.squeeze(out, dim=1)
106
+
107
+
108
+
109
+ def fuse_model(m):
110
+ prev_previous_type = nn.Identity()
111
+ prev_previous_name = ''
112
+ previous_type = nn.Identity()
113
+ previous_name = ''
114
+ for name, module in m.named_modules():
115
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116
+ # print("FUSED ", prev_previous_name, previous_name, name)
117
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119
+ # print("FUSED ", prev_previous_name, previous_name)
120
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122
+ # print("FUSED ", previous_name, name)
123
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124
+
125
+ prev_previous_type = previous_type
126
+ prev_previous_name = previous_name
127
+ previous_type = type(module)
128
+ previous_name = name
condition/midas/midas/transforms.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+
6
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
+
9
+ Args:
10
+ sample (dict): sample
11
+ size (tuple): image size
12
+
13
+ Returns:
14
+ tuple: new size
15
+ """
16
+ shape = list(sample["disparity"].shape)
17
+
18
+ if shape[0] >= size[0] and shape[1] >= size[1]:
19
+ return sample
20
+
21
+ scale = [0, 0]
22
+ scale[0] = size[0] / shape[0]
23
+ scale[1] = size[1] / shape[1]
24
+
25
+ scale = max(scale)
26
+
27
+ shape[0] = math.ceil(scale * shape[0])
28
+ shape[1] = math.ceil(scale * shape[1])
29
+
30
+ # resize
31
+ sample["image"] = cv2.resize(
32
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
+ )
34
+
35
+ sample["disparity"] = cv2.resize(
36
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
+ )
38
+ sample["mask"] = cv2.resize(
39
+ sample["mask"].astype(np.float32),
40
+ tuple(shape[::-1]),
41
+ interpolation=cv2.INTER_NEAREST,
42
+ )
43
+ sample["mask"] = sample["mask"].astype(bool)
44
+
45
+ return tuple(shape)
46
+
47
+
48
+ class Resize(object):
49
+ """Resize sample to given size (width, height).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ width,
55
+ height,
56
+ resize_target=True,
57
+ keep_aspect_ratio=False,
58
+ ensure_multiple_of=1,
59
+ resize_method="lower_bound",
60
+ image_interpolation_method=cv2.INTER_AREA,
61
+ ):
62
+ """Init.
63
+
64
+ Args:
65
+ width (int): desired output width
66
+ height (int): desired output height
67
+ resize_target (bool, optional):
68
+ True: Resize the full sample (image, mask, target).
69
+ False: Resize image only.
70
+ Defaults to True.
71
+ keep_aspect_ratio (bool, optional):
72
+ True: Keep the aspect ratio of the input sample.
73
+ Output sample might not have the given width and height, and
74
+ resize behaviour depends on the parameter 'resize_method'.
75
+ Defaults to False.
76
+ ensure_multiple_of (int, optional):
77
+ Output width and height is constrained to be multiple of this parameter.
78
+ Defaults to 1.
79
+ resize_method (str, optional):
80
+ "lower_bound": Output will be at least as large as the given size.
81
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83
+ Defaults to "lower_bound".
84
+ """
85
+ self.__width = width
86
+ self.__height = height
87
+
88
+ self.__resize_target = resize_target
89
+ self.__keep_aspect_ratio = keep_aspect_ratio
90
+ self.__multiple_of = ensure_multiple_of
91
+ self.__resize_method = resize_method
92
+ self.__image_interpolation_method = image_interpolation_method
93
+
94
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96
+
97
+ if max_val is not None and y > max_val:
98
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
+
100
+ if y < min_val:
101
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
+
103
+ return y
104
+
105
+ def get_size(self, width, height):
106
+ # determine new height and width
107
+ scale_height = self.__height / height
108
+ scale_width = self.__width / width
109
+
110
+ if self.__keep_aspect_ratio:
111
+ if self.__resize_method == "lower_bound":
112
+ # scale such that output size is lower bound
113
+ if scale_width > scale_height:
114
+ # fit width
115
+ scale_height = scale_width
116
+ else:
117
+ # fit height
118
+ scale_width = scale_height
119
+ elif self.__resize_method == "upper_bound":
120
+ # scale such that output size is upper bound
121
+ if scale_width < scale_height:
122
+ # fit width
123
+ scale_height = scale_width
124
+ else:
125
+ # fit height
126
+ scale_width = scale_height
127
+ elif self.__resize_method == "minimal":
128
+ # scale as least as possbile
129
+ if abs(1 - scale_width) < abs(1 - scale_height):
130
+ # fit width
131
+ scale_height = scale_width
132
+ else:
133
+ # fit height
134
+ scale_width = scale_height
135
+ else:
136
+ raise ValueError(
137
+ f"resize_method {self.__resize_method} not implemented"
138
+ )
139
+
140
+ if self.__resize_method == "lower_bound":
141
+ new_height = self.constrain_to_multiple_of(
142
+ scale_height * height, min_val=self.__height
143
+ )
144
+ new_width = self.constrain_to_multiple_of(
145
+ scale_width * width, min_val=self.__width
146
+ )
147
+ elif self.__resize_method == "upper_bound":
148
+ new_height = self.constrain_to_multiple_of(
149
+ scale_height * height, max_val=self.__height
150
+ )
151
+ new_width = self.constrain_to_multiple_of(
152
+ scale_width * width, max_val=self.__width
153
+ )
154
+ elif self.__resize_method == "minimal":
155
+ new_height = self.constrain_to_multiple_of(scale_height * height)
156
+ new_width = self.constrain_to_multiple_of(scale_width * width)
157
+ else:
158
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
159
+
160
+ return (new_width, new_height)
161
+
162
+ def __call__(self, sample):
163
+ width, height = self.get_size(
164
+ sample["image"].shape[1], sample["image"].shape[0]
165
+ )
166
+
167
+ # resize sample
168
+ sample["image"] = cv2.resize(
169
+ sample["image"],
170
+ (width, height),
171
+ interpolation=self.__image_interpolation_method,
172
+ )
173
+
174
+ if self.__resize_target:
175
+ if "disparity" in sample:
176
+ sample["disparity"] = cv2.resize(
177
+ sample["disparity"],
178
+ (width, height),
179
+ interpolation=cv2.INTER_NEAREST,
180
+ )
181
+
182
+ if "depth" in sample:
183
+ sample["depth"] = cv2.resize(
184
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185
+ )
186
+
187
+ sample["mask"] = cv2.resize(
188
+ sample["mask"].astype(np.float32),
189
+ (width, height),
190
+ interpolation=cv2.INTER_NEAREST,
191
+ )
192
+ sample["mask"] = sample["mask"].astype(bool)
193
+
194
+ return sample
195
+
196
+
197
+ class NormalizeImage(object):
198
+ """Normlize image by given mean and std.
199
+ """
200
+
201
+ def __init__(self, mean, std):
202
+ self.__mean = mean
203
+ self.__std = std
204
+
205
+ def __call__(self, sample):
206
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
207
+
208
+ return sample
209
+
210
+
211
+ class PrepareForNet(object):
212
+ """Prepare sample for usage as network input.
213
+ """
214
+
215
+ def __init__(self):
216
+ pass
217
+
218
+ def __call__(self, sample):
219
+ image = np.transpose(sample["image"], (2, 0, 1))
220
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221
+
222
+ if "mask" in sample:
223
+ sample["mask"] = sample["mask"].astype(np.float32)
224
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
225
+
226
+ if "disparity" in sample:
227
+ disparity = sample["disparity"].astype(np.float32)
228
+ sample["disparity"] = np.ascontiguousarray(disparity)
229
+
230
+ if "depth" in sample:
231
+ depth = sample["depth"].astype(np.float32)
232
+ sample["depth"] = np.ascontiguousarray(depth)
233
+
234
+ return sample
condition/midas/midas/vit.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Slice(nn.Module):
10
+ def __init__(self, start_index=1):
11
+ super(Slice, self).__init__()
12
+ self.start_index = start_index
13
+
14
+ def forward(self, x):
15
+ return x[:, self.start_index :]
16
+
17
+
18
+ class AddReadout(nn.Module):
19
+ def __init__(self, start_index=1):
20
+ super(AddReadout, self).__init__()
21
+ self.start_index = start_index
22
+
23
+ def forward(self, x):
24
+ if self.start_index == 2:
25
+ readout = (x[:, 0] + x[:, 1]) / 2
26
+ else:
27
+ readout = x[:, 0]
28
+ return x[:, self.start_index :] + readout.unsqueeze(1)
29
+
30
+
31
+ class ProjectReadout(nn.Module):
32
+ def __init__(self, in_features, start_index=1):
33
+ super(ProjectReadout, self).__init__()
34
+ self.start_index = start_index
35
+
36
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37
+
38
+ def forward(self, x):
39
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40
+ features = torch.cat((x[:, self.start_index :], readout), -1)
41
+
42
+ return self.project(features)
43
+
44
+
45
+ class Transpose(nn.Module):
46
+ def __init__(self, dim0, dim1):
47
+ super(Transpose, self).__init__()
48
+ self.dim0 = dim0
49
+ self.dim1 = dim1
50
+
51
+ def forward(self, x):
52
+ x = x.transpose(self.dim0, self.dim1)
53
+ return x
54
+
55
+
56
+ def forward_vit(pretrained, x):
57
+ b, c, h, w = x.shape
58
+
59
+ glob = pretrained.model.forward_flex(x)
60
+
61
+ layer_1 = pretrained.activations["1"]
62
+ layer_2 = pretrained.activations["2"]
63
+ layer_3 = pretrained.activations["3"]
64
+ layer_4 = pretrained.activations["4"]
65
+
66
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
67
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
68
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
69
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
70
+
71
+ unflatten = nn.Sequential(
72
+ nn.Unflatten(
73
+ 2,
74
+ torch.Size(
75
+ [
76
+ h // pretrained.model.patch_size[1],
77
+ w // pretrained.model.patch_size[0],
78
+ ]
79
+ ),
80
+ )
81
+ )
82
+
83
+ if layer_1.ndim == 3:
84
+ layer_1 = unflatten(layer_1)
85
+ if layer_2.ndim == 3:
86
+ layer_2 = unflatten(layer_2)
87
+ if layer_3.ndim == 3:
88
+ layer_3 = unflatten(layer_3)
89
+ if layer_4.ndim == 3:
90
+ layer_4 = unflatten(layer_4)
91
+
92
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
93
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
94
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
95
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
96
+
97
+ return layer_1, layer_2, layer_3, layer_4
98
+
99
+
100
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
101
+ posemb_tok, posemb_grid = (
102
+ posemb[:, : self.start_index],
103
+ posemb[0, self.start_index :],
104
+ )
105
+
106
+ gs_old = int(math.sqrt(len(posemb_grid)))
107
+
108
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
109
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
110
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
111
+
112
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
113
+
114
+ return posemb
115
+
116
+
117
+ def forward_flex(self, x):
118
+ b, c, h, w = x.shape
119
+
120
+ pos_embed = self._resize_pos_embed(
121
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
122
+ )
123
+
124
+ B = x.shape[0]
125
+
126
+ if hasattr(self.patch_embed, "backbone"):
127
+ x = self.patch_embed.backbone(x)
128
+ if isinstance(x, (list, tuple)):
129
+ x = x[-1] # last feature if backbone outputs list/tuple of features
130
+
131
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
132
+
133
+ if getattr(self, "dist_token", None) is not None:
134
+ cls_tokens = self.cls_token.expand(
135
+ B, -1, -1
136
+ ) # stole cls_tokens impl from Phil Wang, thanks
137
+ dist_token = self.dist_token.expand(B, -1, -1)
138
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
139
+ else:
140
+ cls_tokens = self.cls_token.expand(
141
+ B, -1, -1
142
+ ) # stole cls_tokens impl from Phil Wang, thanks
143
+ x = torch.cat((cls_tokens, x), dim=1)
144
+
145
+ x = x + pos_embed
146
+ x = self.pos_drop(x)
147
+
148
+ for blk in self.blocks:
149
+ x = blk(x)
150
+
151
+ x = self.norm(x)
152
+
153
+ return x
154
+
155
+
156
+ activations = {}
157
+
158
+
159
+ def get_activation(name):
160
+ def hook(model, input, output):
161
+ activations[name] = output
162
+
163
+ return hook
164
+
165
+
166
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
167
+ if use_readout == "ignore":
168
+ readout_oper = [Slice(start_index)] * len(features)
169
+ elif use_readout == "add":
170
+ readout_oper = [AddReadout(start_index)] * len(features)
171
+ elif use_readout == "project":
172
+ readout_oper = [
173
+ ProjectReadout(vit_features, start_index) for out_feat in features
174
+ ]
175
+ else:
176
+ assert (
177
+ False
178
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
179
+
180
+ return readout_oper
181
+
182
+
183
+ def _make_vit_b16_backbone(
184
+ model,
185
+ features=[96, 192, 384, 768],
186
+ size=[384, 384],
187
+ hooks=[2, 5, 8, 11],
188
+ vit_features=768,
189
+ use_readout="ignore",
190
+ start_index=1,
191
+ ):
192
+ pretrained = nn.Module()
193
+
194
+ pretrained.model = model
195
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
196
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
197
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
198
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
199
+
200
+ pretrained.activations = activations
201
+
202
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
203
+
204
+ # 32, 48, 136, 384
205
+ pretrained.act_postprocess1 = nn.Sequential(
206
+ readout_oper[0],
207
+ Transpose(1, 2),
208
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
209
+ nn.Conv2d(
210
+ in_channels=vit_features,
211
+ out_channels=features[0],
212
+ kernel_size=1,
213
+ stride=1,
214
+ padding=0,
215
+ ),
216
+ nn.ConvTranspose2d(
217
+ in_channels=features[0],
218
+ out_channels=features[0],
219
+ kernel_size=4,
220
+ stride=4,
221
+ padding=0,
222
+ bias=True,
223
+ dilation=1,
224
+ groups=1,
225
+ ),
226
+ )
227
+
228
+ pretrained.act_postprocess2 = nn.Sequential(
229
+ readout_oper[1],
230
+ Transpose(1, 2),
231
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
232
+ nn.Conv2d(
233
+ in_channels=vit_features,
234
+ out_channels=features[1],
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0,
238
+ ),
239
+ nn.ConvTranspose2d(
240
+ in_channels=features[1],
241
+ out_channels=features[1],
242
+ kernel_size=2,
243
+ stride=2,
244
+ padding=0,
245
+ bias=True,
246
+ dilation=1,
247
+ groups=1,
248
+ ),
249
+ )
250
+
251
+ pretrained.act_postprocess3 = nn.Sequential(
252
+ readout_oper[2],
253
+ Transpose(1, 2),
254
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
255
+ nn.Conv2d(
256
+ in_channels=vit_features,
257
+ out_channels=features[2],
258
+ kernel_size=1,
259
+ stride=1,
260
+ padding=0,
261
+ ),
262
+ )
263
+
264
+ pretrained.act_postprocess4 = nn.Sequential(
265
+ readout_oper[3],
266
+ Transpose(1, 2),
267
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
268
+ nn.Conv2d(
269
+ in_channels=vit_features,
270
+ out_channels=features[3],
271
+ kernel_size=1,
272
+ stride=1,
273
+ padding=0,
274
+ ),
275
+ nn.Conv2d(
276
+ in_channels=features[3],
277
+ out_channels=features[3],
278
+ kernel_size=3,
279
+ stride=2,
280
+ padding=1,
281
+ ),
282
+ )
283
+
284
+ pretrained.model.start_index = start_index
285
+ pretrained.model.patch_size = [16, 16]
286
+
287
+ # We inject this function into the VisionTransformer instances so that
288
+ # we can use it with interpolated position embeddings without modifying the library source.
289
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
290
+ pretrained.model._resize_pos_embed = types.MethodType(
291
+ _resize_pos_embed, pretrained.model
292
+ )
293
+
294
+ return pretrained
295
+
296
+
297
+ def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
298
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
299
+
300
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
301
+ return _make_vit_b16_backbone(
302
+ model,
303
+ features=[256, 512, 1024, 1024],
304
+ hooks=hooks,
305
+ vit_features=1024,
306
+ use_readout=use_readout,
307
+ )
308
+
309
+
310
+ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
311
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
312
+
313
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
314
+ return _make_vit_b16_backbone(
315
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
316
+ )
317
+
318
+
319
+ def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
320
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
321
+
322
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
323
+ return _make_vit_b16_backbone(
324
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
325
+ )
326
+
327
+
328
+ def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
329
+ model = timm.create_model(
330
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
331
+ )
332
+
333
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
334
+ return _make_vit_b16_backbone(
335
+ model,
336
+ features=[96, 192, 384, 768],
337
+ hooks=hooks,
338
+ use_readout=use_readout,
339
+ start_index=2,
340
+ )
341
+
342
+
343
+ def _make_vit_b_rn50_backbone(
344
+ model,
345
+ features=[256, 512, 768, 768],
346
+ size=[384, 384],
347
+ hooks=[0, 1, 8, 11],
348
+ vit_features=768,
349
+ use_vit_only=False,
350
+ use_readout="ignore",
351
+ start_index=1,
352
+ ):
353
+ pretrained = nn.Module()
354
+
355
+ pretrained.model = model
356
+
357
+ if use_vit_only == True:
358
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
359
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
360
+ else:
361
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
362
+ get_activation("1")
363
+ )
364
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
365
+ get_activation("2")
366
+ )
367
+
368
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
369
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
370
+
371
+ pretrained.activations = activations
372
+
373
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
374
+
375
+ if use_vit_only == True:
376
+ pretrained.act_postprocess1 = nn.Sequential(
377
+ readout_oper[0],
378
+ Transpose(1, 2),
379
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
380
+ nn.Conv2d(
381
+ in_channels=vit_features,
382
+ out_channels=features[0],
383
+ kernel_size=1,
384
+ stride=1,
385
+ padding=0,
386
+ ),
387
+ nn.ConvTranspose2d(
388
+ in_channels=features[0],
389
+ out_channels=features[0],
390
+ kernel_size=4,
391
+ stride=4,
392
+ padding=0,
393
+ bias=True,
394
+ dilation=1,
395
+ groups=1,
396
+ ),
397
+ )
398
+
399
+ pretrained.act_postprocess2 = nn.Sequential(
400
+ readout_oper[1],
401
+ Transpose(1, 2),
402
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
403
+ nn.Conv2d(
404
+ in_channels=vit_features,
405
+ out_channels=features[1],
406
+ kernel_size=1,
407
+ stride=1,
408
+ padding=0,
409
+ ),
410
+ nn.ConvTranspose2d(
411
+ in_channels=features[1],
412
+ out_channels=features[1],
413
+ kernel_size=2,
414
+ stride=2,
415
+ padding=0,
416
+ bias=True,
417
+ dilation=1,
418
+ groups=1,
419
+ ),
420
+ )
421
+ else:
422
+ pretrained.act_postprocess1 = nn.Sequential(
423
+ nn.Identity(), nn.Identity(), nn.Identity()
424
+ )
425
+ pretrained.act_postprocess2 = nn.Sequential(
426
+ nn.Identity(), nn.Identity(), nn.Identity()
427
+ )
428
+
429
+ pretrained.act_postprocess3 = nn.Sequential(
430
+ readout_oper[2],
431
+ Transpose(1, 2),
432
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
433
+ nn.Conv2d(
434
+ in_channels=vit_features,
435
+ out_channels=features[2],
436
+ kernel_size=1,
437
+ stride=1,
438
+ padding=0,
439
+ ),
440
+ )
441
+
442
+ pretrained.act_postprocess4 = nn.Sequential(
443
+ readout_oper[3],
444
+ Transpose(1, 2),
445
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
446
+ nn.Conv2d(
447
+ in_channels=vit_features,
448
+ out_channels=features[3],
449
+ kernel_size=1,
450
+ stride=1,
451
+ padding=0,
452
+ ),
453
+ nn.Conv2d(
454
+ in_channels=features[3],
455
+ out_channels=features[3],
456
+ kernel_size=3,
457
+ stride=2,
458
+ padding=1,
459
+ ),
460
+ )
461
+
462
+ pretrained.model.start_index = start_index
463
+ pretrained.model.patch_size = [16, 16]
464
+
465
+ # We inject this function into the VisionTransformer instances so that
466
+ # we can use it with interpolated position embeddings without modifying the library source.
467
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
468
+
469
+ # We inject this function into the VisionTransformer instances so that
470
+ # we can use it with interpolated position embeddings without modifying the library source.
471
+ pretrained.model._resize_pos_embed = types.MethodType(
472
+ _resize_pos_embed, pretrained.model
473
+ )
474
+
475
+ return pretrained
476
+
477
+
478
+ def _make_pretrained_vitb_rn50_384(
479
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
480
+ ):
481
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
482
+
483
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
484
+ return _make_vit_b_rn50_backbone(
485
+ model,
486
+ features=[256, 512, 768, 768],
487
+ size=[384, 384],
488
+ hooks=hooks,
489
+ use_vit_only=use_vit_only,
490
+ use_readout=use_readout,
491
+ )
condition/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import os
4
+
5
+
6
+ annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
7
+
8
+
9
+ def HWC3(x):
10
+ assert x.dtype == np.uint8
11
+ if x.ndim == 2:
12
+ x = x[:, :, None]
13
+ assert x.ndim == 3
14
+ H, W, C = x.shape
15
+ assert C == 1 or C == 3 or C == 4
16
+ if C == 3:
17
+ return x
18
+ if C == 1:
19
+ return np.concatenate([x, x, x], axis=2)
20
+ if C == 4:
21
+ color = x[:, :, 0:3].astype(np.float32)
22
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
23
+ y = color * alpha + 255.0 * (1.0 - alpha)
24
+ y = y.clip(0, 255).astype(np.uint8)
25
+ return y
26
+
27
+
28
+ def resize_image(input_image, resolution):
29
+ H, W, C = input_image.shape
30
+ H = float(H)
31
+ W = float(W)
32
+ k = float(resolution) / min(H, W)
33
+ H *= k
34
+ W *= k
35
+ H = int(np.round(H / 64.0)) * 64
36
+ W = int(np.round(W / 64.0)) * 64
37
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
38
+ return img
language/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Language models for text-conditional image generation
2
+
3
+ ### Requirements
4
+ ```
5
+ pip install ftfy
6
+ pip install transformers
7
+ pip install accelerate
8
+ pip install sentencepiece
9
+ pip install pandas
10
+ pip install bs4
11
+ ```
12
+
13
+ ### Language Models
14
+ Download flan-t5-xl models from [flan-t5-xl](https://huggingface.co/google/flan-t5-xl) and put into the folder of `./pretrained_models/t5-ckpt/`
language/extract_t5_feature.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.backends.cuda.matmul.allow_tf32 = True
3
+ torch.backends.cudnn.allow_tf32 = True
4
+ import torch.distributed as dist
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from torch.utils.data.distributed import DistributedSampler
7
+ import numpy as np
8
+ import argparse
9
+ import os
10
+ import json
11
+
12
+ from utils.distributed import init_distributed_mode
13
+ from language.t5 import T5Embedder
14
+
15
+ CAPTION_KEY = {
16
+ 'blip': 0,
17
+ 'llava': 1,
18
+ 'llava_first': 2,
19
+ }
20
+ #################################################################################
21
+ # Training Helper Functions #
22
+ #################################################################################
23
+ class CustomDataset(Dataset):
24
+ def __init__(self, lst_dir, start, end, caption_key, trunc_caption=False):
25
+ img_path_list = []
26
+ for lst_name in sorted(os.listdir(lst_dir))[start: end+1]:
27
+ if not lst_name.endswith('.jsonl'):
28
+ continue
29
+ file_path = os.path.join(lst_dir, lst_name)
30
+ with open(file_path, 'r') as file:
31
+ for line_idx, line in enumerate(file):
32
+ data = json.loads(line)
33
+ # caption = data[caption_key]
34
+ caption = data['text'][CAPTION_KEY[caption_key]]
35
+ code_dir = file_path.split('/')[-1].split('.')[0]
36
+ if trunc_caption:
37
+ caption = caption.split('.')[0]
38
+ img_path_list.append((caption, code_dir, line_idx))
39
+ self.img_path_list = img_path_list
40
+
41
+ def __len__(self):
42
+ return len(self.img_path_list)
43
+
44
+ def __getitem__(self, index):
45
+ caption, code_dir, code_name = self.img_path_list[index]
46
+ return caption, code_dir, code_name
47
+
48
+
49
+
50
+ #################################################################################
51
+ # Training Loop #
52
+ #################################################################################
53
+ def main(args):
54
+ """
55
+ Trains a new DiT model.
56
+ """
57
+ assert torch.cuda.is_available(), "Training currently requires at least one GPU."
58
+
59
+ # Setup DDP:
60
+ # dist.init_process_group("nccl")
61
+ init_distributed_mode(args)
62
+ rank = dist.get_rank()
63
+ device = rank % torch.cuda.device_count()
64
+ seed = args.global_seed * dist.get_world_size() + rank
65
+ torch.manual_seed(seed)
66
+ torch.cuda.set_device(device)
67
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
68
+
69
+ # Setup a feature folder:
70
+ if rank == 0:
71
+ os.makedirs(args.t5_path, exist_ok=True)
72
+
73
+ # Setup data:
74
+ print(f"Dataset is preparing...")
75
+ dataset = CustomDataset(args.data_path, args.data_start, args.data_end, args.caption_key, args.trunc_caption)
76
+ sampler = DistributedSampler(
77
+ dataset,
78
+ num_replicas=dist.get_world_size(),
79
+ rank=rank,
80
+ shuffle=False,
81
+ seed=args.global_seed
82
+ )
83
+ loader = DataLoader(
84
+ dataset,
85
+ batch_size=1, # important!
86
+ shuffle=False,
87
+ sampler=sampler,
88
+ num_workers=args.num_workers,
89
+ pin_memory=True,
90
+ drop_last=False
91
+ )
92
+ print(f"Dataset contains {len(dataset):,} images")
93
+
94
+ precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
95
+ assert os.path.exists(args.t5_model_path)
96
+ t5_xxl = T5Embedder(
97
+ device=device,
98
+ local_cache=True,
99
+ cache_dir=args.t5_model_path,
100
+ dir_or_name=args.t5_model_type,
101
+ torch_dtype=precision
102
+ )
103
+
104
+ for caption, code_dir, code_name in loader:
105
+ caption_embs, emb_masks = t5_xxl.get_text_embeddings(caption)
106
+ valid_caption_embs = caption_embs[:, :emb_masks.sum()]
107
+ x = valid_caption_embs.to(torch.float32).detach().cpu().numpy()
108
+ os.makedirs(os.path.join(args.t5_path, code_dir[0]), exist_ok=True)
109
+ np.save(os.path.join(args.t5_path, code_dir[0], '{}.npy'.format(code_name.item())), x)
110
+ print(code_name.item())
111
+
112
+ dist.destroy_process_group()
113
+
114
+
115
+ if __name__ == "__main__":
116
+ parser = argparse.ArgumentParser()
117
+ parser.add_argument("--data-path", type=str, required=True)
118
+ parser.add_argument("--t5-path", type=str, required=True)
119
+ parser.add_argument("--data-start", type=int, required=True)
120
+ parser.add_argument("--data-end", type=int, required=True)
121
+ parser.add_argument("--caption-key", type=str, default='blip', choices=list(CAPTION_KEY.keys()))
122
+ parser.add_argument("--trunc-caption", action='store_true', default=False)
123
+ parser.add_argument("--t5-model-path", type=str, default='./pretrained_models/t5-ckpt')
124
+ parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
125
+ parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
126
+ parser.add_argument("--global-seed", type=int, default=0)
127
+ parser.add_argument("--num-workers", type=int, default=24)
128
+ args = parser.parse_args()
129
+ main(args)
language/t5.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/t5.py
3
+ import os
4
+ import re
5
+ import html
6
+ import urllib.parse as ul
7
+
8
+ import ftfy
9
+ import torch
10
+ from bs4 import BeautifulSoup
11
+ from transformers import T5EncoderModel, AutoTokenizer
12
+ from huggingface_hub import hf_hub_download
13
+
14
+
15
+ class T5Embedder:
16
+ available_models = ['t5-v1_1-xxl', 't5-v1_1-xl', 'flan-t5-xl']
17
+ bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa
18
+
19
+ def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
20
+ t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120):
21
+ self.device = torch.device(device)
22
+ self.torch_dtype = torch_dtype or torch.bfloat16
23
+ if t5_model_kwargs is None:
24
+ t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
25
+ t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}
26
+
27
+ self.use_text_preprocessing = use_text_preprocessing
28
+ self.hf_token = hf_token
29
+ self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
30
+ self.dir_or_name = dir_or_name
31
+ tokenizer_path, path = dir_or_name, dir_or_name
32
+ if local_cache:
33
+ cache_dir = os.path.join(self.cache_dir, dir_or_name)
34
+ tokenizer_path, path = cache_dir, cache_dir
35
+ elif dir_or_name in self.available_models:
36
+ cache_dir = os.path.join(self.cache_dir, dir_or_name)
37
+ for filename in [
38
+ 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
39
+ 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
40
+ ]:
41
+ hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
42
+ force_filename=filename, token=self.hf_token)
43
+ tokenizer_path, path = cache_dir, cache_dir
44
+ else:
45
+ cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
46
+ for filename in [
47
+ 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
48
+ ]:
49
+ hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
50
+ force_filename=filename, token=self.hf_token)
51
+ tokenizer_path = cache_dir
52
+
53
+ print(tokenizer_path)
54
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
55
+ self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
56
+ self.model_max_length = model_max_length
57
+
58
+ def get_text_embeddings(self, texts):
59
+ texts = [self.text_preprocessing(text) for text in texts]
60
+
61
+ text_tokens_and_mask = self.tokenizer(
62
+ texts,
63
+ max_length=self.model_max_length,
64
+ padding='max_length',
65
+ truncation=True,
66
+ return_attention_mask=True,
67
+ add_special_tokens=True,
68
+ return_tensors='pt'
69
+ )
70
+
71
+ text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids']
72
+ text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
73
+
74
+ with torch.no_grad():
75
+ text_encoder_embs = self.model(
76
+ input_ids=text_tokens_and_mask['input_ids'].to(self.device),
77
+ attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
78
+ )['last_hidden_state'].detach()
79
+ return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device)
80
+
81
+ def text_preprocessing(self, text):
82
+ if self.use_text_preprocessing:
83
+ # The exact text cleaning as was in the training stage:
84
+ text = self.clean_caption(text)
85
+ text = self.clean_caption(text)
86
+ return text
87
+ else:
88
+ return text.lower().strip()
89
+
90
+ @staticmethod
91
+ def basic_clean(text):
92
+ text = ftfy.fix_text(text)
93
+ text = html.unescape(html.unescape(text))
94
+ return text.strip()
95
+
96
+ def clean_caption(self, caption):
97
+ caption = str(caption)
98
+ caption = ul.unquote_plus(caption)
99
+ caption = caption.strip().lower()
100
+ caption = re.sub('<person>', 'person', caption)
101
+ # urls:
102
+ caption = re.sub(
103
+ r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
104
+ '', caption) # regex for urls
105
+ caption = re.sub(
106
+ r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
107
+ '', caption) # regex for urls
108
+ # html:
109
+ caption = BeautifulSoup(caption, features='html.parser').text
110
+
111
+ # @<nickname>
112
+ caption = re.sub(r'@[\w\d]+\b', '', caption)
113
+
114
+ # 31C0—31EF CJK Strokes
115
+ # 31F0—31FF Katakana Phonetic Extensions
116
+ # 3200—32FF Enclosed CJK Letters and Months
117
+ # 3300—33FF CJK Compatibility
118
+ # 3400—4DBF CJK Unified Ideographs Extension A
119
+ # 4DC0—4DFF Yijing Hexagram Symbols
120
+ # 4E00—9FFF CJK Unified Ideographs
121
+ caption = re.sub(r'[\u31c0-\u31ef]+', '', caption)
122
+ caption = re.sub(r'[\u31f0-\u31ff]+', '', caption)
123
+ caption = re.sub(r'[\u3200-\u32ff]+', '', caption)
124
+ caption = re.sub(r'[\u3300-\u33ff]+', '', caption)
125
+ caption = re.sub(r'[\u3400-\u4dbf]+', '', caption)
126
+ caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption)
127
+ caption = re.sub(r'[\u4e00-\u9fff]+', '', caption)
128
+ #######################################################
129
+
130
+ # все виды тире / all types of dash --> "-"
131
+ caption = re.sub(
132
+ r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa
133
+ '-', caption)
134
+
135
+ # кавычки к одному стандарту
136
+ caption = re.sub(r'[`´«»“”¨]', '"', caption)
137
+ caption = re.sub(r'[‘’]', "'", caption)
138
+
139
+ # &quot;
140
+ caption = re.sub(r'&quot;?', '', caption)
141
+ # &amp
142
+ caption = re.sub(r'&amp', '', caption)
143
+
144
+ # ip adresses:
145
+ caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption)
146
+
147
+ # article ids:
148
+ caption = re.sub(r'\d:\d\d\s+$', '', caption)
149
+
150
+ # \n
151
+ caption = re.sub(r'\\n', ' ', caption)
152
+
153
+ # "#123"
154
+ caption = re.sub(r'#\d{1,3}\b', '', caption)
155
+ # "#12345.."
156
+ caption = re.sub(r'#\d{5,}\b', '', caption)
157
+ # "123456.."
158
+ caption = re.sub(r'\b\d{6,}\b', '', caption)
159
+ # filenames:
160
+ caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)
161
+
162
+ #
163
+ caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT"""
164
+ caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT"""
165
+
166
+ caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
167
+ caption = re.sub(r'\s+\.\s+', r' ', caption) # " . "
168
+
169
+ # this-is-my-cute-cat / this_is_my_cute_cat
170
+ regex2 = re.compile(r'(?:\-|\_)')
171
+ if len(re.findall(regex2, caption)) > 3:
172
+ caption = re.sub(regex2, ' ', caption)
173
+
174
+ caption = self.basic_clean(caption)
175
+
176
+ caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640
177
+ caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc
178
+ caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231
179
+
180
+ caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption)
181
+ caption = re.sub(r'(free\s)?download(\sfree)?', '', caption)
182
+ caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption)
183
+ caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption)
184
+ caption = re.sub(r'\bpage\s+\d+\b', '', caption)
185
+
186
+ caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a...
187
+
188
+ caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
189
+
190
+ caption = re.sub(r'\b\s+\:\s+', r': ', caption)
191
+ caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
192
+ caption = re.sub(r'\s+', ' ', caption)
193
+
194
+ caption.strip()
195
+
196
+ caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
197
+ caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
198
+ caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
199
+ caption = re.sub(r'^\.\S+$', '', caption)
200
+
201
+ return caption.strip()
model.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import spaces
3
+ from safetensors.torch import load_file
4
+ from autoregressive.models.gpt_t2i import GPT_models
5
+ from tokenizer.tokenizer_image.vq_model import VQ_models
6
+ from language.t5 import T5Embedder
7
+ import torch
8
+ import numpy as np
9
+ import PIL
10
+ from PIL import Image
11
+ from condition.canny import CannyDetector
12
+ import time
13
+ from autoregressive.models.generate import generate
14
+ from condition.midas.depth import MidasDetector
15
+
16
+ models = {
17
+ "canny": "checkpoints/t2i/canny_MR.safetensors",
18
+ "depth": "checkpoints/t2i/depth_MR.safetensors",
19
+ }
20
+
21
+
22
+ def resize_image_to_16_multiple(image, condition_type='canny'):
23
+ if isinstance(image, np.ndarray):
24
+ image = Image.fromarray(image)
25
+ # image = Image.open(image_path)
26
+ width, height = image.size
27
+
28
+ if condition_type == 'depth': # The depth model requires a side length that is a multiple of 32
29
+ new_width = (width + 31) // 32 * 32
30
+ new_height = (height + 31) // 32 * 32
31
+ else:
32
+ new_width = (width + 15) // 16 * 16
33
+ new_height = (height + 15) // 16 * 16
34
+
35
+ resized_image = image.resize((new_width, new_height))
36
+ return resized_image
37
+
38
+
39
+ class Model:
40
+
41
+ def __init__(self):
42
+ self.device = torch.device(
43
+ "cuda:0" if torch.cuda.is_available() else "cpu")
44
+ self.base_model_id = ""
45
+ self.task_name = ""
46
+ self.vq_model = self.load_vq()
47
+ self.t5_model = self.load_t5()
48
+ self.gpt_model_canny = self.load_gpt(condition_type='canny')
49
+ self.gpt_model_depth = self.load_gpt(condition_type='depth')
50
+ self.get_control_canny = CannyDetector()
51
+ self.get_control_depth = MidasDetector(device=self.device)
52
+
53
+ def load_vq(self):
54
+ vq_model = VQ_models["VQ-16"](codebook_size=16384,
55
+ codebook_embed_dim=8)
56
+ vq_model.to(self.device)
57
+ vq_model.eval()
58
+ checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
59
+ map_location="cpu")
60
+ vq_model.load_state_dict(checkpoint["model"])
61
+ del checkpoint
62
+ print(f"image tokenizer is loaded")
63
+ return vq_model
64
+
65
+ def load_gpt(self, condition_type='canny'):
66
+ gpt_ckpt = models[condition_type]
67
+ precision = torch.bfloat16
68
+ latent_size = 768 // 16
69
+ gpt_model = GPT_models["GPT-XL"](
70
+ block_size=latent_size**2,
71
+ cls_token_num=120,
72
+ model_type='t2i',
73
+ condition_type=condition_type,
74
+ ).to(device=self.device, dtype=precision)
75
+
76
+ model_weight = load_file(gpt_ckpt)
77
+ gpt_model.load_state_dict(model_weight, strict=False)
78
+ gpt_model.eval()
79
+ print(f"gpt model is loaded")
80
+ return gpt_model
81
+
82
+ def load_t5(self):
83
+ precision = torch.bfloat16
84
+ t5_model = T5Embedder(
85
+ device=self.device,
86
+ local_cache=True,
87
+ # cache_dir='checkpoints/t5-ckpt',
88
+ dir_or_name='flan-t5-xl',
89
+ torch_dtype=precision,
90
+ model_max_length=120,
91
+ )
92
+ return t5_model
93
+
94
+ @torch.no_grad()
95
+ @spaces.GPU(enable_queue=True)
96
+ def process_canny(
97
+ self,
98
+ image: np.ndarray,
99
+ prompt: str,
100
+ cfg_scale: float,
101
+ temperature: float,
102
+ top_k: int,
103
+ top_p: int,
104
+ seed: int,
105
+ low_threshold: int,
106
+ high_threshold: int,
107
+ ) -> list[PIL.Image.Image]:
108
+
109
+ image = resize_image_to_16_multiple(image, 'canny')
110
+ W, H = image.size
111
+ print(W, H)
112
+ condition_img = self.get_control_canny(np.array(image), low_threshold,
113
+ high_threshold)
114
+ condition_img = torch.from_numpy(condition_img[None, None,
115
+ ...]).repeat(
116
+ 2, 3, 1, 1)
117
+ condition_img = condition_img.to(self.device)
118
+ condition_img = 2 * (condition_img / 255 - 0.5)
119
+ prompts = [prompt] * 2
120
+ caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
121
+
122
+ print(f"processing left-padding...")
123
+ new_emb_masks = torch.flip(emb_masks, dims=[-1])
124
+ new_caption_embs = []
125
+ for idx, (caption_emb,
126
+ emb_mask) in enumerate(zip(caption_embs, emb_masks)):
127
+ valid_num = int(emb_mask.sum().item())
128
+ print(f' prompt {idx} token len: {valid_num}')
129
+ new_caption_emb = torch.cat(
130
+ [caption_emb[valid_num:], caption_emb[:valid_num]])
131
+ new_caption_embs.append(new_caption_emb)
132
+ new_caption_embs = torch.stack(new_caption_embs)
133
+ c_indices = new_caption_embs * new_emb_masks[:, :, None]
134
+ c_emb_masks = new_emb_masks
135
+ qzshape = [len(c_indices), 8, H // 16, W // 16]
136
+ t1 = time.time()
137
+ index_sample = generate(
138
+ self.gpt_model_canny,
139
+ c_indices,
140
+ (H // 16) * (W // 16),
141
+ c_emb_masks,
142
+ condition=condition_img,
143
+ cfg_scale=cfg_scale,
144
+ temperature=temperature,
145
+ top_k=top_k,
146
+ top_p=top_p,
147
+ sample_logits=True,
148
+ )
149
+ sampling_time = time.time() - t1
150
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
151
+
152
+ t2 = time.time()
153
+ print(index_sample.shape)
154
+ samples = self.vq_model.decode_code(
155
+ index_sample, qzshape) # output value is between [-1, 1]
156
+ decoder_time = time.time() - t2
157
+ print(f"decoder takes about {decoder_time:.2f} seconds.")
158
+
159
+ samples = torch.cat((condition_img[0:1], samples), dim=0)
160
+ samples = 255 * (samples * 0.5 + 0.5)
161
+ samples = [image] + [
162
+ Image.fromarray(
163
+ sample.permute(1, 2, 0).cpu().detach().numpy().clip(
164
+ 0, 255).astype(np.uint8)) for sample in samples
165
+ ]
166
+ del condition_img
167
+ torch.cuda.empty_cache()
168
+ return samples
169
+
170
+ @torch.no_grad()
171
+ @spaces.GPU(enable_queue=True)
172
+ def process_depth(
173
+ self,
174
+ image: np.ndarray,
175
+ prompt: str,
176
+ cfg_scale: float,
177
+ temperature: float,
178
+ top_k: int,
179
+ top_p: int,
180
+ seed: int,
181
+ ) -> list[PIL.Image.Image]:
182
+ image = resize_image_to_16_multiple(image, 'depth')
183
+ W, H = image.size
184
+ print(W, H)
185
+ image_tensor = torch.from_numpy(np.array(image)).to(self.device)
186
+ condition_img = torch.from_numpy(
187
+ self.get_control_depth(image_tensor)).unsqueeze(0)
188
+ condition_img = condition_img.unsqueeze(0).repeat(2, 3, 1, 1)
189
+ condition_img = condition_img.to(self.device)
190
+ condition_img = 2 * (condition_img / 255 - 0.5)
191
+ prompts = [prompt] * 2
192
+ caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
193
+
194
+ print(f"processing left-padding...")
195
+ new_emb_masks = torch.flip(emb_masks, dims=[-1])
196
+ new_caption_embs = []
197
+ for idx, (caption_emb,
198
+ emb_mask) in enumerate(zip(caption_embs, emb_masks)):
199
+ valid_num = int(emb_mask.sum().item())
200
+ print(f' prompt {idx} token len: {valid_num}')
201
+ new_caption_emb = torch.cat(
202
+ [caption_emb[valid_num:], caption_emb[:valid_num]])
203
+ new_caption_embs.append(new_caption_emb)
204
+ new_caption_embs = torch.stack(new_caption_embs)
205
+
206
+ c_indices = new_caption_embs * new_emb_masks[:, :, None]
207
+ c_emb_masks = new_emb_masks
208
+ qzshape = [len(c_indices), 8, H // 16, W // 16]
209
+ t1 = time.time()
210
+ index_sample = generate(
211
+ self.gpt_model_depth,
212
+ c_indices,
213
+ (H // 16) * (W // 16),
214
+ c_emb_masks,
215
+ condition=condition_img,
216
+ cfg_scale=cfg_scale,
217
+ temperature=temperature,
218
+ top_k=top_k,
219
+ top_p=top_p,
220
+ sample_logits=True,
221
+ )
222
+ sampling_time = time.time() - t1
223
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
224
+
225
+ t2 = time.time()
226
+ print(index_sample.shape)
227
+ samples = self.vq_model.decode_code(index_sample, qzshape)
228
+ decoder_time = time.time() - t2
229
+ print(f"decoder takes about {decoder_time:.2f} seconds.")
230
+ condition_img = condition_img.cpu()
231
+ samples = samples.cpu()
232
+ samples = torch.cat((condition_img[0:1], samples), dim=0)
233
+ samples = 255 * (samples * 0.5 + 0.5)
234
+ samples = [image] + [
235
+ Image.fromarray(
236
+ sample.permute(1, 2, 0).numpy().clip(0, 255).astype(np.uint8))
237
+ for sample in samples
238
+ ]
239
+ del image_tensor
240
+ del condition_img
241
+ torch.cuda.empty_cache()
242
+ return samples
style.css ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: #fff;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }
tokenizer/consistencydecoder/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Consistency Decoder from OpenAI
2
+
3
+ ### install
4
+ ```
5
+ pip install diffusers
6
+ pip install accelerate
7
+ ```
8
+
9
+ ### demo
10
+ ```
11
+ cd ${THIS_REPO_ROOT}
12
+ python3 tokenizer/consistencydecoder/cd_demo.py
13
+ ```
14
+
tokenizer/consistencydecoder/cd_demo.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from PIL import Image
6
+ from diffusers import ConsistencyDecoderVAE
7
+
8
+
9
+ def main(args):
10
+ # Setup PyTorch:
11
+ torch.manual_seed(args.seed)
12
+ torch.set_grad_enabled(False)
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ # create and load model
16
+ vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16).to(device)
17
+
18
+ # load image
19
+ img_path = args.image_path
20
+ out_path = args.image_path.replace('.jpg', '_cd.jpg').replace('.jpeg', '_cd.jpeg').replace('.png', '_cd.png')
21
+ input_size = args.image_size
22
+ img = Image.open(img_path).convert("RGB")
23
+
24
+ # preprocess
25
+ size_org = img.size
26
+ img = img.resize((input_size, input_size))
27
+ img = np.array(img) / 255.
28
+ x = 2.0 * img - 1.0 # x value is between [-1, 1]
29
+ x = torch.tensor(x)
30
+ x = x.unsqueeze(dim=0)
31
+ x = torch.einsum('nhwc->nchw', x)
32
+ x_input = x.half().to(device)
33
+
34
+ # inference
35
+ with torch.no_grad():
36
+ # Map input images to latent space + normalize latents:
37
+ latent = vae.encode(x_input).latent_dist.sample().mul_(0.18215)
38
+ # reconstruct:
39
+ output = vae.decode(latent / 0.18215).sample # output value is between [-1, 1]
40
+
41
+ # postprocess
42
+ output = F.interpolate(output, size=[size_org[1], size_org[0]], mode='bilinear').permute(0, 2, 3, 1)[0]
43
+ sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
44
+
45
+ # save
46
+ Image.fromarray(sample).save(out_path)
47
+ print("Reconstructed image is saved to {}".format(out_path))
48
+
49
+
50
+
51
+ if __name__ == "__main__":
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument("--image-path", type=str, default="assets/example.jpg")
54
+ parser.add_argument("--image-size", type=int, choices=[256, 512, 1024], default=512)
55
+ parser.add_argument("--seed", type=int, default=0)
56
+ args = parser.parse_args()
57
+ main(args)
tokenizer/consistencydecoder/reconstruction_cd_ddp.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.backends.cuda.matmul.allow_tf32 = True
3
+ torch.backends.cudnn.allow_tf32 = True
4
+ import torch.distributed as dist
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from torch.utils.data.distributed import DistributedSampler
7
+ from torchvision.datasets import ImageFolder
8
+ from torchvision import transforms
9
+ from tqdm import tqdm
10
+ import os
11
+ import itertools
12
+ from PIL import Image
13
+ import numpy as np
14
+ import argparse
15
+ import random
16
+
17
+ from skimage.metrics import peak_signal_noise_ratio as psnr_loss
18
+ from skimage.metrics import structural_similarity as ssim_loss
19
+ from diffusers.models import ConsistencyDecoderVAE
20
+
21
+
22
+ class SingleFolderDataset(Dataset):
23
+ def __init__(self, directory, transform=None):
24
+ super().__init__()
25
+ self.directory = directory
26
+ self.transform = transform
27
+ self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory)
28
+ if os.path.isfile(os.path.join(directory, file_name))]
29
+
30
+ def __len__(self):
31
+ return len(self.image_paths)
32
+
33
+ def __getitem__(self, idx):
34
+ image_path = self.image_paths[idx]
35
+ image = Image.open(image_path).convert('RGB')
36
+ if self.transform:
37
+ image = self.transform(image)
38
+ return image, torch.tensor(0)
39
+
40
+
41
+ def create_npz_from_sample_folder(sample_dir, num=50_000):
42
+ """
43
+ Builds a single .npz file from a folder of .png samples.
44
+ """
45
+ samples = []
46
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
47
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
48
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
49
+ samples.append(sample_np)
50
+
51
+ random.shuffle(samples) # This is very important for IS(Inception Score) !!!
52
+ samples = np.stack(samples)
53
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
54
+ npz_path = f"{sample_dir}.npz"
55
+ np.savez(npz_path, arr_0=samples)
56
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
57
+ return npz_path
58
+
59
+
60
+ def center_crop_arr(pil_image, image_size):
61
+ """
62
+ Center cropping implementation from ADM.
63
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
64
+ """
65
+ while min(*pil_image.size) >= 2 * image_size:
66
+ pil_image = pil_image.resize(
67
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
68
+ )
69
+
70
+ scale = image_size / min(*pil_image.size)
71
+ pil_image = pil_image.resize(
72
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
73
+ )
74
+
75
+ arr = np.array(pil_image)
76
+ crop_y = (arr.shape[0] - image_size) // 2
77
+ crop_x = (arr.shape[1] - image_size) // 2
78
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
79
+
80
+
81
+ def main(args):
82
+ # Setup PyTorch:
83
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
84
+ torch.set_grad_enabled(False)
85
+
86
+ # Setup env
87
+ dist.init_process_group("nccl")
88
+ rank = dist.get_rank()
89
+ device = rank % torch.cuda.device_count()
90
+ seed = args.global_seed * dist.get_world_size() + rank
91
+ torch.manual_seed(seed)
92
+ torch.cuda.set_device(device)
93
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
94
+
95
+ # create and load model
96
+ vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16).to("cuda:{}".format(device))
97
+
98
+ # Create folder to save samples:
99
+ folder_name = f"openai-consistencydecoder-{args.dataset}-size-{args.image_size}-seed-{args.global_seed}"
100
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
101
+ if rank == 0:
102
+ os.makedirs(sample_folder_dir, exist_ok=True)
103
+ print(f"Saving .png samples at {sample_folder_dir}")
104
+ dist.barrier()
105
+
106
+ # Setup data:
107
+ transform = transforms.Compose([
108
+ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
109
+ transforms.ToTensor(),
110
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
111
+ ])
112
+ if args.dataset == 'imagenet':
113
+ dataset = ImageFolder(args.data_path, transform=transform)
114
+ num_fid_samples = 50000
115
+ elif args.dataset == 'coco':
116
+ dataset = SingleFolderDataset(args.data_path, transform=transform)
117
+ num_fid_samples = 5000
118
+ else:
119
+ raise Exception("please check dataset")
120
+ sampler = DistributedSampler(
121
+ dataset,
122
+ num_replicas=dist.get_world_size(),
123
+ rank=rank,
124
+ shuffle=False,
125
+ seed=args.global_seed
126
+ )
127
+ loader = DataLoader(
128
+ dataset,
129
+ batch_size=args.per_proc_batch_size,
130
+ shuffle=False,
131
+ sampler=sampler,
132
+ num_workers=args.num_workers,
133
+ pin_memory=True,
134
+ drop_last=False
135
+ )
136
+
137
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
138
+ n = args.per_proc_batch_size
139
+ global_batch_size = n * dist.get_world_size()
140
+ psnr_val_rgb = []
141
+ ssim_val_rgb = []
142
+
143
+ loader = tqdm(loader) if rank == 0 else loader
144
+ total = 0
145
+ for x, _ in loader:
146
+ rgb_gts = x
147
+ rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1]
148
+ x = x.half().to("cuda:{}".format(device))
149
+ with torch.no_grad():
150
+ # Map input images to latent space + normalize latents:
151
+ latent = vae.encode(x).latent_dist.sample().mul_(0.18215)
152
+ # reconstruct:
153
+ samples = vae.decode(latent / 0.18215).sample # output value is between [-1, 1]
154
+ samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
155
+
156
+ # Save samples to disk as individual .png files
157
+ for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)):
158
+ index = i * dist.get_world_size() + rank + total
159
+ Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
160
+ # metric
161
+ rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1]
162
+ psnr = psnr_loss(rgb_restored, rgb_gt)
163
+ ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1)
164
+ psnr_val_rgb.append(psnr)
165
+ ssim_val_rgb.append(ssim)
166
+ total += global_batch_size
167
+
168
+ # ------------------------------------
169
+ # Summary
170
+ # ------------------------------------
171
+ # Make sure all processes have finished saving their samples
172
+ dist.barrier()
173
+ world_size = dist.get_world_size()
174
+ gather_psnr_val = [None for _ in range(world_size)]
175
+ gather_ssim_val = [None for _ in range(world_size)]
176
+ dist.all_gather_object(gather_psnr_val, psnr_val_rgb)
177
+ dist.all_gather_object(gather_ssim_val, ssim_val_rgb)
178
+
179
+ if rank == 0:
180
+ gather_psnr_val = list(itertools.chain(*gather_psnr_val))
181
+ gather_ssim_val = list(itertools.chain(*gather_ssim_val))
182
+ psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val)
183
+ ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val)
184
+ print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb))
185
+
186
+ result_file = f"{sample_folder_dir}_results.txt"
187
+ print("writing results to {}".format(result_file))
188
+ with open(result_file, 'w') as f:
189
+ print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f)
190
+
191
+ create_npz_from_sample_folder(sample_folder_dir, num_fid_samples)
192
+ print("Done.")
193
+
194
+ dist.barrier()
195
+ dist.destroy_process_group()
196
+
197
+
198
+ if __name__ == "__main__":
199
+ parser = argparse.ArgumentParser()
200
+ parser.add_argument("--data-path", type=str, required=True)
201
+ parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet')
202
+ parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
203
+ parser.add_argument("--sample-dir", type=str, default="reconstructions")
204
+ parser.add_argument("--per-proc-batch-size", type=int, default=32)
205
+ parser.add_argument("--global-seed", type=int, default=0)
206
+ parser.add_argument("--num-workers", type=int, default=4)
207
+ args = parser.parse_args()
208
+ main(args)
tokenizer/tokenizer_image/cache/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
+ size 7289
tokenizer/tokenizer_image/discriminator.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # taming-transformers: https://github.com/CompVis/taming-transformers
3
+ # stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
4
+ # maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py
5
+ import functools
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ try:
10
+ from kornia.filters import filter2d
11
+ except:
12
+ pass
13
+
14
+ #################################################################################
15
+ # PatchGAN #
16
+ #################################################################################
17
+ class PatchGANDiscriminator(nn.Module):
18
+ """Defines a PatchGAN discriminator as in Pix2Pix
19
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
+ """
21
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22
+ """Construct a PatchGAN discriminator
23
+ Parameters:
24
+ input_nc (int) -- the number of channels in input images
25
+ ndf (int) -- the number of filters in the last conv layer
26
+ n_layers (int) -- the number of conv layers in the discriminator
27
+ norm_layer -- normalization layer
28
+ """
29
+ super(PatchGANDiscriminator, self).__init__()
30
+ if not use_actnorm:
31
+ norm_layer = nn.BatchNorm2d
32
+ else:
33
+ norm_layer = ActNorm
34
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35
+ use_bias = norm_layer.func != nn.BatchNorm2d
36
+ else:
37
+ use_bias = norm_layer != nn.BatchNorm2d
38
+
39
+ kw = 4
40
+ padw = 1
41
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42
+ nf_mult = 1
43
+ nf_mult_prev = 1
44
+ for n in range(1, n_layers): # gradually increase the number of filters
45
+ nf_mult_prev = nf_mult
46
+ nf_mult = min(2 ** n, 8)
47
+ sequence += [
48
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49
+ norm_layer(ndf * nf_mult),
50
+ nn.LeakyReLU(0.2, True)
51
+ ]
52
+
53
+ nf_mult_prev = nf_mult
54
+ nf_mult = min(2 ** n_layers, 8)
55
+ sequence += [
56
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57
+ norm_layer(ndf * nf_mult),
58
+ nn.LeakyReLU(0.2, True)
59
+ ]
60
+
61
+ sequence += [
62
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63
+ self.main = nn.Sequential(*sequence)
64
+
65
+ self.apply(self._init_weights)
66
+
67
+ def _init_weights(self, module):
68
+ if isinstance(module, nn.Conv2d):
69
+ nn.init.normal_(module.weight.data, 0.0, 0.02)
70
+ elif isinstance(module, nn.BatchNorm2d):
71
+ nn.init.normal_(module.weight.data, 1.0, 0.02)
72
+ nn.init.constant_(module.bias.data, 0)
73
+
74
+ def forward(self, input):
75
+ """Standard forward."""
76
+ return self.main(input)
77
+
78
+
79
+ class ActNorm(nn.Module):
80
+ def __init__(self, num_features, logdet=False, affine=True,
81
+ allow_reverse_init=False):
82
+ assert affine
83
+ super().__init__()
84
+ self.logdet = logdet
85
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
86
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
87
+ self.allow_reverse_init = allow_reverse_init
88
+
89
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
90
+
91
+ def initialize(self, input):
92
+ with torch.no_grad():
93
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
94
+ mean = (
95
+ flatten.mean(1)
96
+ .unsqueeze(1)
97
+ .unsqueeze(2)
98
+ .unsqueeze(3)
99
+ .permute(1, 0, 2, 3)
100
+ )
101
+ std = (
102
+ flatten.std(1)
103
+ .unsqueeze(1)
104
+ .unsqueeze(2)
105
+ .unsqueeze(3)
106
+ .permute(1, 0, 2, 3)
107
+ )
108
+
109
+ self.loc.data.copy_(-mean)
110
+ self.scale.data.copy_(1 / (std + 1e-6))
111
+
112
+ def forward(self, input, reverse=False):
113
+ if reverse:
114
+ return self.reverse(input)
115
+ if len(input.shape) == 2:
116
+ input = input[:,:,None,None]
117
+ squeeze = True
118
+ else:
119
+ squeeze = False
120
+
121
+ _, _, height, width = input.shape
122
+
123
+ if self.training and self.initialized.item() == 0:
124
+ self.initialize(input)
125
+ self.initialized.fill_(1)
126
+
127
+ h = self.scale * (input + self.loc)
128
+
129
+ if squeeze:
130
+ h = h.squeeze(-1).squeeze(-1)
131
+
132
+ if self.logdet:
133
+ log_abs = torch.log(torch.abs(self.scale))
134
+ logdet = height*width*torch.sum(log_abs)
135
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
136
+ return h, logdet
137
+
138
+ return h
139
+
140
+ def reverse(self, output):
141
+ if self.training and self.initialized.item() == 0:
142
+ if not self.allow_reverse_init:
143
+ raise RuntimeError(
144
+ "Initializing ActNorm in reverse direction is "
145
+ "disabled by default. Use allow_reverse_init=True to enable."
146
+ )
147
+ else:
148
+ self.initialize(output)
149
+ self.initialized.fill_(1)
150
+
151
+ if len(output.shape) == 2:
152
+ output = output[:,:,None,None]
153
+ squeeze = True
154
+ else:
155
+ squeeze = False
156
+
157
+ h = output / self.scale - self.loc
158
+
159
+ if squeeze:
160
+ h = h.squeeze(-1).squeeze(-1)
161
+ return h
162
+
163
+
164
+
165
+ #################################################################################
166
+ # StyleGAN #
167
+ #################################################################################
168
+ class StyleGANDiscriminator(nn.Module):
169
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256):
170
+ super().__init__()
171
+ channels = {
172
+ 4: 512,
173
+ 8: 512,
174
+ 16: 512,
175
+ 32: 512,
176
+ 64: 256 * channel_multiplier,
177
+ 128: 128 * channel_multiplier,
178
+ 256: 64 * channel_multiplier,
179
+ 512: 32 * channel_multiplier,
180
+ 1024: 16 * channel_multiplier,
181
+ }
182
+
183
+ log_size = int(math.log(image_size, 2))
184
+ in_channel = channels[image_size]
185
+
186
+ blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()]
187
+ for i in range(log_size, 2, -1):
188
+ out_channel = channels[2 ** (i - 1)]
189
+ blocks.append(DiscriminatorBlock(in_channel, out_channel))
190
+ in_channel = out_channel
191
+ self.blocks = nn.ModuleList(blocks)
192
+
193
+ self.final_conv = nn.Sequential(
194
+ nn.Conv2d(in_channel, channels[4], 3, padding=1),
195
+ leaky_relu(),
196
+ )
197
+ self.final_linear = nn.Sequential(
198
+ nn.Linear(channels[4] * 4 * 4, channels[4]),
199
+ leaky_relu(),
200
+ nn.Linear(channels[4], 1)
201
+ )
202
+
203
+ def forward(self, x):
204
+ for block in self.blocks:
205
+ x = block(x)
206
+ x = self.final_conv(x)
207
+ x = x.view(x.shape[0], -1)
208
+ x = self.final_linear(x)
209
+ return x
210
+
211
+
212
+ class DiscriminatorBlock(nn.Module):
213
+ def __init__(self, input_channels, filters, downsample=True):
214
+ super().__init__()
215
+ self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
216
+
217
+ self.net = nn.Sequential(
218
+ nn.Conv2d(input_channels, filters, 3, padding=1),
219
+ leaky_relu(),
220
+ nn.Conv2d(filters, filters, 3, padding=1),
221
+ leaky_relu()
222
+ )
223
+
224
+ self.downsample = nn.Sequential(
225
+ Blur(),
226
+ nn.Conv2d(filters, filters, 3, padding = 1, stride = 2)
227
+ ) if downsample else None
228
+
229
+ def forward(self, x):
230
+ res = self.conv_res(x)
231
+ x = self.net(x)
232
+ if exists(self.downsample):
233
+ x = self.downsample(x)
234
+ x = (x + res) * (1 / math.sqrt(2))
235
+ return x
236
+
237
+
238
+ class Blur(nn.Module):
239
+ def __init__(self):
240
+ super().__init__()
241
+ f = torch.Tensor([1, 2, 1])
242
+ self.register_buffer('f', f)
243
+
244
+ def forward(self, x):
245
+ f = self.f
246
+ f = f[None, None, :] * f [None, :, None]
247
+ return filter2d(x, f, normalized=True)
248
+
249
+
250
+ def leaky_relu(p=0.2):
251
+ return nn.LeakyReLU(p, inplace=True)
252
+
253
+
254
+ def exists(val):
255
+ return val is not None
tokenizer/tokenizer_image/discriminator_patchgan.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # taming-transformers: https://github.com/CompVis/taming-transformers
3
+ import functools
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class NLayerDiscriminator(nn.Module):
9
+ """Defines a PatchGAN discriminator as in Pix2Pix
10
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
11
+ """
12
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
13
+ """Construct a PatchGAN discriminator
14
+ Parameters:
15
+ input_nc (int) -- the number of channels in input images
16
+ ndf (int) -- the number of filters in the last conv layer
17
+ n_layers (int) -- the number of conv layers in the discriminator
18
+ norm_layer -- normalization layer
19
+ """
20
+ super(NLayerDiscriminator, self).__init__()
21
+ if not use_actnorm:
22
+ norm_layer = nn.BatchNorm2d
23
+ else:
24
+ norm_layer = ActNorm
25
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
26
+ use_bias = norm_layer.func != nn.BatchNorm2d
27
+ else:
28
+ use_bias = norm_layer != nn.BatchNorm2d
29
+
30
+ kw = 4
31
+ padw = 1
32
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
33
+ nf_mult = 1
34
+ nf_mult_prev = 1
35
+ for n in range(1, n_layers): # gradually increase the number of filters
36
+ nf_mult_prev = nf_mult
37
+ nf_mult = min(2 ** n, 8)
38
+ sequence += [
39
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
40
+ norm_layer(ndf * nf_mult),
41
+ nn.LeakyReLU(0.2, True)
42
+ ]
43
+
44
+ nf_mult_prev = nf_mult
45
+ nf_mult = min(2 ** n_layers, 8)
46
+ sequence += [
47
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
48
+ norm_layer(ndf * nf_mult),
49
+ nn.LeakyReLU(0.2, True)
50
+ ]
51
+
52
+ sequence += [
53
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
54
+ self.main = nn.Sequential(*sequence)
55
+
56
+ self.apply(self._init_weights)
57
+
58
+ def _init_weights(self, module):
59
+ if isinstance(module, nn.Conv2d):
60
+ nn.init.normal_(module.weight.data, 0.0, 0.02)
61
+ elif isinstance(module, nn.BatchNorm2d):
62
+ nn.init.normal_(module.weight.data, 1.0, 0.02)
63
+ nn.init.constant_(module.bias.data, 0)
64
+
65
+ def forward(self, input):
66
+ """Standard forward."""
67
+ return self.main(input)
68
+
69
+
70
+ class ActNorm(nn.Module):
71
+ def __init__(self, num_features, logdet=False, affine=True,
72
+ allow_reverse_init=False):
73
+ assert affine
74
+ super().__init__()
75
+ self.logdet = logdet
76
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
77
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
78
+ self.allow_reverse_init = allow_reverse_init
79
+
80
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
81
+
82
+ def initialize(self, input):
83
+ with torch.no_grad():
84
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
85
+ mean = (
86
+ flatten.mean(1)
87
+ .unsqueeze(1)
88
+ .unsqueeze(2)
89
+ .unsqueeze(3)
90
+ .permute(1, 0, 2, 3)
91
+ )
92
+ std = (
93
+ flatten.std(1)
94
+ .unsqueeze(1)
95
+ .unsqueeze(2)
96
+ .unsqueeze(3)
97
+ .permute(1, 0, 2, 3)
98
+ )
99
+
100
+ self.loc.data.copy_(-mean)
101
+ self.scale.data.copy_(1 / (std + 1e-6))
102
+
103
+ def forward(self, input, reverse=False):
104
+ if reverse:
105
+ return self.reverse(input)
106
+ if len(input.shape) == 2:
107
+ input = input[:,:,None,None]
108
+ squeeze = True
109
+ else:
110
+ squeeze = False
111
+
112
+ _, _, height, width = input.shape
113
+
114
+ if self.training and self.initialized.item() == 0:
115
+ self.initialize(input)
116
+ self.initialized.fill_(1)
117
+
118
+ h = self.scale * (input + self.loc)
119
+
120
+ if squeeze:
121
+ h = h.squeeze(-1).squeeze(-1)
122
+
123
+ if self.logdet:
124
+ log_abs = torch.log(torch.abs(self.scale))
125
+ logdet = height*width*torch.sum(log_abs)
126
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
127
+ return h, logdet
128
+
129
+ return h
130
+
131
+ def reverse(self, output):
132
+ if self.training and self.initialized.item() == 0:
133
+ if not self.allow_reverse_init:
134
+ raise RuntimeError(
135
+ "Initializing ActNorm in reverse direction is "
136
+ "disabled by default. Use allow_reverse_init=True to enable."
137
+ )
138
+ else:
139
+ self.initialize(output)
140
+ self.initialized.fill_(1)
141
+
142
+ if len(output.shape) == 2:
143
+ output = output[:,:,None,None]
144
+ squeeze = True
145
+ else:
146
+ squeeze = False
147
+
148
+ h = output / self.scale - self.loc
149
+
150
+ if squeeze:
151
+ h = h.squeeze(-1).squeeze(-1)
152
+ return h
tokenizer/tokenizer_image/discriminator_stylegan.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # stylegan2-pytorch: https://github.com/lucidrains/stylegan2-pytorch/blob/master/stylegan2_pytorch/stylegan2_pytorch.py
3
+ # stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
4
+ # maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ try:
9
+ from kornia.filters import filter2d
10
+ except:
11
+ pass
12
+
13
+ class Discriminator(nn.Module):
14
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256):
15
+ super().__init__()
16
+ channels = {
17
+ 4: 512,
18
+ 8: 512,
19
+ 16: 512,
20
+ 32: 512,
21
+ 64: 256 * channel_multiplier,
22
+ 128: 128 * channel_multiplier,
23
+ 256: 64 * channel_multiplier,
24
+ 512: 32 * channel_multiplier,
25
+ 1024: 16 * channel_multiplier,
26
+ }
27
+
28
+ log_size = int(math.log(image_size, 2))
29
+ in_channel = channels[image_size]
30
+
31
+ blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()]
32
+ for i in range(log_size, 2, -1):
33
+ out_channel = channels[2 ** (i - 1)]
34
+ blocks.append(DiscriminatorBlock(in_channel, out_channel))
35
+ in_channel = out_channel
36
+ self.blocks = nn.ModuleList(blocks)
37
+
38
+ self.final_conv = nn.Sequential(
39
+ nn.Conv2d(in_channel, channels[4], 3, padding=1),
40
+ leaky_relu(),
41
+ )
42
+ self.final_linear = nn.Sequential(
43
+ nn.Linear(channels[4] * 4 * 4, channels[4]),
44
+ leaky_relu(),
45
+ nn.Linear(channels[4], 1)
46
+ )
47
+
48
+ def forward(self, x):
49
+ for block in self.blocks:
50
+ x = block(x)
51
+ x = self.final_conv(x)
52
+ x = x.view(x.shape[0], -1)
53
+ x = self.final_linear(x)
54
+ return x
55
+
56
+
57
+ class DiscriminatorBlock(nn.Module):
58
+ def __init__(self, input_channels, filters, downsample=True):
59
+ super().__init__()
60
+ self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
61
+
62
+ self.net = nn.Sequential(
63
+ nn.Conv2d(input_channels, filters, 3, padding=1),
64
+ leaky_relu(),
65
+ nn.Conv2d(filters, filters, 3, padding=1),
66
+ leaky_relu()
67
+ )
68
+
69
+ self.downsample = nn.Sequential(
70
+ Blur(),
71
+ nn.Conv2d(filters, filters, 3, padding = 1, stride = 2)
72
+ ) if downsample else None
73
+
74
+ def forward(self, x):
75
+ res = self.conv_res(x)
76
+ x = self.net(x)
77
+ if exists(self.downsample):
78
+ x = self.downsample(x)
79
+ x = (x + res) * (1 / math.sqrt(2))
80
+ return x
81
+
82
+
83
+
84
+ class Blur(nn.Module):
85
+ def __init__(self):
86
+ super().__init__()
87
+ f = torch.Tensor([1, 2, 1])
88
+ self.register_buffer('f', f)
89
+
90
+ def forward(self, x):
91
+ f = self.f
92
+ f = f[None, None, :] * f [None, :, None]
93
+ return filter2d(x, f, normalized=True)
94
+
95
+
96
+ def leaky_relu(p=0.2):
97
+ return nn.LeakyReLU(p, inplace=True)
98
+
99
+
100
+ def exists(val):
101
+ return val is not None
tokenizer/tokenizer_image/lpips.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
+
3
+ import os, hashlib
4
+ import requests
5
+ from tqdm import tqdm
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torchvision import models
10
+ from collections import namedtuple
11
+
12
+ URL_MAP = {
13
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
14
+ }
15
+
16
+ CKPT_MAP = {
17
+ "vgg_lpips": "vgg.pth"
18
+ }
19
+
20
+ MD5_MAP = {
21
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
22
+ }
23
+
24
+ def download(url, local_path, chunk_size=1024):
25
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
26
+ with requests.get(url, stream=True) as r:
27
+ total_size = int(r.headers.get("content-length", 0))
28
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
29
+ with open(local_path, "wb") as f:
30
+ for data in r.iter_content(chunk_size=chunk_size):
31
+ if data:
32
+ f.write(data)
33
+ pbar.update(chunk_size)
34
+
35
+
36
+ def md5_hash(path):
37
+ with open(path, "rb") as f:
38
+ content = f.read()
39
+ return hashlib.md5(content).hexdigest()
40
+
41
+
42
+ def get_ckpt_path(name, root, check=False):
43
+ assert name in URL_MAP
44
+ path = os.path.join(root, CKPT_MAP[name])
45
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
46
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
47
+ download(URL_MAP[name], path)
48
+ md5 = md5_hash(path)
49
+ assert md5 == MD5_MAP[name], md5
50
+ return path
51
+
52
+
53
+ class LPIPS(nn.Module):
54
+ # Learned perceptual metric
55
+ def __init__(self, use_dropout=True):
56
+ super().__init__()
57
+ self.scaling_layer = ScalingLayer()
58
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
59
+ self.net = vgg16(pretrained=True, requires_grad=False)
60
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
61
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
62
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
63
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
64
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
65
+ self.load_from_pretrained()
66
+ for param in self.parameters():
67
+ param.requires_grad = False
68
+
69
+ def load_from_pretrained(self, name="vgg_lpips"):
70
+ ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"))
71
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
72
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
73
+
74
+ @classmethod
75
+ def from_pretrained(cls, name="vgg_lpips"):
76
+ if name != "vgg_lpips":
77
+ raise NotImplementedError
78
+ model = cls()
79
+ ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"))
80
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
81
+ return model
82
+
83
+ def forward(self, input, target):
84
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
85
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
86
+ feats0, feats1, diffs = {}, {}, {}
87
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
88
+ for kk in range(len(self.chns)):
89
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
90
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
91
+
92
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
93
+ val = res[0]
94
+ for l in range(1, len(self.chns)):
95
+ val += res[l]
96
+ return val
97
+
98
+
99
+ class ScalingLayer(nn.Module):
100
+ def __init__(self):
101
+ super(ScalingLayer, self).__init__()
102
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
103
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
104
+
105
+ def forward(self, inp):
106
+ return (inp - self.shift) / self.scale
107
+
108
+
109
+ class NetLinLayer(nn.Module):
110
+ """ A single linear layer which does a 1x1 conv """
111
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
112
+ super(NetLinLayer, self).__init__()
113
+ layers = [nn.Dropout(), ] if (use_dropout) else []
114
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
115
+ self.model = nn.Sequential(*layers)
116
+
117
+
118
+ class vgg16(torch.nn.Module):
119
+ def __init__(self, requires_grad=False, pretrained=True):
120
+ super(vgg16, self).__init__()
121
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
122
+ self.slice1 = torch.nn.Sequential()
123
+ self.slice2 = torch.nn.Sequential()
124
+ self.slice3 = torch.nn.Sequential()
125
+ self.slice4 = torch.nn.Sequential()
126
+ self.slice5 = torch.nn.Sequential()
127
+ self.N_slices = 5
128
+ for x in range(4):
129
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
130
+ for x in range(4, 9):
131
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
132
+ for x in range(9, 16):
133
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
134
+ for x in range(16, 23):
135
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
136
+ for x in range(23, 30):
137
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
138
+ if not requires_grad:
139
+ for param in self.parameters():
140
+ param.requires_grad = False
141
+
142
+ def forward(self, X):
143
+ h = self.slice1(X)
144
+ h_relu1_2 = h
145
+ h = self.slice2(h)
146
+ h_relu2_2 = h
147
+ h = self.slice3(h)
148
+ h_relu3_3 = h
149
+ h = self.slice4(h)
150
+ h_relu4_3 = h
151
+ h = self.slice5(h)
152
+ h_relu5_3 = h
153
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
154
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
155
+ return out
156
+
157
+
158
+ def normalize_tensor(x,eps=1e-10):
159
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
160
+ return x/(norm_factor+eps)
161
+
162
+
163
+ def spatial_average(x, keepdim=True):
164
+ return x.mean([2,3],keepdim=keepdim)
tokenizer/tokenizer_image/reconstruction_vq_ddp.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.backends.cuda.matmul.allow_tf32 = True
3
+ torch.backends.cudnn.allow_tf32 = True
4
+ import torch.nn.functional as F
5
+ import torch.distributed as dist
6
+ from torch.utils.data import DataLoader
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ from torchvision import transforms
9
+ from tqdm import tqdm
10
+ import os
11
+ from PIL import Image
12
+ import numpy as np
13
+ import argparse
14
+ import itertools
15
+
16
+ from skimage.metrics import peak_signal_noise_ratio as psnr_loss
17
+ from skimage.metrics import structural_similarity as ssim_loss
18
+ from dataset.augmentation import center_crop_arr
19
+ from dataset.build import build_dataset
20
+ from tokenizer.tokenizer_image.vq_model import VQ_models
21
+
22
+
23
+
24
+ def create_npz_from_sample_folder(sample_dir, num=50000):
25
+ """
26
+ Builds a single .npz file from a folder of .png samples.
27
+ """
28
+ samples = []
29
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
30
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
31
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
32
+ samples.append(sample_np)
33
+ samples = np.stack(samples)
34
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
35
+ npz_path = f"{sample_dir}.npz"
36
+ np.savez(npz_path, arr_0=samples)
37
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
38
+ return npz_path
39
+
40
+
41
+
42
+ def main(args):
43
+ # Setup PyTorch:
44
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
45
+ torch.set_grad_enabled(False)
46
+
47
+ # Setup DDP:
48
+ dist.init_process_group("nccl")
49
+ rank = dist.get_rank()
50
+ device = rank % torch.cuda.device_count()
51
+ seed = args.global_seed * dist.get_world_size() + rank
52
+ torch.manual_seed(seed)
53
+ torch.cuda.set_device(device)
54
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
55
+
56
+ # create and load model
57
+ vq_model = VQ_models[args.vq_model](
58
+ codebook_size=args.codebook_size,
59
+ codebook_embed_dim=args.codebook_embed_dim)
60
+ vq_model.to(device)
61
+ vq_model.eval()
62
+ checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
63
+ if "ema" in checkpoint: # ema
64
+ model_weight = checkpoint["ema"]
65
+ elif "model" in checkpoint: # ddp
66
+ model_weight = checkpoint["model"]
67
+ elif "state_dict" in checkpoint:
68
+ model_weight = checkpoint["state_dict"]
69
+ else:
70
+ raise Exception("please check model weight")
71
+ vq_model.load_state_dict(model_weight)
72
+ del checkpoint
73
+
74
+ # Create folder to save samples:
75
+ folder_name = (f"{args.vq_model}-{args.dataset}-size-{args.image_size}-size-{args.image_size_eval}"
76
+ f"-codebook-size-{args.codebook_size}-dim-{args.codebook_embed_dim}-seed-{args.global_seed}")
77
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
78
+ if rank == 0:
79
+ os.makedirs(sample_folder_dir, exist_ok=True)
80
+ print(f"Saving .png samples at {sample_folder_dir}")
81
+ dist.barrier()
82
+
83
+ # Setup data:
84
+ transform = transforms.Compose([
85
+ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
86
+ transforms.ToTensor(),
87
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
88
+ ])
89
+
90
+ if args.dataset == 'imagenet':
91
+ dataset = build_dataset(args, transform=transform)
92
+ num_fid_samples = 50000
93
+ elif args.dataset == 'coco':
94
+ dataset = build_dataset(args, transform=transform)
95
+ num_fid_samples = 5000
96
+ elif args.dataset == 'imagenet_code':
97
+ dataset = build_dataset(args)
98
+ num_fid_samples = 50000
99
+ else:
100
+ raise Exception("please check dataset")
101
+
102
+ sampler = DistributedSampler(
103
+ dataset,
104
+ num_replicas=dist.get_world_size(),
105
+ rank=rank,
106
+ shuffle=False,
107
+ seed=args.global_seed
108
+ )
109
+ loader = DataLoader(
110
+ dataset,
111
+ batch_size=args.per_proc_batch_size,
112
+ shuffle=False,
113
+ sampler=sampler,
114
+ num_workers=args.num_workers,
115
+ pin_memory=True,
116
+ drop_last=False
117
+ )
118
+
119
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
120
+ n = args.per_proc_batch_size
121
+ global_batch_size = n * dist.get_world_size()
122
+
123
+ psnr_val_rgb = []
124
+ ssim_val_rgb = []
125
+ loader = tqdm(loader) if rank == 0 else loader
126
+ total = 0
127
+ # for x, _ in loader:
128
+ for batch in loader:
129
+ x = batch['condition_imgs'].repeat(1,3,1,1)
130
+ # import pdb
131
+ # pdb.set_trace()
132
+ if args.image_size_eval != args.image_size:
133
+ rgb_gts = F.interpolate(x, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
134
+ else:
135
+ rgb_gts = x
136
+ rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1]
137
+ x = x.to(device, non_blocking=True)
138
+ with torch.no_grad():
139
+ latent, _, [_, _, indices] = vq_model.encode(x.float())
140
+ import pdb;pdb.set_trace()
141
+ samples = vq_model.decode_code(indices, latent.shape) # output value is between [-1, 1]
142
+ if args.image_size_eval != args.image_size:
143
+ samples = F.interpolate(samples, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
144
+ samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
145
+
146
+ # Save samples to disk as individual .png files
147
+ for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)):
148
+ index = i * dist.get_world_size() + rank + total
149
+ # Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
150
+ # metric
151
+ rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1]
152
+ psnr = psnr_loss(rgb_restored, rgb_gt)
153
+ ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1)
154
+ psnr_val_rgb.append(psnr)
155
+ ssim_val_rgb.append(ssim)
156
+
157
+ total += global_batch_size
158
+
159
+ # ------------------------------------
160
+ # Summary
161
+ # ------------------------------------
162
+ # Make sure all processes have finished saving their samples
163
+ dist.barrier()
164
+ world_size = dist.get_world_size()
165
+ gather_psnr_val = [None for _ in range(world_size)]
166
+ gather_ssim_val = [None for _ in range(world_size)]
167
+ dist.all_gather_object(gather_psnr_val, psnr_val_rgb)
168
+ dist.all_gather_object(gather_ssim_val, ssim_val_rgb)
169
+
170
+ if rank == 0:
171
+ gather_psnr_val = list(itertools.chain(*gather_psnr_val))
172
+ gather_ssim_val = list(itertools.chain(*gather_ssim_val))
173
+ psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val)
174
+ ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val)
175
+ print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb))
176
+
177
+ result_file = f"{sample_folder_dir}_results.txt"
178
+ print("writing results to {}".format(result_file))
179
+ with open(result_file, 'w') as f:
180
+ print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f)
181
+
182
+ create_npz_from_sample_folder(sample_folder_dir, num_fid_samples)
183
+ print("Done.")
184
+
185
+ dist.barrier()
186
+ dist.destroy_process_group()
187
+
188
+
189
+ if __name__ == "__main__":
190
+ parser = argparse.ArgumentParser()
191
+ parser.add_argument("--data-path", type=str, default=None)
192
+ parser.add_argument("--code-path", type=str, required=True)
193
+ parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco', 'imagenet_code'], default='imagenet')
194
+ parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
195
+ parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
196
+ parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
197
+ parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
198
+ parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=256)
199
+ parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256)
200
+ parser.add_argument("--sample-dir", type=str, default="reconstructions")
201
+ parser.add_argument("--per-proc-batch-size", type=int, default=32)
202
+ parser.add_argument("--global-seed", type=int, default=0)
203
+ parser.add_argument("--num-workers", type=int, default=4)
204
+ parser.add_argument("--condition", type=str, choices=['canny', 'hed'], default='canny')
205
+ parser.add_argument("--get-condition-img", type=bool, default=False)
206
+ args = parser.parse_args()
207
+ main(args)
tokenizer/tokenizer_image/vq_demo.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import os
5
+ import argparse
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ from tokenizer.tokenizer_image.vq_model import VQ_models
10
+ from dataset.augmentation import center_crop_arr
11
+
12
+
13
+ def main(args):
14
+ # Setup PyTorch:
15
+ torch.manual_seed(args.seed)
16
+ torch.set_grad_enabled(False)
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # create and load model
20
+ model = VQ_models[args.vq_model](
21
+ codebook_size=args.codebook_size,
22
+ codebook_embed_dim=args.codebook_embed_dim)
23
+ model.to(device)
24
+ model.eval()
25
+ checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
26
+ if "ema" in checkpoint: # ema
27
+ model_weight = checkpoint["ema"]
28
+ elif "model" in checkpoint: # ddp
29
+ model_weight = checkpoint["model"]
30
+ elif "state_dict" in checkpoint:
31
+ model_weight = checkpoint["state_dict"]
32
+ else:
33
+ raise Exception("please check model weight")
34
+ model.load_state_dict(model_weight)
35
+ del checkpoint
36
+
37
+ # output dir
38
+ os.makedirs(args.output_dir, exist_ok=True)
39
+ out_path = args.image_path.replace('.jpg', '_{}.jpg'.format(args.suffix))
40
+ out_path = out_path.replace('.jpeg', '_{}.jpeg'.format(args.suffix))
41
+ out_path = out_path.replace('.png', '_{}.png'.format(args.suffix))
42
+ out_filename = out_path.split('/')[-1]
43
+ out_path = os.path.join(args.output_dir, out_filename)
44
+
45
+ # load image
46
+ pil_image = Image.open(args.image_path).convert("RGB")
47
+ img = center_crop_arr(pil_image, args.image_size)
48
+ # # preprocess
49
+ # size_org = img.size
50
+ # img = img.resize((input_size, input_size))
51
+ img = np.array(img) / 255.
52
+ x = 2.0 * img - 1.0 # x value is between [-1, 1]
53
+ x = torch.tensor(x)
54
+ x = x.unsqueeze(dim=0)
55
+ x = torch.einsum('nhwc->nchw', x)
56
+ x_input = x.float().to("cuda")
57
+
58
+ # inference
59
+ with torch.no_grad():
60
+ latent, _, [_, _, indices] = model.encode(x_input)
61
+ output = model.decode_code(indices, latent.shape) # output value is between [-1, 1]
62
+
63
+ # postprocess
64
+ output = F.interpolate(output, size=[args.image_size, args.image_size], mode='bicubic').permute(0, 2, 3, 1)[0]
65
+ sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
66
+
67
+ # save
68
+ Image.fromarray(sample).save(out_path)
69
+ print("Reconstructed image is saved to {}".format(out_path))
70
+
71
+
72
+ if __name__ == "__main__":
73
+ parser = argparse.ArgumentParser()
74
+ parser.add_argument("--image-path", type=str, default="assets/example.jpg")
75
+ parser.add_argument("--output-dir", type=str, default="output_vq_demo")
76
+ parser.add_argument("--suffix", type=str, default="tokenizer_image")
77
+ parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
78
+ parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
79
+ parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
80
+ parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
81
+ parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512, 1024], default=512)
82
+ parser.add_argument("--seed", type=int, default=0)
83
+ args = parser.parse_args()
84
+ main(args)