customdiffusion360 commited on
Commit
8d3da67
1 Parent(s): 226ece7

reduce memory usage, use github code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +3 -0
  2. Dockerfile +2 -0
  3. app.py +2 -5
  4. configs/train_co3d_concept.yaml +0 -198
  5. sampling_for_demo.py +3 -4
  6. sgm/__init__.py +0 -4
  7. sgm/data/__init__.py +0 -1
  8. sgm/data/data_co3d.py +0 -762
  9. sgm/lr_scheduler.py +0 -135
  10. sgm/models/__init__.py +0 -2
  11. sgm/models/autoencoder.py +0 -335
  12. sgm/models/diffusion.py +0 -556
  13. sgm/modules/__init__.py +0 -6
  14. sgm/modules/attention.py +0 -1202
  15. sgm/modules/autoencoding/__init__.py +0 -0
  16. sgm/modules/autoencoding/lpips/__init__.py +0 -0
  17. sgm/modules/autoencoding/lpips/loss.py +0 -0
  18. sgm/modules/autoencoding/lpips/loss/LICENSE +0 -23
  19. sgm/modules/autoencoding/lpips/loss/__init__.py +0 -0
  20. sgm/modules/autoencoding/lpips/loss/lpips.py +0 -147
  21. sgm/modules/autoencoding/lpips/model/LICENSE +0 -58
  22. sgm/modules/autoencoding/lpips/model/__init__.py +0 -0
  23. sgm/modules/autoencoding/lpips/model/model.py +0 -88
  24. sgm/modules/autoencoding/lpips/util.py +0 -128
  25. sgm/modules/autoencoding/lpips/vqperceptual.py +0 -17
  26. sgm/modules/autoencoding/regularizers/__init__.py +0 -31
  27. sgm/modules/autoencoding/regularizers/base.py +0 -40
  28. sgm/modules/autoencoding/regularizers/quantize.py +0 -487
  29. sgm/modules/diffusionmodules/__init__.py +0 -0
  30. sgm/modules/diffusionmodules/denoiser.py +0 -79
  31. sgm/modules/diffusionmodules/denoiser_scaling.py +0 -41
  32. sgm/modules/diffusionmodules/denoiser_weighting.py +0 -24
  33. sgm/modules/diffusionmodules/discretizer.py +0 -69
  34. sgm/modules/diffusionmodules/guiders.py +0 -167
  35. sgm/modules/diffusionmodules/loss.py +0 -216
  36. sgm/modules/diffusionmodules/loss_weighting.py +0 -32
  37. sgm/modules/diffusionmodules/model.py +0 -748
  38. sgm/modules/diffusionmodules/openaimodel.py +0 -1352
  39. sgm/modules/diffusionmodules/sampling.py +0 -465
  40. sgm/modules/diffusionmodules/sampling_utils.py +0 -48
  41. sgm/modules/diffusionmodules/sigma_sampling.py +0 -54
  42. sgm/modules/diffusionmodules/util.py +0 -344
  43. sgm/modules/diffusionmodules/wrappers.py +0 -35
  44. sgm/modules/distributions/__init__.py +0 -0
  45. sgm/modules/distributions/distributions.py +0 -102
  46. sgm/modules/distributions/distributions1.py +0 -102
  47. sgm/modules/ema.py +0 -86
  48. sgm/modules/encoders/__init__.py +0 -0
  49. sgm/modules/encoders/modules.py +0 -1154
  50. sgm/modules/nerfsd_pytorch3d.py +0 -468
.gitignore CHANGED
@@ -3,3 +3,6 @@ __pycache__/
3
 
4
  # ignore sdxl
5
  pretrained-models/*.safetensors
 
 
 
 
3
 
4
  # ignore sdxl
5
  pretrained-models/*.safetensors
6
+
7
+ # ignore internal code repo
8
+ custom-diffusion360/
Dockerfile CHANGED
@@ -32,6 +32,8 @@ RUN pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
32
  RUN wget https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P pretrained-models
33
  RUN wget https://huggingface.co/stabilityai/sdxl-vae/resolve/main/sdxl_vae.safetensors -P pretrained-models
34
 
 
 
35
  ENV GRADIO_SERVER_NAME=0.0.0.0
36
 
37
  ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "pose", "python", "app.py"]
 
32
  RUN wget https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P pretrained-models
33
  RUN wget https://huggingface.co/stabilityai/sdxl-vae/resolve/main/sdxl_vae.safetensors -P pretrained-models
34
 
35
+ RUN git clone https://github.com/customdiffusion360/custom-diffusion360.git
36
+
37
  ENV GRADIO_SERVER_NAME=0.0.0.0
38
 
39
  ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "pose", "python", "app.py"]
app.py CHANGED
@@ -13,13 +13,10 @@ import sys
13
  # Mesh imports
14
  from pytorch3d.io import load_objs_as_meshes
15
  from pytorch3d.vis.plotly_vis import AxisArgs, plot_scene
16
- from pytorch3d.transforms import Transform3d, RotateAxisAngle, Translate, Rotate
17
 
18
  from sampling_for_demo import load_and_return_model_and_data, sample, load_base_model
19
 
20
- # add current directory to path
21
- # sys.path.append(os.path.dirname(os.path.realpath(__file__)))
22
-
23
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
24
 
25
 
@@ -184,7 +181,7 @@ current_data = None
184
  current_model = None
185
 
186
  global base_model
187
- BASE_CONFIG = "configs/train_co3d_concept.yaml"
188
  BASE_CKPT = "pretrained-models/sd_xl_base_1.0.safetensors"
189
 
190
  start_time = time.time()
 
13
  # Mesh imports
14
  from pytorch3d.io import load_objs_as_meshes
15
  from pytorch3d.vis.plotly_vis import AxisArgs, plot_scene
16
+ from pytorch3d.transforms import RotateAxisAngle, Translate
17
 
18
  from sampling_for_demo import load_and_return_model_and_data, sample, load_base_model
19
 
 
 
 
20
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
21
 
22
 
 
181
  current_model = None
182
 
183
  global base_model
184
+ BASE_CONFIG = "custom-diffusion360/configs/train_co3d_concept.yaml"
185
  BASE_CKPT = "pretrained-models/sd_xl_base_1.0.safetensors"
186
 
187
  start_time = time.time()
configs/train_co3d_concept.yaml DELETED
@@ -1,198 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-4
3
- target: sgm.models.diffusion.DiffusionEngine
4
- params:
5
- scale_factor: 0.13025
6
- disable_first_stage_autocast: True
7
- trainkeys: pose
8
- multiplier: 0.05
9
- loss_rgb_lambda: 5
10
- loss_fg_lambda: 10
11
- loss_bg_lambda: 10
12
- log_keys:
13
- - txt
14
-
15
- denoiser_config:
16
- target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
17
- params:
18
- num_idx: 1000
19
-
20
- weighting_config:
21
- target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
22
- scaling_config:
23
- target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
24
- discretization_config:
25
- target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
26
-
27
- network_config:
28
- target: sgm.modules.diffusionmodules.openaimodel.UNetModel
29
- params:
30
- adm_in_channels: 2816
31
- num_classes: sequential
32
- use_checkpoint: False
33
- in_channels: 4
34
- out_channels: 4
35
- model_channels: 320
36
- attention_resolutions: [4, 2]
37
- num_res_blocks: 2
38
- channel_mult: [1, 2, 4]
39
- num_head_channels: 64
40
- use_linear_in_transformer: True
41
- transformer_depth: [1, 2, 10]
42
- context_dim: 2048
43
- spatial_transformer_attn_type: softmax-xformers
44
- image_cross_blocks: [0, 2, 4, 6, 8, 10]
45
- rgb: True
46
- far: 2
47
- num_samples: 24
48
- not_add_context_in_triplane: False
49
- rgb_predict: True
50
- add_lora: False
51
- average: False
52
- use_prev_weights_imp_sample: True
53
- stratified: True
54
- imp_sampling_percent: 0.9
55
-
56
- conditioner_config:
57
- target: sgm.modules.GeneralConditioner
58
- params:
59
- emb_models:
60
- # crossattn cond
61
- - is_trainable: False
62
- input_keys: txt,txt_ref
63
- target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
64
- params:
65
- layer: hidden
66
- layer_idx: 11
67
- modifier_token: <new1>
68
- # crossattn and vector cond
69
- - is_trainable: False
70
- input_keys: txt,txt_ref
71
- target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
72
- params:
73
- arch: ViT-bigG-14
74
- version: laion2b_s39b_b160k
75
- layer: penultimate
76
- always_return_pooled: True
77
- legacy: False
78
- modifier_token: <new1>
79
- # vector cond
80
- - is_trainable: False
81
- input_keys: original_size_as_tuple,original_size_as_tuple_ref
82
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
83
- params:
84
- outdim: 256 # multiplied by two
85
- # vector cond
86
- - is_trainable: False
87
- input_keys: crop_coords_top_left,crop_coords_top_left_ref
88
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
89
- params:
90
- outdim: 256 # multiplied by two
91
- # vector cond
92
- - is_trainable: False
93
- input_keys: target_size_as_tuple,target_size_as_tuple_ref
94
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
95
- params:
96
- outdim: 256 # multiplied by two
97
-
98
- first_stage_config:
99
- target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
100
- params:
101
- ckpt_path: pretrained-models/sdxl_vae.safetensors
102
- embed_dim: 4
103
- monitor: val/rec_loss
104
- ddconfig:
105
- attn_type: vanilla-xformers
106
- double_z: true
107
- z_channels: 4
108
- resolution: 256
109
- in_channels: 3
110
- out_ch: 3
111
- ch: 128
112
- ch_mult: [1, 2, 4, 4]
113
- num_res_blocks: 2
114
- attn_resolutions: []
115
- dropout: 0.0
116
- lossconfig:
117
- target: torch.nn.Identity
118
-
119
- loss_fn_config:
120
- target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef
121
- params:
122
- sigma_sampler_config:
123
- target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling
124
- params:
125
- num_idx: 1000
126
- discretization_config:
127
- target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
128
- sigma_sampler_config_ref:
129
- target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
130
- params:
131
- num_idx: 50
132
-
133
- discretization_config:
134
- target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
135
-
136
- sampler_config:
137
- target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
138
- params:
139
- num_steps: 50
140
-
141
- discretization_config:
142
- target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
143
-
144
- guider_config:
145
- target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef
146
- params:
147
- scale: 7.5
148
-
149
- data:
150
- target: sgm.data.data_co3d.CustomDataDictLoader
151
- params:
152
- batch_size: 1
153
- num_workers: 4
154
- category: teddybear
155
- img_size: 512
156
- skip: 2
157
- num_images: 5
158
- mask_images: True
159
- single_id: 0
160
- bbox: True
161
- addreg: True
162
- drop_ratio: 0.25
163
- drop_txt: 0.1
164
- modifier_token: <new1>
165
-
166
- lightning:
167
- modelcheckpoint:
168
- params:
169
- every_n_train_steps: 1600
170
- save_top_k: -1
171
- save_on_train_epoch_end: False
172
-
173
- callbacks:
174
- metrics_over_trainsteps_checkpoint:
175
- params:
176
- every_n_train_steps: 25000
177
-
178
- image_logger:
179
- target: main.ImageLogger
180
- params:
181
- disabled: False
182
- enable_autocast: False
183
- batch_frequency: 5000
184
- max_images: 8
185
- increase_log_steps: False
186
- log_first_step: False
187
- log_images_kwargs:
188
- use_ema_scope: False
189
- N: 1
190
- n_rows: 2
191
-
192
- trainer:
193
- devices: 0,1,2,3
194
- benchmark: True
195
- num_sanity_val_steps: 0
196
- accumulate_grad_batches: 1
197
- max_steps: 1610
198
- # val_check_interval: 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sampling_for_demo.py CHANGED
@@ -14,7 +14,7 @@ from pytorch3d.renderer.camera_utils import join_cameras_as_batch
14
 
15
  import json
16
 
17
- sys.path.append('./')
18
  from sgm.util import instantiate_from_config, load_safetensors
19
 
20
  choices = []
@@ -49,7 +49,6 @@ def load_base_model(config, ckpt=None, verbose=True):
49
 
50
  m, u = model.load_state_dict(sd, strict=False)
51
 
52
- model.cuda()
53
  model.eval()
54
  return model
55
 
@@ -84,7 +83,6 @@ def load_delta_model(model, delta_ckpt=None, verbose=True, freeze=True):
84
  for param in model.parameters():
85
  param.requires_grad = False
86
 
87
- model.cuda()
88
  model.eval()
89
  return model, msg
90
 
@@ -290,7 +288,7 @@ def process_camera_json(camera_json, example_cam):
290
 
291
 
292
  def load_and_return_model_and_data(config, model,
293
- ckpt="/data/gdsu/customization3d/stable-diffusion-xl-base-1.0/sd_xl_base_1.0.safetensors",
294
  delta_ckpt=None,
295
  train=False,
296
  valid=False,
@@ -318,6 +316,7 @@ def load_and_return_model_and_data(config, model,
318
  # print(f"Total images in dataset: {total_images}")
319
 
320
  model, msg = load_delta_model(model, delta_ckpt,)
 
321
 
322
  # change forward methods to store rendered features and use the pre-calculated reference features
323
  def register_recr(net_):
 
14
 
15
  import json
16
 
17
+ sys.path.append('./custom-diffusion360/')
18
  from sgm.util import instantiate_from_config, load_safetensors
19
 
20
  choices = []
 
49
 
50
  m, u = model.load_state_dict(sd, strict=False)
51
 
 
52
  model.eval()
53
  return model
54
 
 
83
  for param in model.parameters():
84
  param.requires_grad = False
85
 
 
86
  model.eval()
87
  return model, msg
88
 
 
288
 
289
 
290
  def load_and_return_model_and_data(config, model,
291
+ ckpt="pretrained-models/sd_xl_base_1.0.safetensors",
292
  delta_ckpt=None,
293
  train=False,
294
  valid=False,
 
316
  # print(f"Total images in dataset: {total_images}")
317
 
318
  model, msg = load_delta_model(model, delta_ckpt,)
319
+ model = model.cuda()
320
 
321
  # change forward methods to store rendered features and use the pre-calculated reference features
322
  def register_recr(net_):
sgm/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .models import AutoencodingEngine, DiffusionEngine
2
- from .util import get_configs_path, instantiate_from_config
3
-
4
- __version__ = "0.1.0"
 
 
 
 
 
sgm/data/__init__.py DELETED
@@ -1 +0,0 @@
1
- # from .dataset import StableDataModuleFromConfig
 
 
sgm/data/data_co3d.py DELETED
@@ -1,762 +0,0 @@
1
- # code taken and modified from https://github.com/amyxlase/relpose-plus-plus/blob/b33f7d5000cf2430bfcda6466c8e89bc2dcde43f/relpose/dataset/co3d_v2.py#L346)
2
- import os.path as osp
3
- import random
4
-
5
- import numpy as np
6
- import torch
7
- import pytorch_lightning as pl
8
-
9
- from PIL import Image, ImageFile
10
- import json
11
- import gzip
12
- from torch.utils.data import DataLoader, Dataset
13
- from torchvision import transforms
14
- from pytorch3d.renderer.cameras import PerspectiveCameras
15
- from pytorch3d.renderer.camera_utils import join_cameras_as_batch
16
- from pytorch3d.implicitron.dataset.utils import adjust_camera_to_bbox_crop_, adjust_camera_to_image_scale_
17
- from pytorch3d.transforms import Rotate, Translate
18
-
19
-
20
- CO3D_DIR = "data/training/"
21
-
22
- Image.MAX_IMAGE_PIXELS = None
23
- ImageFile.LOAD_TRUNCATED_IMAGES = True
24
-
25
-
26
- # Added: normalize camera poses
27
- def intersect_skew_line_groups(p, r, mask):
28
- # p, r both of shape (B, N, n_intersected_lines, 3)
29
- # mask of shape (B, N, n_intersected_lines)
30
- p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
31
- _, p_line_intersect = _point_line_distance(
32
- p, r, p_intersect[..., None, :].expand_as(p)
33
- )
34
- intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(
35
- dim=-1
36
- )
37
- return p_intersect, p_line_intersect, intersect_dist_squared, r
38
-
39
-
40
- def intersect_skew_lines_high_dim(p, r, mask=None):
41
- # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
42
- dim = p.shape[-1]
43
- # make sure the heading vectors are l2-normed
44
- if mask is None:
45
- mask = torch.ones_like(p[..., 0])
46
- r = torch.nn.functional.normalize(r, dim=-1)
47
-
48
- eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
49
- I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
50
- sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
51
- p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
52
-
53
- if torch.any(torch.isnan(p_intersect)):
54
- print(p_intersect)
55
- assert False
56
- return p_intersect, r
57
-
58
-
59
- def _point_line_distance(p1, r1, p2):
60
- df = p2 - p1
61
- proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
62
- line_pt_nearest = p2 - proj_vector
63
- d = (proj_vector).norm(dim=-1)
64
- return d, line_pt_nearest
65
-
66
-
67
- def compute_optical_axis_intersection(cameras):
68
- centers = cameras.get_camera_center()
69
- principal_points = cameras.principal_point
70
-
71
- one_vec = torch.ones((len(cameras), 1))
72
- optical_axis = torch.cat((principal_points, one_vec), -1)
73
-
74
- pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
75
-
76
- pp2 = torch.zeros((pp.shape[0], 3))
77
- for i in range(0, pp.shape[0]):
78
- pp2[i] = pp[i][i]
79
-
80
- directions = pp2 - centers
81
- centers = centers.unsqueeze(0).unsqueeze(0)
82
- directions = directions.unsqueeze(0).unsqueeze(0)
83
-
84
- p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(
85
- p=centers, r=directions, mask=None
86
- )
87
-
88
- p_intersect = p_intersect.squeeze().unsqueeze(0)
89
- dist = (p_intersect - centers).norm(dim=-1)
90
-
91
- return p_intersect, dist, p_line_intersect, pp2, r
92
-
93
-
94
- def normalize_cameras(cameras, scale=1.0):
95
- """
96
- Normalizes cameras such that the optical axes point to the origin and the average
97
- distance to the origin is 1.
98
-
99
- Args:
100
- cameras (List[camera]).
101
- """
102
-
103
- # Let distance from first camera to origin be unit
104
- new_cameras = cameras.clone()
105
- new_transform = new_cameras.get_world_to_view_transform()
106
-
107
- p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(
108
- cameras
109
- )
110
- t = Translate(p_intersect)
111
-
112
- # scale = dist.squeeze()[0]
113
- scale = max(dist.squeeze())
114
-
115
- # Degenerate case
116
- if scale == 0:
117
- print(cameras.T)
118
- print(new_transform.get_matrix()[:, 3, :3])
119
- return -1
120
- assert scale != 0
121
-
122
- new_transform = t.compose(new_transform)
123
- new_cameras.R = new_transform.get_matrix()[:, :3, :3]
124
- new_cameras.T = new_transform.get_matrix()[:, 3, :3] / scale
125
- return new_cameras, p_intersect, p_line_intersect, pp, r
126
-
127
-
128
- def centerandalign(cameras, scale=1.0):
129
- """
130
- Normalizes cameras such that the optical axes point to the origin and the average
131
- distance to the origin is 1.
132
-
133
- Args:
134
- cameras (List[camera]).
135
- """
136
-
137
- # Let distance from first camera to origin be unit
138
- new_cameras = cameras.clone()
139
- new_transform = new_cameras.get_world_to_view_transform()
140
-
141
- p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(
142
- cameras
143
- )
144
- t = Translate(p_intersect)
145
-
146
- centers = [cam.get_camera_center() for cam in new_cameras]
147
- centers = torch.concat(centers, 0).cpu().numpy()
148
- m = len(cameras)
149
-
150
- # https://math.stackexchange.com/questions/99299/best-fitting-plane-given-a-set-of-points
151
- A = np.hstack((centers[:m, :2], np.ones((m, 1))))
152
- B = centers[:m, 2:]
153
- if A.shape[0] == 2:
154
- x = A.T @ np.linalg.inv(A @ A.T) @ B
155
- else:
156
- x = np.linalg.inv(A.T @ A) @ A.T @ B
157
- a, b, c = x.flatten()
158
- n = np.array([a, b, 1])
159
- n /= np.linalg.norm(n)
160
-
161
- # https://math.stackexchange.com/questions/180418/calculate-rotation-matrix-to-align-vector-a-to-vector-b-in-3d
162
- v = np.cross(n, [0, 1, 0])
163
- s = np.linalg.norm(v)
164
- c = np.dot(n, [0, 1, 0])
165
- V = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
166
- rot = torch.from_numpy(np.eye(3) + V + V @ V * (1 - c) / s**2).float()
167
-
168
- scale = dist.squeeze()[0]
169
-
170
- # Degenerate case
171
- if scale == 0:
172
- print(cameras.T)
173
- print(new_transform.get_matrix()[:, 3, :3])
174
- return -1
175
- assert scale != 0
176
-
177
- rot = Rotate(rot.T)
178
-
179
- new_transform = rot.compose(t).compose(new_transform)
180
- new_cameras.R = new_transform.get_matrix()[:, :3, :3]
181
- new_cameras.T = new_transform.get_matrix()[:, 3, :3] / scale
182
- return new_cameras
183
-
184
-
185
- def square_bbox(bbox, padding=0.0, astype=None):
186
- """
187
- Computes a square bounding box, with optional padding parameters.
188
-
189
- Args:
190
- bbox: Bounding box in xyxy format (4,).
191
-
192
- Returns:
193
- square_bbox in xyxy format (4,).
194
- """
195
- if astype is None:
196
- astype = type(bbox[0])
197
- bbox = np.array(bbox)
198
- center = ((bbox[:2] + bbox[2:]) / 2).round().astype(int)
199
- extents = (bbox[2:] - bbox[:2]) / 2
200
- s = (max(extents) * (1 + padding)).round().astype(int)
201
- square_bbox = np.array(
202
- [center[0] - s, center[1] - s, center[0] + s, center[1] + s],
203
- dtype=astype,
204
- )
205
-
206
- return square_bbox
207
-
208
-
209
- class Co3dDataset(Dataset):
210
- def __init__(
211
- self,
212
- category,
213
- split="train",
214
- skip=2,
215
- img_size=1024,
216
- num_images=4,
217
- mask_images=False,
218
- single_id=0,
219
- bbox=False,
220
- modifier_token=None,
221
- addreg=False,
222
- drop_ratio=0.5,
223
- drop_txt=0.1,
224
- categoryname=None,
225
- aligncameras=False,
226
- repeat=100,
227
- addlen=False,
228
- onlyref=False,
229
- ):
230
- """
231
- Args:
232
- category (iterable): List of categories to use. If "all" is in the list,
233
- all training categories are used.
234
- num_images (int): Default number of images in each batch.
235
- normalize_cameras (bool): If True, normalizes cameras so that the
236
- intersection of the optical axes is placed at the origin and the norm
237
- of the first camera translation is 1.
238
- mask_images (bool): If True, masks out the background of the images.
239
- """
240
- # category = CATEGORIES
241
- category = sorted(category.split(','))
242
- self.category = category
243
- self.single_id = single_id
244
- self.addlen = addlen
245
- self.onlyref = onlyref
246
- self.categoryname = categoryname
247
- self.bbox = bbox
248
- self.modifier_token = modifier_token
249
- self.addreg = addreg
250
- self.drop_txt = drop_txt
251
- self.skip = skip
252
- if self.addreg:
253
- with open(f'data/regularization/{category[0]}_sp_generated/caption.txt', "r") as f:
254
- self.regcaptions = f.read().splitlines()
255
- self.reglen = len(self.regcaptions)
256
- self.regimpath = f'data/regularization/{category[0]}_sp_generated'
257
-
258
- self.low_quality_translations = []
259
- self.rotations = {}
260
- self.category_map = {}
261
- co3d_dir = CO3D_DIR
262
- for c in category:
263
- subset = 'fewview_dev'
264
- category_dir = osp.join(co3d_dir, c)
265
- frame_file = osp.join(category_dir, "frame_annotations.jgz")
266
- sequence_file = osp.join(category_dir, "sequence_annotations.jgz")
267
- subset_lists_file = osp.join(category_dir, f"set_lists/set_lists_{subset}.json")
268
- bbox_file = osp.join(category_dir, f"{c}_bbox.jgz")
269
-
270
- with open(subset_lists_file) as f:
271
- subset_lists_data = json.load(f)
272
-
273
- with gzip.open(sequence_file, "r") as fin:
274
- sequence_data = json.loads(fin.read())
275
-
276
- with gzip.open(bbox_file, "r") as fin:
277
- bbox_data = json.loads(fin.read())
278
-
279
- with gzip.open(frame_file, "r") as fin:
280
- frame_data = json.loads(fin.read())
281
-
282
- frame_data_processed = {}
283
- for f_data in frame_data:
284
- sequence_name = f_data["sequence_name"]
285
- if sequence_name not in frame_data_processed:
286
- frame_data_processed[sequence_name] = {}
287
- frame_data_processed[sequence_name][f_data["frame_number"]] = f_data
288
-
289
- good_quality_sequences = set()
290
- for seq_data in sequence_data:
291
- if seq_data["viewpoint_quality_score"] > 0.5:
292
- good_quality_sequences.add(seq_data["sequence_name"])
293
-
294
- for subset in ["train"]:
295
- for seq_name, frame_number, filepath in subset_lists_data[subset]:
296
- if seq_name not in good_quality_sequences:
297
- continue
298
-
299
- if seq_name not in self.rotations:
300
- self.rotations[seq_name] = []
301
- self.category_map[seq_name] = c
302
-
303
- mask_path = filepath.replace("images", "masks").replace(".jpg", ".png")
304
-
305
- frame_data = frame_data_processed[seq_name][frame_number]
306
-
307
- self.rotations[seq_name].append(
308
- {
309
- "filepath": filepath,
310
- "R": frame_data["viewpoint"]["R"],
311
- "T": frame_data["viewpoint"]["T"],
312
- "focal_length": frame_data["viewpoint"]["focal_length"],
313
- "principal_point": frame_data["viewpoint"]["principal_point"],
314
- "mask": mask_path,
315
- "txt": "a car",
316
- "bbox": bbox_data[mask_path]
317
- }
318
- )
319
-
320
- for seq_name in self.rotations:
321
- seq_data = self.rotations[seq_name]
322
- cameras = PerspectiveCameras(
323
- focal_length=[data["focal_length"] for data in seq_data],
324
- principal_point=[data["principal_point"] for data in seq_data],
325
- R=[data["R"] for data in seq_data],
326
- T=[data["T"] for data in seq_data],
327
- )
328
-
329
- normalized_cameras, _, _, _, _ = normalize_cameras(cameras)
330
- if aligncameras:
331
- normalized_cameras = centerandalign(cameras)
332
-
333
- if normalized_cameras == -1:
334
- print("Error in normalizing cameras: camera scale was 0")
335
- del self.rotations[seq_name]
336
- continue
337
-
338
- for i, data in enumerate(seq_data):
339
- self.rotations[seq_name][i]["R"] = normalized_cameras.R[i]
340
- self.rotations[seq_name][i]["T"] = normalized_cameras.T[i]
341
- self.rotations[seq_name][i]["R_original"] = torch.from_numpy(np.array(seq_data[i]["R"]))
342
- self.rotations[seq_name][i]["T_original"] = torch.from_numpy(np.array(seq_data[i]["T"]))
343
-
344
- # Make sure translations are not ridiculous
345
- if self.rotations[seq_name][i]["T"][0] + self.rotations[seq_name][i]["T"][1] + self.rotations[seq_name][i]["T"][2] > 1e5:
346
- bad_seq = True
347
- self.low_quality_translations.append(seq_name)
348
- break
349
-
350
- for seq_name in self.low_quality_translations:
351
- if seq_name in self.rotations:
352
- del self.rotations[seq_name]
353
-
354
- self.sequence_list = list(self.rotations.keys())
355
-
356
- self.transform = transforms.Compose(
357
- [
358
- transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC),
359
- transforms.ToTensor(),
360
- transforms.Lambda(lambda x: x * 2.0 - 1.0)
361
- ]
362
- )
363
- self.transformim = transforms.Compose(
364
- [
365
- transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC),
366
- transforms.CenterCrop(img_size),
367
- transforms.ToTensor(),
368
- transforms.Lambda(lambda x: x * 2.0 - 1.0)
369
- ]
370
- )
371
- self.transformmask = transforms.Compose(
372
- [
373
- transforms.Resize(img_size // 8),
374
- transforms.ToTensor(),
375
- ]
376
- )
377
-
378
- self.num_images = num_images
379
- self.image_size = img_size
380
- self.normalize_cameras = normalize_cameras
381
- self.mask_images = mask_images
382
- self.drop_ratio = drop_ratio
383
- self.kernel_tensor = torch.ones((1, 1, 7, 7))
384
- self.repeat = repeat
385
- print(self.sequence_list, "$$$$$$$$$$$$$$$$$$$$$")
386
- self.valid_ids = np.arange(0, len(self.rotations[self.sequence_list[self.single_id]]), skip).tolist()
387
- if split == 'test':
388
- self.valid_ids = list(set(np.arange(0, len(self.rotations[self.sequence_list[self.single_id]])).tolist()).difference(self.valid_ids))
389
-
390
- print(
391
- f"Low quality translation sequences, not used: {self.low_quality_translations}"
392
- )
393
- print(f"Data size: {len(self)}")
394
-
395
- def __len__(self):
396
- return (len(self.valid_ids))*self.repeat + (1 if self.addlen else 0)
397
-
398
- def _padded_bbox(self, bbox, w, h):
399
- if w < h:
400
- bbox = np.array([0, 0, w, h])
401
- else:
402
- bbox = np.array([0, 0, w, h])
403
- return square_bbox(bbox.astype(np.float32))
404
-
405
- def _crop_bbox(self, bbox, w, h):
406
- bbox = square_bbox(bbox.astype(np.float32))
407
-
408
- side_length = bbox[2] - bbox[0]
409
- center = (bbox[:2] + bbox[2:]) / 2
410
- extent = side_length / 2
411
-
412
- # Final coordinates need to be integer for cropping.
413
- ul = (center - extent).round().astype(int)
414
- lr = ul + np.round(2 * extent).astype(int)
415
- return np.concatenate((ul, lr))
416
-
417
- def _crop_image(self, image, bbox, white_bg=False):
418
- if white_bg:
419
- # Only support PIL Images
420
- image_crop = Image.new(
421
- "RGB", (bbox[2] - bbox[0], bbox[3] - bbox[1]), (255, 255, 255)
422
- )
423
- image_crop.paste(image, (-bbox[0], -bbox[1]))
424
- else:
425
- image_crop = transforms.functional.crop(
426
- image,
427
- top=bbox[1],
428
- left=bbox[0],
429
- height=bbox[3] - bbox[1],
430
- width=bbox[2] - bbox[0],
431
- )
432
- return image_crop
433
-
434
- def __getitem__(self, index, specific_id=None, validation=False):
435
- sequence_name = self.sequence_list[self.single_id]
436
-
437
- metadata = self.rotations[sequence_name]
438
-
439
- if validation:
440
- drop_text = False
441
- drop_im = False
442
- else:
443
- drop_im = np.random.uniform(0, 1) < self.drop_ratio
444
- if not drop_im:
445
- drop_text = np.random.uniform(0, 1) < self.drop_txt
446
- else:
447
- drop_text = False
448
-
449
- size = self.image_size
450
-
451
- # sample reference ids
452
- listofindices = self.valid_ids.copy()
453
- max_diff = len(listofindices) // (self.num_images-1)
454
- if (index*self.skip) % len(metadata) in listofindices:
455
- listofindices.remove((index*self.skip) % len(metadata))
456
- references = np.random.choice(np.arange(0, len(listofindices)+1, max_diff), self.num_images-1, replace=False)
457
- rem = np.random.randint(0, max_diff)
458
- references = [listofindices[(x + rem) % len(listofindices)] for x in references]
459
- ids = [(index*self.skip) % len(metadata)] + references
460
-
461
- # special case to save features corresponding to ref image as part of model buffer
462
- if self.onlyref:
463
- ids = references + [(index*self.skip) % len(metadata)]
464
- if specific_id is not None: # remove this later
465
- ids = specific_id
466
-
467
- # get data
468
- batch = self.get_data(index=self.single_id, ids=ids)
469
-
470
- # text prompt
471
- if self.modifier_token is not None:
472
- name = self.category[0] if self.categoryname is None else self.categoryname
473
- batch['txt'] = [f'photo of a {self.modifier_token} {name}' for _ in range(len(batch['txt']))]
474
-
475
- # replace with regularization image if drop_im
476
- if drop_im and self.addreg:
477
- select_id = np.random.randint(0, self.reglen)
478
- batch["image"] = [self.transformim(Image.open(f'{self.regimpath}/images/{select_id}.png').convert('RGB'))]
479
- batch['txt'] = [self.regcaptions[select_id]]
480
- batch["original_size_as_tuple"] = torch.ones_like(batch["original_size_as_tuple"])*1024
481
-
482
- # create camera class and adjust intrinsics for crop
483
- cameras = [PerspectiveCameras(R=batch['R'][i].unsqueeze(0),
484
- T=batch['T'][i].unsqueeze(0),
485
- focal_length=batch['focal_lengths'][i].unsqueeze(0),
486
- principal_point=batch['principal_points'][i].unsqueeze(0),
487
- image_size=self.image_size
488
- )
489
- for i in range(len(ids))]
490
- for i, cam in enumerate(cameras):
491
- adjust_camera_to_bbox_crop_(cam, batch["original_size_as_tuple"][i, :2], batch["crop_coords"][i])
492
- adjust_camera_to_image_scale_(cam, batch["original_size_as_tuple"][i, 2:], torch.tensor([self.image_size, self.image_size]))
493
-
494
- # create mask and dilated mask for mask based losses
495
- batch["depth"] = batch["mask"].clone()
496
- batch["mask"] = torch.clamp(torch.nn.functional.conv2d(batch["mask"], self.kernel_tensor, padding='same'), 0, 1)
497
- if not self.mask_images:
498
- batch["mask"] = [None for i in range(len(ids))]
499
-
500
- # special case to save features corresponding to zero image
501
- if index == self.__len__()-1 and self.addlen:
502
- batch["image"][0] *= 0.
503
-
504
- return {"jpg": batch["image"][0],
505
- "txt": batch["txt"][0] if not drop_text else "",
506
- "jpg_ref": batch["image"][1:] if not drop_im else torch.stack([2*torch.rand_like(batch["image"][0])-1. for _ in range(len(ids)-1)], dim=0),
507
- "txt_ref": batch["txt"][1:] if not drop_im else ["" for _ in range(len(ids)-1)],
508
- "pose": cameras,
509
- "mask": batch["mask"][0] if not drop_im else torch.ones_like(batch["mask"][0]),
510
- "mask_ref": batch["masks_padding"][1:],
511
- "depth": batch["depth"][0] if len(batch["depth"]) > 0 else None,
512
- "filepaths": batch["filepaths"],
513
- "original_size_as_tuple": batch["original_size_as_tuple"][0][2:],
514
- "target_size_as_tuple": torch.ones_like(batch["original_size_as_tuple"][0][2:])*size,
515
- "crop_coords_top_left": torch.zeros_like(batch["crop_coords"][0][:2]),
516
- "original_size_as_tuple_ref": batch["original_size_as_tuple"][1:][:, 2:],
517
- "target_size_as_tuple_ref": torch.ones_like(batch["original_size_as_tuple"][1:][:, 2:])*size,
518
- "crop_coords_top_left_ref": torch.zeros_like(batch["crop_coords"][1:][:, :2]),
519
- "drop_im": torch.Tensor([1-drop_im*1.])
520
- }
521
-
522
- def get_data(self, index=None, sequence_name=None, ids=(0, 1)):
523
- if sequence_name is None:
524
- sequence_name = self.sequence_list[index]
525
- metadata = self.rotations[sequence_name]
526
- category = self.category_map[sequence_name]
527
- annos = [metadata[i] for i in ids]
528
- images = []
529
- rotations = []
530
- translations = []
531
- focal_lengths = []
532
- principal_points = []
533
- txts = []
534
- masks = []
535
- filepaths = []
536
- images_transformed = []
537
- masks_transformed = []
538
- original_size_as_tuple = []
539
- crop_parameters = []
540
- masks_padding = []
541
- depths = []
542
-
543
- for counter, anno in enumerate(annos):
544
- filepath = anno["filepath"]
545
- filepaths.append(filepath)
546
- image = Image.open(osp.join(CO3D_DIR, filepath)).convert("RGB")
547
-
548
- mask_name = osp.basename(filepath.replace(".jpg", ".png"))
549
-
550
- mask_path = osp.join(
551
- CO3D_DIR, category, sequence_name, "masks", mask_name
552
- )
553
- mask = Image.open(mask_path).convert("L")
554
-
555
- if mask.size != image.size:
556
- mask = mask.resize(image.size)
557
-
558
- mask_padded = Image.fromarray((np.ones_like(mask) > 0))
559
- mask = Image.fromarray((np.array(mask) > 125))
560
- masks.append(mask)
561
-
562
- # crop image around object
563
- w, h = image.width, image.height
564
- bbox = np.array(anno["bbox"])
565
- if len(bbox) == 0:
566
- bbox = np.array([0, 0, w, h])
567
-
568
- if self.bbox and counter > 0:
569
- bbox = self._crop_bbox(bbox, w, h)
570
- else:
571
- bbox = self._padded_bbox(None, w, h)
572
- image = self._crop_image(image, bbox)
573
- mask = self._crop_image(mask, bbox)
574
- mask_padded = self._crop_image(mask_padded, bbox)
575
- masks_padding.append(self.transformmask(mask_padded))
576
- images_transformed.append(self.transform(image))
577
- masks_transformed.append(self.transformmask(mask))
578
-
579
- crop_parameters.append(torch.tensor([bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1] ]).int())
580
- original_size_as_tuple.append(torch.tensor([w, h, bbox[2] - bbox[0], bbox[3] - bbox[1]]))
581
- images.append(image)
582
- rotations.append(anno["R"])
583
- translations.append(anno["T"])
584
- focal_lengths.append(torch.tensor(anno["focal_length"]))
585
- principal_points.append(torch.tensor(anno["principal_point"]))
586
- txts.append(anno["txt"])
587
-
588
- images = images_transformed
589
- batch = {
590
- "model_id": sequence_name,
591
- "category": category,
592
- "original_size_as_tuple": torch.stack(original_size_as_tuple),
593
- "crop_coords": torch.stack(crop_parameters),
594
- "n": len(metadata),
595
- "ind": torch.tensor(ids),
596
- "txt": txts,
597
- "filepaths": filepaths,
598
- "masks_padding": torch.stack(masks_padding) if len(masks_padding) > 0 else [],
599
- "depth": torch.stack(depths) if len(depths) > 0 else [],
600
- }
601
-
602
- batch["R"] = torch.stack(rotations)
603
- batch["T"] = torch.stack(translations)
604
- batch["focal_lengths"] = torch.stack(focal_lengths)
605
- batch["principal_points"] = torch.stack(principal_points)
606
-
607
- # Add images
608
- if self.transform is None:
609
- batch["image"] = images
610
- else:
611
- batch["image"] = torch.stack(images)
612
- batch["mask"] = torch.stack(masks_transformed)
613
-
614
- return batch
615
-
616
- @staticmethod
617
- def collate_fn(batch):
618
- """A function to collate the data across batches. This function must be passed to pytorch's DataLoader to collate batches.
619
- Args:
620
- batch(list): List of objects returned by this class' __getitem__ function. This is given by pytorch's dataloader that calls __getitem__
621
- multiple times and expects a collated batch.
622
- Returns:
623
- dict: The collated dictionary representing the data in the batch.
624
- """
625
- result = {
626
- "jpg": [],
627
- "txt": [],
628
- "jpg_ref": [],
629
- "txt_ref": [],
630
- "pose": [],
631
- "original_size_as_tuple": [],
632
- "original_size_as_tuple_ref": [],
633
- "crop_coords_top_left": [],
634
- "crop_coords_top_left_ref": [],
635
- "target_size_as_tuple_ref": [],
636
- "target_size_as_tuple": [],
637
- "drop_im": [],
638
- "mask_ref": [],
639
- }
640
- if batch[0]["mask"] is not None:
641
- result["mask"] = []
642
- if batch[0]["depth"] is not None:
643
- result["depth"] = []
644
-
645
- for batch_obj in batch:
646
- for key in result.keys():
647
- result[key].append(batch_obj[key])
648
- for key in result.keys():
649
- if not (key == 'pose' or 'txt' in key or 'size_as_tuple_ref' in key or 'coords_top_left_ref' in key):
650
- result[key] = torch.stack(result[key], dim=0)
651
- elif 'txt_ref' in key:
652
- result[key] = [item for sublist in result[key] for item in sublist]
653
- elif 'size_as_tuple_ref' in key or 'coords_top_left_ref' in key:
654
- result[key] = torch.cat(result[key], dim=0)
655
- elif 'pose' in key:
656
- result[key] = [join_cameras_as_batch(cameras) for cameras in result[key]]
657
-
658
- return result
659
-
660
-
661
- class CustomDataDictLoader(pl.LightningDataModule):
662
- def __init__(
663
- self,
664
- category,
665
- batch_size,
666
- mask_images=False,
667
- skip=1,
668
- img_size=1024,
669
- num_images=4,
670
- num_workers=0,
671
- shuffle=True,
672
- single_id=0,
673
- modifier_token=None,
674
- bbox=False,
675
- addreg=False,
676
- drop_ratio=0.5,
677
- jitter=False,
678
- drop_txt=0.1,
679
- categoryname=None,
680
- ):
681
- super().__init__()
682
-
683
- self.batch_size = batch_size
684
- self.num_workers = num_workers
685
- self.shuffle = shuffle
686
- self.train_dataset = Co3dDataset(category,
687
- img_size=img_size,
688
- mask_images=mask_images,
689
- skip=skip,
690
- num_images=num_images,
691
- single_id=single_id,
692
- modifier_token=modifier_token,
693
- bbox=bbox,
694
- addreg=addreg,
695
- drop_ratio=drop_ratio,
696
- drop_txt=drop_txt,
697
- categoryname=categoryname,
698
- )
699
- self.val_dataset = Co3dDataset(category,
700
- img_size=img_size,
701
- mask_images=mask_images,
702
- skip=skip,
703
- num_images=2,
704
- single_id=single_id,
705
- modifier_token=modifier_token,
706
- bbox=bbox,
707
- addreg=addreg,
708
- drop_ratio=0.,
709
- drop_txt=0.,
710
- categoryname=categoryname,
711
- repeat=1,
712
- addlen=True,
713
- onlyref=True,
714
- )
715
- self.test_dataset = Co3dDataset(category,
716
- img_size=img_size,
717
- mask_images=mask_images,
718
- split="test",
719
- skip=skip,
720
- num_images=2,
721
- single_id=single_id,
722
- modifier_token=modifier_token,
723
- bbox=False,
724
- addreg=addreg,
725
- drop_ratio=0.,
726
- drop_txt=0.,
727
- categoryname=categoryname,
728
- repeat=1,
729
- )
730
- self.collate_fn = Co3dDataset.collate_fn
731
-
732
- def prepare_data(self):
733
- pass
734
-
735
- def train_dataloader(self):
736
- return DataLoader(
737
- self.train_dataset,
738
- batch_size=self.batch_size,
739
- shuffle=self.shuffle,
740
- num_workers=self.num_workers,
741
- collate_fn=self.collate_fn,
742
- drop_last=True,
743
- )
744
-
745
- def test_dataloader(self):
746
- return DataLoader(
747
- self.train_dataset,
748
- batch_size=self.batch_size,
749
- shuffle=False,
750
- num_workers=self.num_workers,
751
- collate_fn=self.collate_fn,
752
- )
753
-
754
- def val_dataloader(self):
755
- return DataLoader(
756
- self.val_dataset,
757
- batch_size=self.batch_size,
758
- shuffle=False,
759
- num_workers=self.num_workers,
760
- collate_fn=self.collate_fn,
761
- drop_last=True
762
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/lr_scheduler.py DELETED
@@ -1,135 +0,0 @@
1
- import numpy as np
2
-
3
-
4
- class LambdaWarmUpCosineScheduler:
5
- """
6
- note: use with a base_lr of 1.0
7
- """
8
-
9
- def __init__(
10
- self,
11
- warm_up_steps,
12
- lr_min,
13
- lr_max,
14
- lr_start,
15
- max_decay_steps,
16
- verbosity_interval=0,
17
- ):
18
- self.lr_warm_up_steps = warm_up_steps
19
- self.lr_start = lr_start
20
- self.lr_min = lr_min
21
- self.lr_max = lr_max
22
- self.lr_max_decay_steps = max_decay_steps
23
- self.last_lr = 0.0
24
- self.verbosity_interval = verbosity_interval
25
-
26
- def schedule(self, n, **kwargs):
27
- if self.verbosity_interval > 0:
28
- if n % self.verbosity_interval == 0:
29
- print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30
- if n < self.lr_warm_up_steps:
31
- lr = (
32
- self.lr_max - self.lr_start
33
- ) / self.lr_warm_up_steps * n + self.lr_start
34
- self.last_lr = lr
35
- return lr
36
- else:
37
- t = (n - self.lr_warm_up_steps) / (
38
- self.lr_max_decay_steps - self.lr_warm_up_steps
39
- )
40
- t = min(t, 1.0)
41
- lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
42
- 1 + np.cos(t * np.pi)
43
- )
44
- self.last_lr = lr
45
- return lr
46
-
47
- def __call__(self, n, **kwargs):
48
- return self.schedule(n, **kwargs)
49
-
50
-
51
- class LambdaWarmUpCosineScheduler2:
52
- """
53
- supports repeated iterations, configurable via lists
54
- note: use with a base_lr of 1.0.
55
- """
56
-
57
- def __init__(
58
- self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
59
- ):
60
- assert (
61
- len(warm_up_steps)
62
- == len(f_min)
63
- == len(f_max)
64
- == len(f_start)
65
- == len(cycle_lengths)
66
- )
67
- self.lr_warm_up_steps = warm_up_steps
68
- self.f_start = f_start
69
- self.f_min = f_min
70
- self.f_max = f_max
71
- self.cycle_lengths = cycle_lengths
72
- self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
73
- self.last_f = 0.0
74
- self.verbosity_interval = verbosity_interval
75
-
76
- def find_in_interval(self, n):
77
- interval = 0
78
- for cl in self.cum_cycles[1:]:
79
- if n <= cl:
80
- return interval
81
- interval += 1
82
-
83
- def schedule(self, n, **kwargs):
84
- cycle = self.find_in_interval(n)
85
- n = n - self.cum_cycles[cycle]
86
- if self.verbosity_interval > 0:
87
- if n % self.verbosity_interval == 0:
88
- print(
89
- f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90
- f"current cycle {cycle}"
91
- )
92
- if n < self.lr_warm_up_steps[cycle]:
93
- f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
94
- cycle
95
- ] * n + self.f_start[cycle]
96
- self.last_f = f
97
- return f
98
- else:
99
- t = (n - self.lr_warm_up_steps[cycle]) / (
100
- self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
101
- )
102
- t = min(t, 1.0)
103
- f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
104
- 1 + np.cos(t * np.pi)
105
- )
106
- self.last_f = f
107
- return f
108
-
109
- def __call__(self, n, **kwargs):
110
- return self.schedule(n, **kwargs)
111
-
112
-
113
- class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114
- def schedule(self, n, **kwargs):
115
- cycle = self.find_in_interval(n)
116
- n = n - self.cum_cycles[cycle]
117
- if self.verbosity_interval > 0:
118
- if n % self.verbosity_interval == 0:
119
- print(
120
- f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121
- f"current cycle {cycle}"
122
- )
123
-
124
- if n < self.lr_warm_up_steps[cycle]:
125
- f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
126
- cycle
127
- ] * n + self.f_start[cycle]
128
- self.last_f = f
129
- return f
130
- else:
131
- f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
132
- self.cycle_lengths[cycle] - n
133
- ) / (self.cycle_lengths[cycle])
134
- self.last_f = f
135
- return f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/models/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .autoencoder import AutoencodingEngine
2
- from .diffusion import DiffusionEngine
 
 
 
sgm/models/autoencoder.py DELETED
@@ -1,335 +0,0 @@
1
- import re
2
- from abc import abstractmethod
3
- from contextlib import contextmanager
4
- from typing import Any, Dict, Tuple, Union
5
-
6
- import pytorch_lightning as pl
7
- import torch
8
- from omegaconf import ListConfig
9
- from packaging import version
10
- from safetensors.torch import load_file as load_safetensors
11
-
12
- from ..modules.diffusionmodules.model import Decoder, Encoder
13
- from ..modules.distributions.distributions import DiagonalGaussianDistribution
14
- from ..modules.ema import LitEma
15
- from ..util import default, get_obj_from_str, instantiate_from_config
16
-
17
-
18
- class AbstractAutoencoder(pl.LightningModule):
19
- """
20
- This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
21
- unCLIP models, etc. Hence, it is fairly general, and specific features
22
- (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
23
- """
24
-
25
- def __init__(
26
- self,
27
- ema_decay: Union[None, float] = None,
28
- monitor: Union[None, str] = None,
29
- input_key: str = "jpg",
30
- ckpt_path: Union[None, str] = None,
31
- ignore_keys: Union[Tuple, list, ListConfig] = (),
32
- ):
33
- super().__init__()
34
- self.input_key = input_key
35
- self.use_ema = ema_decay is not None
36
- if monitor is not None:
37
- self.monitor = monitor
38
-
39
- if self.use_ema:
40
- self.model_ema = LitEma(self, decay=ema_decay)
41
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
42
-
43
- if ckpt_path is not None:
44
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
45
-
46
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
47
- self.automatic_optimization = False
48
-
49
- def init_from_ckpt(
50
- self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
51
- ) -> None:
52
- if path.endswith("ckpt"):
53
- sd = torch.load(path, map_location="cpu")["state_dict"]
54
- elif path.endswith("safetensors"):
55
- sd = load_safetensors(path)
56
- else:
57
- raise NotImplementedError
58
-
59
- keys = list(sd.keys())
60
- for k in keys:
61
- for ik in ignore_keys:
62
- if re.match(ik, k):
63
- print("Deleting key {} from state_dict.".format(k))
64
- del sd[k]
65
- missing, unexpected = self.load_state_dict(sd, strict=False)
66
- print(
67
- f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
68
- )
69
- if len(missing) > 0:
70
- print(f"Missing Keys: {missing}")
71
- if len(unexpected) > 0:
72
- print(f"Unexpected Keys: {unexpected}")
73
-
74
- @abstractmethod
75
- def get_input(self, batch) -> Any:
76
- raise NotImplementedError()
77
-
78
- def on_train_batch_end(self, *args, **kwargs):
79
- # for EMA computation
80
- if self.use_ema:
81
- self.model_ema(self)
82
-
83
- @contextmanager
84
- def ema_scope(self, context=None):
85
- if self.use_ema:
86
- self.model_ema.store(self.parameters())
87
- self.model_ema.copy_to(self)
88
- if context is not None:
89
- print(f"{context}: Switched to EMA weights")
90
- try:
91
- yield None
92
- finally:
93
- if self.use_ema:
94
- self.model_ema.restore(self.parameters())
95
- if context is not None:
96
- print(f"{context}: Restored training weights")
97
-
98
- @abstractmethod
99
- def encode(self, *args, **kwargs) -> torch.Tensor:
100
- raise NotImplementedError("encode()-method of abstract base class called")
101
-
102
- @abstractmethod
103
- def decode(self, *args, **kwargs) -> torch.Tensor:
104
- raise NotImplementedError("decode()-method of abstract base class called")
105
-
106
- def instantiate_optimizer_from_config(self, params, lr, cfg):
107
- print(f"loading >>> {cfg['target']} <<< optimizer from config")
108
- return get_obj_from_str(cfg["target"])(
109
- params, lr=lr, **cfg.get("params", dict())
110
- )
111
-
112
- def configure_optimizers(self) -> Any:
113
- raise NotImplementedError()
114
-
115
-
116
- class AutoencodingEngine(AbstractAutoencoder):
117
- """
118
- Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
119
- (we also restore them explicitly as special cases for legacy reasons).
120
- Regularizations such as KL or VQ are moved to the regularizer class.
121
- """
122
-
123
- def __init__(
124
- self,
125
- *args,
126
- encoder_config: Dict,
127
- decoder_config: Dict,
128
- loss_config: Dict,
129
- regularizer_config: Dict,
130
- optimizer_config: Union[Dict, None] = None,
131
- lr_g_factor: float = 1.0,
132
- **kwargs,
133
- ):
134
- super().__init__(*args, **kwargs)
135
- # todo: add options to freeze encoder/decoder
136
- self.encoder = instantiate_from_config(encoder_config)
137
- self.decoder = instantiate_from_config(decoder_config)
138
- self.loss = instantiate_from_config(loss_config)
139
- self.regularization = instantiate_from_config(regularizer_config)
140
- self.optimizer_config = default(
141
- optimizer_config, {"target": "torch.optim.Adam"}
142
- )
143
- self.lr_g_factor = lr_g_factor
144
-
145
- def get_input(self, batch: Dict) -> torch.Tensor:
146
- # assuming unified data format, dataloader returns a dict.
147
- # image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
148
- return batch[self.input_key]
149
-
150
- def get_autoencoder_params(self) -> list:
151
- params = (
152
- list(self.encoder.parameters())
153
- + list(self.decoder.parameters())
154
- + list(self.regularization.get_trainable_parameters())
155
- + list(self.loss.get_trainable_autoencoder_parameters())
156
- )
157
- return params
158
-
159
- def get_discriminator_params(self) -> list:
160
- params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
161
- return params
162
-
163
- def get_last_layer(self):
164
- return self.decoder.get_last_layer()
165
-
166
- def encode(self, x: Any, return_reg_log: bool = False) -> Any:
167
- z = self.encoder(x)
168
- z, reg_log = self.regularization(z)
169
- if return_reg_log:
170
- return z, reg_log
171
- return z
172
-
173
- def decode(self, z: Any) -> torch.Tensor:
174
- x = self.decoder(z)
175
- return x
176
-
177
- def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178
- z, reg_log = self.encode(x, return_reg_log=True)
179
- dec = self.decode(z)
180
- return z, dec, reg_log
181
-
182
- def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
183
- x = self.get_input(batch)
184
- z, xrec, regularization_log = self(x)
185
-
186
- if optimizer_idx == 0:
187
- # autoencode
188
- aeloss, log_dict_ae = self.loss(
189
- regularization_log,
190
- x,
191
- xrec,
192
- optimizer_idx,
193
- self.global_step,
194
- last_layer=self.get_last_layer(),
195
- split="train",
196
- )
197
-
198
- self.log_dict(
199
- log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
200
- )
201
- return aeloss
202
-
203
- if optimizer_idx == 1:
204
- # discriminator
205
- discloss, log_dict_disc = self.loss(
206
- regularization_log,
207
- x,
208
- xrec,
209
- optimizer_idx,
210
- self.global_step,
211
- last_layer=self.get_last_layer(),
212
- split="train",
213
- )
214
- self.log_dict(
215
- log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
216
- )
217
- return discloss
218
-
219
- def validation_step(self, batch, batch_idx) -> Dict:
220
- log_dict = self._validation_step(batch, batch_idx)
221
- with self.ema_scope():
222
- log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
223
- log_dict.update(log_dict_ema)
224
- return log_dict
225
-
226
- def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
227
- x = self.get_input(batch)
228
-
229
- z, xrec, regularization_log = self(x)
230
- aeloss, log_dict_ae = self.loss(
231
- regularization_log,
232
- x,
233
- xrec,
234
- 0,
235
- self.global_step,
236
- last_layer=self.get_last_layer(),
237
- split="val" + postfix,
238
- )
239
-
240
- discloss, log_dict_disc = self.loss(
241
- regularization_log,
242
- x,
243
- xrec,
244
- 1,
245
- self.global_step,
246
- last_layer=self.get_last_layer(),
247
- split="val" + postfix,
248
- )
249
- self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
250
- log_dict_ae.update(log_dict_disc)
251
- self.log_dict(log_dict_ae)
252
- return log_dict_ae
253
-
254
- def configure_optimizers(self) -> Any:
255
- ae_params = self.get_autoencoder_params()
256
- disc_params = self.get_discriminator_params()
257
-
258
- opt_ae = self.instantiate_optimizer_from_config(
259
- ae_params,
260
- default(self.lr_g_factor, 1.0) * self.learning_rate,
261
- self.optimizer_config,
262
- )
263
- opt_disc = self.instantiate_optimizer_from_config(
264
- disc_params, self.learning_rate, self.optimizer_config
265
- )
266
-
267
- return [opt_ae, opt_disc], []
268
-
269
- @torch.no_grad()
270
- def log_images(self, batch: Dict, **kwargs) -> Dict:
271
- log = dict()
272
- x = self.get_input(batch)
273
- _, xrec, _ = self(x)
274
- log["inputs"] = x
275
- log["reconstructions"] = xrec
276
- with self.ema_scope():
277
- _, xrec_ema, _ = self(x)
278
- log["reconstructions_ema"] = xrec_ema
279
- return log
280
-
281
-
282
- class AutoencoderKL(AutoencodingEngine):
283
- def __init__(self, embed_dim: int, **kwargs):
284
- ddconfig = kwargs.pop("ddconfig")
285
- ckpt_path = kwargs.pop("ckpt_path", None)
286
- ignore_keys = kwargs.pop("ignore_keys", ())
287
- super().__init__(
288
- encoder_config={"target": "torch.nn.Identity"},
289
- decoder_config={"target": "torch.nn.Identity"},
290
- regularizer_config={"target": "torch.nn.Identity"},
291
- loss_config=kwargs.pop("lossconfig"),
292
- **kwargs,
293
- )
294
- assert ddconfig["double_z"]
295
- self.encoder = Encoder(**ddconfig)
296
- self.decoder = Decoder(**ddconfig)
297
- self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
298
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
299
- self.embed_dim = embed_dim
300
-
301
- if ckpt_path is not None:
302
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
303
-
304
- def encode(self, x):
305
- assert (
306
- not self.training
307
- ), f"{self.__class__.__name__} only supports inference currently"
308
- h = self.encoder(x)
309
- moments = self.quant_conv(h)
310
- posterior = DiagonalGaussianDistribution(moments)
311
- return posterior
312
-
313
- def decode(self, z, **decoder_kwargs):
314
- z = self.post_quant_conv(z)
315
- dec = self.decoder(z, **decoder_kwargs)
316
- return dec
317
-
318
-
319
- class AutoencoderKLInferenceWrapper(AutoencoderKL):
320
- def encode(self, x):
321
- return super().encode(x).sample()
322
-
323
-
324
- class IdentityFirstStage(AbstractAutoencoder):
325
- def __init__(self, *args, **kwargs):
326
- super().__init__(*args, **kwargs)
327
-
328
- def get_input(self, x: Any) -> Any:
329
- return x
330
-
331
- def encode(self, x: Any, *args, **kwargs) -> Any:
332
- return x
333
-
334
- def decode(self, x: Any, *args, **kwargs) -> Any:
335
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/models/diffusion.py DELETED
@@ -1,556 +0,0 @@
1
- from contextlib import contextmanager
2
- from typing import Any, Dict, List, Tuple, Union, DefaultDict
3
-
4
- import pytorch_lightning as pl
5
- import torch
6
- from omegaconf import ListConfig, OmegaConf
7
- from safetensors.torch import load_file as load_safetensors
8
- from torch.optim.lr_scheduler import LambdaLR
9
- from einops import rearrange
10
- import math
11
- import torch.nn as nn
12
- from ..modules import UNCONDITIONAL_CONFIG
13
- from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
14
- from ..modules.ema import LitEma
15
- from ..util import (
16
- default,
17
- disabled_train,
18
- get_obj_from_str,
19
- instantiate_from_config,
20
- log_txt_as_img,
21
- )
22
-
23
-
24
- import collections
25
- from functools import partial
26
-
27
-
28
- def save_activations(
29
- activations: DefaultDict,
30
- name: str,
31
- module: nn.Module,
32
- inp: Tuple,
33
- out: torch.Tensor
34
- ) -> None:
35
- """PyTorch Forward hook to save outputs at each forward
36
- pass. Mutates specified dict objects with each fwd pass.
37
- """
38
- if isinstance(out, tuple):
39
- if out[1] is None:
40
- activations[name].append(out[0].detach())
41
-
42
- class DiffusionEngine(pl.LightningModule):
43
- def __init__(
44
- self,
45
- network_config,
46
- denoiser_config,
47
- first_stage_config,
48
- conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
49
- sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
50
- optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
51
- scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
52
- loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
53
- network_wrapper: Union[None, str] = None,
54
- ckpt_path: Union[None, str] = None,
55
- use_ema: bool = False,
56
- ema_decay_rate: float = 0.9999,
57
- scale_factor: float = 1.0,
58
- disable_first_stage_autocast=False,
59
- input_key: str = "jpg",
60
- log_keys: Union[List, None] = None,
61
- no_cond_log: bool = False,
62
- compile_model: bool = False,
63
- trainkeys='pose',
64
- multiplier=0.05,
65
- loss_rgb_lambda=20.,
66
- loss_fg_lambda=10.,
67
- loss_bg_lambda=20.,
68
- ):
69
- super().__init__()
70
- self.log_keys = log_keys
71
- self.input_key = input_key
72
- self.trainkeys = trainkeys
73
- self.multiplier = multiplier
74
- self.loss_rgb_lambda = loss_rgb_lambda
75
- self.loss_fg_lambda = loss_fg_lambda
76
- self.loss_bg_lambda = loss_bg_lambda
77
- self.rgb = network_config.params.rgb
78
- self.rgb_predict = network_config.params.rgb_predict
79
- self.add_token = ('modifier_token' in conditioner_config.params.emb_models[1].params)
80
- self.optimizer_config = default(
81
- optimizer_config, {"target": "torch.optim.AdamW"}
82
- )
83
- model = instantiate_from_config(network_config)
84
- self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
85
- model, compile_model=compile_model
86
- )
87
-
88
- self.denoiser = instantiate_from_config(denoiser_config)
89
- self.sampler = (
90
- instantiate_from_config(sampler_config)
91
- if sampler_config is not None
92
- else None
93
- )
94
- self.conditioner = instantiate_from_config(
95
- default(conditioner_config, UNCONDITIONAL_CONFIG)
96
- )
97
- self.scheduler_config = scheduler_config
98
- self._init_first_stage(first_stage_config)
99
-
100
- self.loss_fn = (
101
- instantiate_from_config(loss_fn_config)
102
- if loss_fn_config is not None
103
- else None
104
- )
105
-
106
- self.use_ema = use_ema
107
- if self.use_ema:
108
- self.model_ema = LitEma(self.model, decay=ema_decay_rate)
109
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
110
-
111
- self.scale_factor = scale_factor
112
- self.disable_first_stage_autocast = disable_first_stage_autocast
113
- self.no_cond_log = no_cond_log
114
-
115
- if ckpt_path is not None:
116
- self.init_from_ckpt(ckpt_path)
117
-
118
- blocks = []
119
- if self.trainkeys == 'poseattn':
120
- for x in self.model.diffusion_model.named_parameters():
121
- if not ('pose' in x[0] or 'transformer_blocks' in x[0]):
122
- x[1].requires_grad = False
123
- else:
124
- if 'pose' in x[0]:
125
- x[1].requires_grad = True
126
- blocks.append(x[0].split('.pose')[0])
127
-
128
- blocks = set(blocks)
129
- for x in self.model.diffusion_model.named_parameters():
130
- if 'transformer_blocks' in x[0]:
131
- reqgrad = False
132
- for each in blocks:
133
- if each in x[0] and ('attn1' in x[0] or 'attn2' in x[0] or 'pose' in x[0]):
134
- reqgrad = True
135
- x[1].requires_grad = True
136
- if not reqgrad:
137
- x[1].requires_grad = False
138
- elif self.trainkeys == 'pose':
139
- for x in self.model.diffusion_model.named_parameters():
140
- if not ('pose' in x[0]):
141
- x[1].requires_grad = False
142
- else:
143
- x[1].requires_grad = True
144
- elif self.trainkeys == 'all':
145
- for x in self.model.diffusion_model.named_parameters():
146
- x[1].requires_grad = True
147
-
148
- self.model = self.model.to(memory_format=torch.channels_last)
149
-
150
- def register_activation_hooks(
151
- self,
152
- ) -> None:
153
- self.activations_dict = collections.defaultdict(list)
154
- handles = []
155
- for name, module in self.model.diffusion_model.named_modules():
156
- if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks':
157
- if hasattr(module, 'pose_emb_layers'):
158
- handle = module.register_forward_hook(
159
- partial(save_activations, self.activations_dict, name)
160
- )
161
- handles.append(handle)
162
- self.handles = handles
163
-
164
- def clear_rendered_feat(self,):
165
- for name, module in self.model.diffusion_model.named_modules():
166
- if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks':
167
- if hasattr(module, 'pose_emb_layers'):
168
- module.rendered_feat = None
169
-
170
- def remove_activation_hooks(
171
- self, handles
172
- ) -> None:
173
- for handle in handles:
174
- handle.remove()
175
-
176
- def init_from_ckpt(
177
- self,
178
- path: str,
179
- ) -> None:
180
- if path.endswith("ckpt"):
181
- sd = torch.load(path, map_location="cpu")["state_dict"]
182
- elif path.endswith("safetensors"):
183
- sd = load_safetensors(path)
184
- else:
185
- raise NotImplementedError
186
-
187
- missing, unexpected = self.load_state_dict(sd, strict=False)
188
- print(
189
- f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
190
- )
191
- if len(missing) > 0:
192
- print(f"Missing Keys: {missing}")
193
- if len(unexpected) > 0:
194
- print(f"Unexpected Keys: {unexpected}")
195
-
196
- def _init_first_stage(self, config):
197
- model = instantiate_from_config(config).eval()
198
- model.train = disabled_train
199
- for param in model.parameters():
200
- param.requires_grad = False
201
- self.first_stage_model = model
202
-
203
- def get_input(self, batch):
204
- return batch[self.input_key], batch[self.input_key + '_ref'] if self.input_key + '_ref' in batch else None, batch['pose'] if 'pose' in batch else None, batch['mask'] if "mask" in batch else None, batch['mask_ref'] if "mask_ref" in batch else None, batch['depth'] if "depth" in batch else None, batch['drop_im'] if "drop_im" in batch else 0.
205
-
206
- @torch.no_grad()
207
- def decode_first_stage(self, z):
208
- z = 1.0 / self.scale_factor * z
209
- with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
210
- out = self.first_stage_model.decode(z)
211
- return out
212
-
213
- @torch.no_grad()
214
- def encode_first_stage(self, x):
215
- with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
216
- z = self.first_stage_model.encode(x)
217
- z = self.scale_factor * z
218
- return z
219
-
220
- def forward(self, x, x_rgb, xr, pose, mask, mask_ref, opacity, drop_im, batch):
221
- loss, loss_fg, loss_bg, loss_rgb = self.loss_fn(self.model, self.denoiser, self.conditioner, x, x_rgb, xr, pose, mask, mask_ref, opacity, batch)
222
- loss_mean = loss.mean()
223
- loss_dict = {"loss": loss_mean.item()}
224
- if self.rgb and self.global_step > 0:
225
- loss_fg = (loss_fg.mean(1)*drop_im.reshape(-1)).sum()/(drop_im.sum() + 1e-12)
226
- loss_bg = (loss_bg.mean(1)*drop_im.reshape(-1)).sum()/(drop_im.sum() + 1e-12)
227
- loss_mean += self.loss_fg_lambda*loss_fg
228
- loss_mean += self.loss_bg_lambda*loss_bg
229
- loss_dict["loss_fg"] = loss_fg.item()
230
- loss_dict["loss_bg"] = loss_bg.item()
231
- if self.rgb_predict and loss_rgb.mean() > 0:
232
- loss_rgb = (loss_rgb.mean(1)*drop_im.reshape(-1)).sum()/(drop_im.sum() + 1e-12)
233
- loss_mean += self.loss_rgb_lambda*loss_rgb
234
- loss_dict["loss_rgb"] = loss_rgb.item()
235
- return loss_mean, loss_dict
236
-
237
- def shared_step(self, batch: Dict) -> Any:
238
- x, xr, pose, mask, mask_ref, opacity, drop_im = self.get_input(batch)
239
- x_rgb = x.clone().detach()
240
- x = self.encode_first_stage(x)
241
- x = x.to(memory_format=torch.channels_last)
242
- if xr is not None:
243
- b, n = xr.shape[0], xr.shape[1]
244
- xr = rearrange(self.encode_first_stage(rearrange(xr, "b n ... -> (b n) ...")), "(b n) ... -> b n ...", b=b, n=n)
245
- xr = drop_im.reshape(b, 1, 1, 1, 1)*xr + (1-drop_im.reshape(b, 1, 1, 1, 1))*torch.zeros_like(xr)
246
- batch["global_step"] = self.global_step
247
- loss, loss_dict = self(x, x_rgb, xr, pose, mask, mask_ref, opacity, drop_im, batch)
248
- return loss, loss_dict
249
-
250
- def training_step(self, batch, batch_idx):
251
- loss, loss_dict = self.shared_step(batch)
252
-
253
- self.log_dict(
254
- loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
255
- )
256
-
257
- self.log(
258
- "global_step",
259
- self.global_step,
260
- prog_bar=True,
261
- logger=True,
262
- on_step=True,
263
- on_epoch=False,
264
- )
265
-
266
- if self.scheduler_config is not None:
267
- lr = self.optimizers().param_groups[0]["lr"]
268
- self.log(
269
- "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
270
- )
271
- return loss
272
-
273
- def validation_step(self, batch, batch_idx):
274
- # print("validation data", len(self.trainer.val_dataloaders))
275
- loss, loss_dict = self.shared_step(batch)
276
- return loss
277
-
278
- def on_train_start(self, *args, **kwargs):
279
- if self.sampler is None or self.loss_fn is None:
280
- raise ValueError("Sampler and loss function need to be set for training.")
281
-
282
- def on_train_batch_end(self, *args, **kwargs):
283
- if self.use_ema:
284
- self.model_ema(self.model)
285
-
286
- def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
287
- optimizer.zero_grad(set_to_none=True)
288
-
289
- @contextmanager
290
- def ema_scope(self, context=None):
291
- if self.use_ema:
292
- self.model_ema.store(self.model.parameters())
293
- self.model_ema.copy_to(self.model)
294
- if context is not None:
295
- print(f"{context}: Switched to EMA weights")
296
- try:
297
- yield None
298
- finally:
299
- if self.use_ema:
300
- self.model_ema.restore(self.model.parameters())
301
- if context is not None:
302
- print(f"{context}: Restored training weights")
303
-
304
- def instantiate_optimizer_from_config(self, params, lr, cfg):
305
- return get_obj_from_str(cfg["target"])(
306
- params, lr=lr, **cfg.get("params", dict())
307
- )
308
-
309
- def configure_optimizers(self):
310
- lr = self.learning_rate
311
- params = []
312
- blocks = []
313
- lowlrparams = []
314
- if self.trainkeys == 'poseattn':
315
- lowlrparams = []
316
- for x in self.model.diffusion_model.named_parameters():
317
- if ('pose' in x[0]):
318
- params += [x[1]]
319
- blocks.append(x[0].split('.pose')[0])
320
- print(x[0])
321
- blocks = set(blocks)
322
- for x in self.model.diffusion_model.named_parameters():
323
- if 'transformer_blocks' in x[0]:
324
- for each in blocks:
325
- if each in x[0] and not ('pose' in x[0]) and ('attn1' in x[0] or 'attn2' in x[0]):
326
- lowlrparams += [x[1]]
327
- elif self.trainkeys == 'pose':
328
- for x in self.model.diffusion_model.named_parameters():
329
- if ('pose' in x[0]):
330
- params += [x[1]]
331
- print(x[0])
332
- elif self.trainkeys == 'all':
333
- lowlrparams = []
334
- for x in self.model.diffusion_model.named_parameters():
335
- if ('pose' in x[0]):
336
- params += [x[1]]
337
- print(x[0])
338
- else:
339
- lowlrparams += [x[1]]
340
-
341
- for i, embedder in enumerate(self.conditioner.embedders[:2]):
342
- if embedder.is_trainable:
343
- params = params + list(embedder.parameters())
344
- if self.add_token:
345
- if i == 0:
346
- for name, param in embedder.transformer.get_input_embeddings().named_parameters():
347
- param.requires_grad = True
348
- print(name, "conditional model param")
349
- params += [param]
350
- else:
351
- for name, param in embedder.model.token_embedding.named_parameters():
352
- param.requires_grad = True
353
- print(name, "conditional model param")
354
- params += [param]
355
-
356
- if len(lowlrparams) > 0:
357
- print("different optimizer groups")
358
- opt = self.instantiate_optimizer_from_config([{'params': params}, {'params': lowlrparams, 'lr': self.multiplier*lr}], lr, self.optimizer_config)
359
- else:
360
- opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
361
- if self.scheduler_config is not None:
362
- scheduler = instantiate_from_config(self.scheduler_config)
363
- print("Setting up LambdaLR scheduler...")
364
- scheduler = [
365
- {
366
- "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
367
- "interval": "step",
368
- "frequency": 1,
369
- }
370
- ]
371
- return [opt], scheduler
372
- return opt
373
-
374
- @torch.no_grad()
375
- def sample(
376
- self,
377
- cond: Dict,
378
- uc: Union[Dict, None] = None,
379
- batch_size: int = 16,
380
- num_steps=None,
381
- randn=None,
382
- shape: Union[None, Tuple, List] = None,
383
- return_rgb=False,
384
- mask=None,
385
- init_im=None,
386
- **kwargs,
387
- ):
388
- if randn is None:
389
- randn = torch.randn(batch_size, *shape)
390
-
391
- denoiser = lambda input, sigma, c: self.denoiser(
392
- self.model, input, sigma, c, **kwargs
393
- )
394
- if mask is not None:
395
- samples, rgb_list = self.sampler(denoiser, randn.to(self.device), cond, uc=uc, mask=mask, init_im=init_im, num_steps=num_steps)
396
- else:
397
- samples, rgb_list = self.sampler(denoiser, randn.to(self.device), cond, uc=uc, num_steps=num_steps)
398
- if return_rgb:
399
- return samples, rgb_list
400
- return samples
401
-
402
- @torch.no_grad()
403
- def samplemulti(
404
- self,
405
- cond,
406
- uc=None,
407
- batch_size: int = 16,
408
- num_steps=None,
409
- randn=None,
410
- shape: Union[None, Tuple, List] = None,
411
- return_rgb=False,
412
- mask=None,
413
- init_im=None,
414
- multikwargs=None,
415
- ):
416
- if randn is None:
417
- randn = torch.randn(batch_size, *shape)
418
-
419
- samples, rgb_list = self.sampler(self.denoiser, self.model, randn.to(self.device), cond, uc=uc, num_steps=num_steps, multikwargs=multikwargs)
420
- if return_rgb:
421
- return samples, rgb_list
422
- return samples
423
-
424
- @torch.no_grad()
425
- def log_conditionings(self, batch: Dict, n: int, refernce: bool = True) -> Dict:
426
- """
427
- Defines heuristics to log different conditionings.
428
- These can be lists of strings (text-to-image), tensors, ints, ...
429
- """
430
- image_h, image_w = batch[self.input_key].shape[2:]
431
- log = dict()
432
-
433
- for embedder in self.conditioner.embedders:
434
- if refernce:
435
- check = (embedder.input_keys[0] in self.log_keys)
436
- else:
437
- check = (embedder.input_key in self.log_keys)
438
- if (
439
- (self.log_keys is None) or check
440
- ) and not self.no_cond_log:
441
- if refernce:
442
- x = batch[embedder.input_keys[0]][:n]
443
- else:
444
- x = batch[embedder.input_key][:n]
445
- if isinstance(x, torch.Tensor):
446
- if x.dim() == 1:
447
- # class-conditional, convert integer to string
448
- x = [str(x[i].item()) for i in range(x.shape[0])]
449
- xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
450
- elif x.dim() == 2:
451
- # size and crop cond and the like
452
- x = [
453
- "x".join([str(xx) for xx in x[i].tolist()])
454
- for i in range(x.shape[0])
455
- ]
456
- xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
457
- else:
458
- raise NotImplementedError()
459
- elif isinstance(x, (List, ListConfig)):
460
- if isinstance(x[0], str):
461
- # strings
462
- xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
463
- else:
464
- raise NotImplementedError()
465
- else:
466
- raise NotImplementedError()
467
- if refernce:
468
- log[embedder.input_keys[0]] = xc
469
- else:
470
- log[embedder.input_key] = xc
471
- return log
472
-
473
- @torch.no_grad()
474
- def log_images(
475
- self,
476
- batch: Dict,
477
- N: int = 8,
478
- sample: bool = True,
479
- ucg_keys: List[str] = None,
480
- **kwargs,
481
- ) -> Dict:
482
- log = dict()
483
-
484
- x, xr, pose, mask, mask_ref, depth, drop_im = self.get_input(batch)
485
-
486
- if xr is not None:
487
- conditioner_input_keys = [e.input_keys for e in self.conditioner.embedders]
488
- else:
489
- conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
490
-
491
- if ucg_keys:
492
- assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
493
- "Each defined ucg key for sampling must be in the provided conditioner input keys,"
494
- f"but we have {ucg_keys} vs. {conditioner_input_keys}"
495
- )
496
- else:
497
- ucg_keys = conditioner_input_keys
498
-
499
- c, uc = self.conditioner.get_unconditional_conditioning(
500
- batch,
501
- force_uc_zero_embeddings=ucg_keys
502
- if len(self.conditioner.embedders) > 0
503
- else [],
504
- )
505
-
506
- N = min(x.shape[0], N)
507
- x = x.to(self.device)[:N]
508
- zr = None
509
- if xr is not None:
510
- xr = xr.to(self.device)[:N]
511
- b, n = xr.shape[0], xr.shape[1]
512
- log["reference"] = rearrange(xr, "b n ... -> (b n) ...", b=b, n=n)
513
- zr = rearrange(self.encode_first_stage(rearrange(xr, "b n ... -> (b n) ...", b=b, n=n)), "(b n) ... -> b n ...", b=b, n=n)
514
-
515
- log["inputs"] = x
516
- b = x.shape[0]
517
- if mask is not None:
518
- log["mask"] = mask
519
- if depth is not None:
520
- log["depth"] = depth
521
- z = self.encode_first_stage(x)
522
-
523
- if uc is not None:
524
- if xr is not None:
525
- zr = torch.cat([torch.zeros_like(zr), zr])
526
- drop_im = torch.cat([drop_im, drop_im])
527
- if isinstance(pose, list):
528
- pose = pose[:N]*2
529
- else:
530
- pose = torch.cat([pose[:N]] * 2)
531
-
532
- sampling_kwargs = {'input_ref':zr}
533
- sampling_kwargs['pose'] = pose
534
- sampling_kwargs['mask_ref'] = None
535
- sampling_kwargs['drop_im'] = drop_im
536
-
537
- log["reconstructions"] = self.decode_first_stage(z)
538
- log.update(self.log_conditionings(batch, N, refernce=True if xr is not None else False))
539
-
540
- for k in c:
541
- if isinstance(c[k], torch.Tensor):
542
- if xr is not None:
543
- c[k], uc[k] = map(lambda y: y[k][:(n+1)*N].to(self.device), (c, uc))
544
- else:
545
- c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
546
- if sample:
547
- with self.ema_scope("Plotting"):
548
- samples, rgb_list = self.sample(
549
- c, shape=z.shape[1:], uc=uc, batch_size=N, return_rgb=True, **sampling_kwargs
550
- )
551
- samples = self.decode_first_stage(samples)
552
- log["samples"] = samples
553
- if len(rgb_list) > 0:
554
- size = int(math.sqrt(rgb_list[0].size(1)))
555
- log["predicted_rgb"] = rgb_list[0].reshape(-1, size, size, 3).permute(0, 3, 1, 2)
556
- return log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from .encoders.modules import GeneralConditioner
2
-
3
- UNCONDITIONAL_CONFIG = {
4
- "target": "sgm.modules.GeneralConditioner",
5
- "params": {"emb_models": []},
6
- }
 
 
 
 
 
 
 
sgm/modules/attention.py DELETED
@@ -1,1202 +0,0 @@
1
- import logging
2
- import math
3
- import itertools
4
- from inspect import isfunction
5
- from typing import Any, Optional
6
- import numpy as np
7
- import torch
8
- import torch.nn.functional as F
9
- from einops import rearrange, repeat
10
- from packaging import version
11
- from torch import nn
12
- from .diffusionmodules.util import checkpoint
13
- from torch.autograd import Function
14
- from torch.cuda.amp import custom_bwd, custom_fwd
15
-
16
- from ..modules.diffusionmodules.util import zero_module
17
- from ..modules.nerfsd_pytorch3d import NerfSDModule, VolRender
18
-
19
- logpy = logging.getLogger(__name__)
20
-
21
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
22
- SDP_IS_AVAILABLE = True
23
- from torch.backends.cuda import SDPBackend, sdp_kernel
24
-
25
- BACKEND_MAP = {
26
- SDPBackend.MATH: {
27
- "enable_math": True,
28
- "enable_flash": False,
29
- "enable_mem_efficient": False,
30
- },
31
- SDPBackend.FLASH_ATTENTION: {
32
- "enable_math": False,
33
- "enable_flash": True,
34
- "enable_mem_efficient": False,
35
- },
36
- SDPBackend.EFFICIENT_ATTENTION: {
37
- "enable_math": False,
38
- "enable_flash": False,
39
- "enable_mem_efficient": True,
40
- },
41
- None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
42
- }
43
- else:
44
- from contextlib import nullcontext
45
-
46
- SDP_IS_AVAILABLE = False
47
- sdp_kernel = nullcontext
48
- BACKEND_MAP = {}
49
- logpy.warn(
50
- f"No SDP backend available, likely because you are running in pytorch "
51
- f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
52
- f"You might want to consider upgrading."
53
- )
54
-
55
- try:
56
- import xformers
57
- import xformers.ops
58
-
59
- XFORMERS_IS_AVAILABLE = True
60
- except:
61
- XFORMERS_IS_AVAILABLE = False
62
- logpy.warn("no module 'xformers'. Processing without...")
63
-
64
-
65
- def exists(val):
66
- return val is not None
67
-
68
-
69
- def uniq(arr):
70
- return {el: True for el in arr}.keys()
71
-
72
-
73
- def default(val, d):
74
- if exists(val):
75
- return val
76
- return d() if isfunction(d) else d
77
-
78
-
79
- def max_neg_value(t):
80
- return -torch.finfo(t.dtype).max
81
-
82
-
83
- def init_(tensor):
84
- dim = tensor.shape[-1]
85
- std = 1 / math.sqrt(dim)
86
- tensor.uniform_(-std, std)
87
- return tensor
88
-
89
-
90
- # feedforward
91
- class GEGLU(nn.Module):
92
- def __init__(self, dim_in, dim_out):
93
- super().__init__()
94
- self.proj = nn.Linear(dim_in, dim_out * 2)
95
-
96
- def forward(self, x):
97
- x, gate = self.proj(x).chunk(2, dim=-1)
98
- return x * F.gelu(gate)
99
-
100
-
101
- class FeedForward(nn.Module):
102
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
103
- super().__init__()
104
- inner_dim = int(dim * mult)
105
- dim_out = default(dim_out, dim)
106
- project_in = (
107
- nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
108
- if not glu
109
- else GEGLU(dim, inner_dim)
110
- )
111
-
112
- self.net = nn.Sequential(
113
- project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
114
- )
115
-
116
- def forward(self, x):
117
- return self.net(x)
118
-
119
-
120
- def Normalize(in_channels):
121
- return torch.nn.GroupNorm(
122
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
123
- )
124
-
125
-
126
- class LinearAttention(nn.Module):
127
- def __init__(self, dim, heads=4, dim_head=32):
128
- super().__init__()
129
- self.heads = heads
130
- hidden_dim = dim_head * heads
131
- self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
132
- self.to_out = nn.Conv2d(hidden_dim, dim, 1)
133
-
134
- def forward(self, x):
135
- b, c, h, w = x.shape
136
- qkv = self.to_qkv(x)
137
- q, k, v = rearrange(
138
- qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
139
- )
140
- k = k.softmax(dim=-1)
141
- context = torch.einsum("bhdn,bhen->bhde", k, v)
142
- out = torch.einsum("bhde,bhdn->bhen", context, q)
143
- out = rearrange(
144
- out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
145
- )
146
- return self.to_out(out)
147
-
148
-
149
- class SpatialSelfAttention(nn.Module):
150
- def __init__(self, in_channels):
151
- super().__init__()
152
- self.in_channels = in_channels
153
-
154
- self.norm = Normalize(in_channels)
155
- self.q = torch.nn.Conv2d(
156
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
157
- )
158
- self.k = torch.nn.Conv2d(
159
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
160
- )
161
- self.v = torch.nn.Conv2d(
162
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
163
- )
164
- self.proj_out = torch.nn.Conv2d(
165
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
- )
167
-
168
- def forward(self, x):
169
- h_ = x
170
- h_ = self.norm(h_)
171
- q = self.q(h_)
172
- k = self.k(h_)
173
- v = self.v(h_)
174
-
175
- # compute attention
176
- b, c, h, w = q.shape
177
- q = rearrange(q, "b c h w -> b (h w) c")
178
- k = rearrange(k, "b c h w -> b c (h w)")
179
- w_ = torch.einsum("bij,bjk->bik", q, k)
180
-
181
- w_ = w_ * (int(c) ** (-0.5))
182
- w_ = torch.nn.functional.softmax(w_, dim=2)
183
-
184
- # attend to values
185
- v = rearrange(v, "b c h w -> b c (h w)")
186
- w_ = rearrange(w_, "b i j -> b j i")
187
- h_ = torch.einsum("bij,bjk->bik", v, w_)
188
- h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
189
- h_ = self.proj_out(h_)
190
-
191
- return x + h_
192
-
193
-
194
- class _TruncExp(Function): # pylint: disable=abstract-method
195
- # Implementation from torch-ngp:
196
- # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
197
- @staticmethod
198
- @custom_fwd(cast_inputs=torch.float32)
199
- def forward(ctx, x): # pylint: disable=arguments-differ
200
- ctx.save_for_backward(x)
201
- return torch.exp(x)
202
-
203
- @staticmethod
204
- @custom_bwd
205
- def backward(ctx, g): # pylint: disable=arguments-differ
206
- x = ctx.saved_tensors[0]
207
- return g * torch.exp(x.clamp(-15, 15))
208
-
209
-
210
- trunc_exp = _TruncExp.apply
211
- """Same as torch.exp, but with the backward pass clipped to prevent vanishing/exploding
212
- gradients."""
213
-
214
-
215
- class CrossAttention(nn.Module):
216
- def __init__(
217
- self,
218
- query_dim,
219
- context_dim=None,
220
- heads=8,
221
- dim_head=64,
222
- dropout=0.0,
223
- backend=None,
224
- ):
225
- super().__init__()
226
- inner_dim = dim_head * heads
227
- context_dim = default(context_dim, query_dim)
228
-
229
- self.scale = dim_head**-0.5
230
- self.heads = heads
231
-
232
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
233
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
234
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
235
-
236
- self.to_out = nn.Sequential(
237
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
238
- )
239
- self.backend = backend
240
-
241
- def forward(
242
- self,
243
- x,
244
- context=None,
245
- mask=None,
246
- additional_tokens=None,
247
- n_times_crossframe_attn_in_self=0,
248
- ):
249
- h = self.heads
250
-
251
- if additional_tokens is not None:
252
- # get the number of masked tokens at the beginning of the output sequence
253
- n_tokens_to_mask = additional_tokens.shape[1]
254
- # add additional token
255
- x = torch.cat([additional_tokens, x], dim=1)
256
-
257
- q = self.to_q(x)
258
- context = default(context, x)
259
- k = self.to_k(context)
260
- v = self.to_v(context)
261
-
262
- if n_times_crossframe_attn_in_self:
263
- # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
264
- assert x.shape[0] % n_times_crossframe_attn_in_self == 0
265
- n_cp = x.shape[0] // n_times_crossframe_attn_in_self
266
- k = repeat(
267
- k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
268
- )
269
- v = repeat(
270
- v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
271
- )
272
-
273
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
274
-
275
- ## old
276
- """
277
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
278
- del q, k
279
-
280
- if exists(mask):
281
- mask = rearrange(mask, 'b ... -> b (...)')
282
- max_neg_value = -torch.finfo(sim.dtype).max
283
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
284
- sim.masked_fill_(~mask, max_neg_value)
285
-
286
- # attention, what we cannot get enough of
287
- sim = sim.softmax(dim=-1)
288
-
289
- out = einsum('b i j, b j d -> b i d', sim, v)
290
- """
291
- ## new
292
- with sdp_kernel(**BACKEND_MAP[self.backend]):
293
- # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
294
- out = F.scaled_dot_product_attention(
295
- q, k, v, attn_mask=mask
296
- ) # scale is dim_head ** -0.5 per default
297
-
298
- del q, k, v
299
- out = rearrange(out, "b h n d -> b n (h d)", h=h)
300
-
301
- if additional_tokens is not None:
302
- # remove additional token
303
- out = out[:, n_tokens_to_mask:]
304
- return self.to_out(out)
305
-
306
-
307
- class MemoryEfficientCrossAttention(nn.Module):
308
- # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
309
- def __init__(
310
- self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, add_lora=False, **kwargs
311
- ):
312
- super().__init__()
313
- logpy.debug(
314
- f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
315
- f"context_dim is {context_dim} and using {heads} heads with a "
316
- f"dimension of {dim_head}."
317
- )
318
- inner_dim = dim_head * heads
319
- context_dim = default(context_dim, query_dim)
320
-
321
- self.heads = heads
322
- self.dim_head = dim_head
323
- self.add_lora = add_lora
324
-
325
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
326
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
327
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
328
-
329
- self.to_out = nn.Sequential(
330
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
331
- )
332
- if add_lora:
333
- r = 32
334
- self.to_q_attn3_down = nn.Linear(query_dim, r, bias=False)
335
- self.to_q_attn3_up = zero_module(nn.Linear(r, inner_dim, bias=False))
336
- self.to_k_attn3_down = nn.Linear(context_dim, r, bias=False)
337
- self.to_k_attn3_up = zero_module(nn.Linear(r, inner_dim, bias=False))
338
- self.to_v_attn3_down = nn.Linear(context_dim, r, bias=False)
339
- self.to_v_attn3_up = zero_module(nn.Linear(r, inner_dim, bias=False))
340
- self.to_o_attn3_down = nn.Linear(inner_dim, r, bias=False)
341
- self.to_o_attn3_up = zero_module(nn.Linear(r, query_dim, bias=False))
342
- self.dropoutq = nn.Dropout(0.1)
343
- self.dropoutk = nn.Dropout(0.1)
344
- self.dropoutv = nn.Dropout(0.1)
345
- self.dropouto = nn.Dropout(0.1)
346
-
347
- nn.init.normal_(self.to_q_attn3_down.weight, std=1 / r)
348
- nn.init.normal_(self.to_k_attn3_down.weight, std=1 / r)
349
- nn.init.normal_(self.to_v_attn3_down.weight, std=1 / r)
350
- nn.init.normal_(self.to_o_attn3_down.weight, std=1 / r)
351
-
352
- self.attention_op: Optional[Any] = None
353
-
354
- def forward(
355
- self,
356
- x,
357
- context=None,
358
- mask=None,
359
- additional_tokens=None,
360
- n_times_crossframe_attn_in_self=0,
361
- ):
362
- if additional_tokens is not None:
363
- # get the number of masked tokens at the beginning of the output sequence
364
- n_tokens_to_mask = additional_tokens.shape[1]
365
- # add additional token
366
- x = torch.cat([additional_tokens, x], dim=1)
367
-
368
- context_k = context # b, n, c, h, w
369
-
370
- q = self.to_q(x)
371
- context = default(context, x)
372
- context_k = default(context_k, x)
373
- k = self.to_k(context_k)
374
- v = self.to_v(context_k)
375
- if self.add_lora:
376
- q += self.dropoutq(self.to_q_attn3_up(self.to_q_attn3_down(x)))
377
- k += self.dropoutk(self.to_k_attn3_up(self.to_k_attn3_down(context_k)))
378
- v += self.dropoutv(self.to_v_attn3_up(self.to_v_attn3_down(context_k)))
379
-
380
- if n_times_crossframe_attn_in_self:
381
- # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
382
- assert x.shape[0] % n_times_crossframe_attn_in_self == 0
383
- # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
384
- k = repeat(
385
- k[::n_times_crossframe_attn_in_self],
386
- "b ... -> (b n) ...",
387
- n=n_times_crossframe_attn_in_self,
388
- )
389
- v = repeat(
390
- v[::n_times_crossframe_attn_in_self],
391
- "b ... -> (b n) ...",
392
- n=n_times_crossframe_attn_in_self,
393
- )
394
-
395
- b, _, _ = q.shape
396
- q, k, v = map(
397
- lambda t: t.unsqueeze(3)
398
- .reshape(b, t.shape[1], self.heads, self.dim_head)
399
- .permute(0, 2, 1, 3)
400
- .reshape(b * self.heads, t.shape[1], self.dim_head)
401
- .contiguous(),
402
- (q, k, v),
403
- )
404
-
405
- attn_bias = None
406
-
407
- # actually compute the attention, what we cannot get enough of
408
- out = xformers.ops.memory_efficient_attention(
409
- q, k, v, attn_bias=attn_bias, op=self.attention_op
410
- )
411
-
412
- # TODO: Use this directly in the attention operation, as a bias
413
- if exists(mask):
414
- raise NotImplementedError
415
- out = (
416
- out.unsqueeze(0)
417
- .reshape(b, self.heads, out.shape[1], self.dim_head)
418
- .permute(0, 2, 1, 3)
419
- .reshape(b, out.shape[1], self.heads * self.dim_head)
420
- )
421
- if additional_tokens is not None:
422
- # remove additional token
423
- out = out[:, n_tokens_to_mask:]
424
- final = self.to_out(out)
425
- if self.add_lora:
426
- final += self.dropouto(self.to_o_attn3_up(self.to_o_attn3_down(out)))
427
- return final
428
-
429
-
430
- class BasicTransformerBlock(nn.Module):
431
- ATTENTION_MODES = {
432
- "softmax": CrossAttention, # vanilla attention
433
- "softmax-xformers": MemoryEfficientCrossAttention, # ampere
434
- }
435
-
436
- def __init__(
437
- self,
438
- dim,
439
- n_heads,
440
- d_head,
441
- dropout=0.0,
442
- context_dim=None,
443
- gated_ff=True,
444
- checkpoint=True,
445
- disable_self_attn=False,
446
- attn_mode="softmax",
447
- sdp_backend=None,
448
- image_cross=False,
449
- far=2,
450
- num_samples=32,
451
- add_lora=False,
452
- rgb_predict=False,
453
- mode='pixel-nerf',
454
- average=False,
455
- num_freqs=16,
456
- use_prev_weights_imp_sample=False,
457
- imp_sample_next_step=False,
458
- stratified=False,
459
- imp_sampling_percent=0.9,
460
- near_plane=0.
461
- ):
462
-
463
- super().__init__()
464
- assert attn_mode in self.ATTENTION_MODES
465
- self.add_lora = add_lora
466
- self.image_cross = image_cross
467
- self.rgb_predict = rgb_predict
468
- self.use_prev_weights_imp_sample = use_prev_weights_imp_sample
469
- self.imp_sample_next_step = imp_sample_next_step
470
- self.rendered_feat = None
471
- if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
472
- logpy.warn(
473
- f"Attention mode '{attn_mode}' is not available. Falling "
474
- f"back to native attention. This is not a problem in "
475
- f"Pytorch >= 2.0. FYI, you are running with PyTorch "
476
- f"version {torch.__version__}."
477
- )
478
- attn_mode = "softmax"
479
- elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
480
- logpy.warn(
481
- "We do not support vanilla attention anymore, as it is too "
482
- "expensive. Sorry."
483
- )
484
- if not XFORMERS_IS_AVAILABLE:
485
- assert (
486
- False
487
- ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
488
- else:
489
- logpy.info("Falling back to xformers efficient attention.")
490
- attn_mode = "softmax-xformers"
491
- attn_cls = self.ATTENTION_MODES[attn_mode]
492
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
493
- assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
494
- else:
495
- assert sdp_backend is None
496
- self.disable_self_attn = disable_self_attn
497
- self.attn1 = attn_cls(
498
- query_dim=dim,
499
- heads=n_heads,
500
- dim_head=d_head,
501
- dropout=dropout,
502
- add_lora=self.add_lora,
503
- context_dim=context_dim if self.disable_self_attn else None,
504
- backend=sdp_backend,
505
- ) # is a self-attention if not self.disable_self_attn
506
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
507
- self.attn2 = attn_cls(
508
- query_dim=dim,
509
- context_dim=context_dim,
510
- heads=n_heads,
511
- dim_head=d_head,
512
- dropout=dropout,
513
- add_lora=self.add_lora,
514
- backend=sdp_backend,
515
- ) # is self-attn if context is none
516
- if image_cross:
517
- self.pose_emb_layers = nn.Linear(2*dim, dim, bias=False)
518
- nn.init.eye_(self.pose_emb_layers.weight)
519
- self.pose_featurenerf = NerfSDModule(mode=mode,
520
- out_channels=dim,
521
- far_plane=far,
522
- num_samples=num_samples,
523
- rgb_predict=rgb_predict,
524
- average=average,
525
- num_freqs=num_freqs,
526
- stratified=stratified,
527
- imp_sampling_percent=imp_sampling_percent,
528
- near_plane=near_plane,
529
- )
530
-
531
- self.renderer = VolRender()
532
-
533
- self.norm1 = nn.LayerNorm(dim)
534
- self.norm2 = nn.LayerNorm(dim)
535
- self.norm3 = nn.LayerNorm(dim)
536
- self.checkpoint = checkpoint
537
- if self.checkpoint:
538
- logpy.debug(f"{self.__class__.__name__} is using checkpointing")
539
-
540
- def forward(
541
- self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
542
- ):
543
- kwargs = {"x": x}
544
-
545
- if context is not None:
546
- kwargs.update({"context": context})
547
-
548
- if context_ref is not None:
549
- kwargs.update({"context_ref": context_ref})
550
-
551
- if pose is not None:
552
- kwargs.update({"pose": pose})
553
-
554
- if mask_ref is not None:
555
- kwargs.update({"mask_ref": mask_ref})
556
-
557
- if prev_weights is not None:
558
- kwargs.update({"prev_weights": prev_weights})
559
-
560
- if additional_tokens is not None:
561
- kwargs.update({"additional_tokens": additional_tokens})
562
-
563
- if n_times_crossframe_attn_in_self:
564
- kwargs.update(
565
- {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
566
- )
567
-
568
- # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
569
- return checkpoint(
570
- self._forward, (x, context, context_ref, pose, mask_ref, prev_weights), self.parameters(), self.checkpoint
571
- )
572
-
573
- def reference_attn(self, x, context_ref, context, pose, prev_weights, mask_ref):
574
- feats, sigmas, dists, _, predicted_rgb, sigmas_uniform, dists_uniform = self.pose_featurenerf(pose,
575
- context_ref,
576
- mask_ref,
577
- prev_weights=prev_weights if self.use_prev_weights_imp_sample else None,
578
- imp_sample_next_step=self.imp_sample_next_step)
579
-
580
- b, hw, d = feats.size()[:3]
581
- feats = rearrange(feats, "b hw d ... -> b (hw d) ...")
582
-
583
- feats = (
584
- self.attn2(
585
- self.norm2(feats), context=context,
586
- )
587
- + feats
588
- )
589
-
590
- feats = rearrange(feats, "b (hw d) ... -> b hw d ...", hw=hw, d=d)
591
-
592
- sigmas_ = trunc_exp(sigmas)
593
- if sigmas_uniform is not None:
594
- sigmas_uniform = trunc_exp(sigmas_uniform)
595
-
596
- context_ref, fg_mask, alphas, weights_uniform, predicted_rgb = self.renderer(feats, sigmas_, dists, densities_uniform=sigmas_uniform, dists_uniform=dists_uniform, return_weights_uniform=True, rgb=F.sigmoid(predicted_rgb) if predicted_rgb is not None else None)
597
- if self.use_prev_weights_imp_sample:
598
- prev_weights = weights_uniform
599
-
600
- return context_ref, fg_mask, prev_weights, alphas, predicted_rgb
601
-
602
- def _forward(
603
- self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
604
- ):
605
- fg_mask = None
606
- weights = None
607
- alphas = None
608
- predicted_rgb = None
609
- xref = None
610
-
611
- x = (
612
- self.attn1(
613
- self.norm1(x),
614
- context=context if self.disable_self_attn else None,
615
- additional_tokens=additional_tokens,
616
- n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
617
- if not self.disable_self_attn
618
- else 0,
619
- )
620
- + x
621
- )
622
- x = (
623
- self.attn2(
624
- self.norm2(x), context=context, additional_tokens=additional_tokens
625
- )
626
- + x
627
- )
628
- with torch.amp.autocast(device_type='cuda', dtype=torch.float32):
629
- if context_ref is not None:
630
- xref, fg_mask, weights, alphas, predicted_rgb = self.reference_attn(x,
631
- rearrange(context_ref, "(b n) ... -> b n ...", b=x.size(0), n=context_ref.size(0) // x.size(0)),
632
- context,
633
- pose,
634
- prev_weights,
635
- mask_ref)
636
- x = self.pose_emb_layers(torch.cat([x, xref], -1))
637
-
638
- x = self.ff(self.norm3(x)) + x
639
- return x, fg_mask, weights, alphas, predicted_rgb
640
-
641
-
642
- class BasicTransformerSingleLayerBlock(nn.Module):
643
- ATTENTION_MODES = {
644
- "softmax": CrossAttention, # vanilla attention
645
- "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
646
- # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
647
- }
648
-
649
- def __init__(
650
- self,
651
- dim,
652
- n_heads,
653
- d_head,
654
- dropout=0.0,
655
- context_dim=None,
656
- gated_ff=True,
657
- checkpoint=True,
658
- attn_mode="softmax",
659
- ):
660
- super().__init__()
661
- assert attn_mode in self.ATTENTION_MODES
662
- attn_cls = self.ATTENTION_MODES[attn_mode]
663
- self.attn1 = attn_cls(
664
- query_dim=dim,
665
- heads=n_heads,
666
- dim_head=d_head,
667
- dropout=dropout,
668
- context_dim=context_dim,
669
- )
670
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
671
- self.norm1 = nn.LayerNorm(dim)
672
- self.norm2 = nn.LayerNorm(dim)
673
- self.checkpoint = checkpoint
674
-
675
- def forward(self, x, context=None):
676
- return checkpoint(
677
- self._forward, (x, context), self.parameters(), self.checkpoint
678
- )
679
-
680
- def _forward(self, x, context=None):
681
- x = self.attn1(self.norm1(x), context=context) + x
682
- x = self.ff(self.norm2(x)) + x
683
- return x
684
-
685
-
686
- class SpatialTransformer(nn.Module):
687
- """
688
- Transformer block for image-like data.
689
- First, project the input (aka embedding)
690
- and reshape to b, t, d.
691
- Then apply standard transformer action.
692
- Finally, reshape to image
693
- NEW: use_linear for more efficiency instead of the 1x1 convs
694
- """
695
-
696
- def __init__(
697
- self,
698
- in_channels,
699
- n_heads,
700
- d_head,
701
- depth=1,
702
- dropout=0.0,
703
- context_dim=None,
704
- disable_self_attn=False,
705
- use_linear=False,
706
- attn_type="softmax",
707
- use_checkpoint=True,
708
- # sdp_backend=SDPBackend.FLASH_ATTENTION
709
- sdp_backend=None,
710
- image_cross=True,
711
- rgb_predict=False,
712
- far=2,
713
- num_samples=32,
714
- add_lora=False,
715
- mode='feature-nerf',
716
- average=False,
717
- num_freqs=16,
718
- use_prev_weights_imp_sample=False,
719
- stratified=False,
720
- poscontrol_interval=4,
721
- imp_sampling_percent=0.9,
722
- near_plane=0.
723
- ):
724
- super().__init__()
725
- logpy.debug(
726
- f"constructing {self.__class__.__name__} of depth {depth} w/ "
727
- f"{in_channels} channels and {n_heads} heads."
728
- )
729
- from omegaconf import ListConfig
730
-
731
- if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
732
- context_dim = [context_dim]
733
- if exists(context_dim) and isinstance(context_dim, list):
734
- if depth != len(context_dim):
735
- logpy.warn(
736
- f"{self.__class__.__name__}: Found context dims "
737
- f"{context_dim} of depth {len(context_dim)}, which does not "
738
- f"match the specified 'depth' of {depth}. Setting context_dim "
739
- f"to {depth * [context_dim[0]]} now."
740
- )
741
- # depth does not match context dims.
742
- assert all(
743
- map(lambda x: x == context_dim[0], context_dim)
744
- ), "need homogenous context_dim to match depth automatically"
745
- context_dim = depth * [context_dim[0]]
746
- elif context_dim is None:
747
- context_dim = [None] * depth
748
- self.in_channels = in_channels
749
- inner_dim = n_heads * d_head
750
- self.norm = Normalize(in_channels)
751
-
752
- self.image_cross = image_cross
753
- self.poscontrol_interval = poscontrol_interval
754
-
755
- if not use_linear:
756
- self.proj_in = nn.Conv2d(
757
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0
758
- )
759
- else:
760
- self.proj_in = nn.Linear(in_channels, inner_dim)
761
-
762
- self.transformer_blocks = nn.ModuleList(
763
- [
764
- BasicTransformerBlock(
765
- inner_dim,
766
- n_heads,
767
- d_head,
768
- dropout=dropout,
769
- context_dim=context_dim[d],
770
- disable_self_attn=disable_self_attn,
771
- attn_mode=attn_type,
772
- checkpoint=use_checkpoint,
773
- sdp_backend=sdp_backend,
774
- image_cross=self.image_cross and (d % poscontrol_interval == 0),
775
- far=far,
776
- num_samples=num_samples,
777
- add_lora=add_lora and self.image_cross and (d % poscontrol_interval == 0),
778
- rgb_predict=rgb_predict,
779
- mode=mode,
780
- average=average,
781
- num_freqs=num_freqs,
782
- use_prev_weights_imp_sample=use_prev_weights_imp_sample,
783
- imp_sample_next_step=(use_prev_weights_imp_sample and self.image_cross and (d % poscontrol_interval == 0) and depth >= poscontrol_interval and d < (depth // poscontrol_interval) * poscontrol_interval ),
784
- stratified=stratified,
785
- imp_sampling_percent=imp_sampling_percent,
786
- near_plane=near_plane,
787
- )
788
- for d in range(depth)
789
- ]
790
- )
791
- if not use_linear:
792
- self.proj_out = zero_module(
793
- nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
794
- )
795
- else:
796
- # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
797
- self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
798
- self.use_linear = use_linear
799
-
800
- def forward(self, x, xr, context=None, contextr=None, pose=None, mask_ref=None, prev_weights=None):
801
- # note: if no context is given, cross-attention defaults to self-attention
802
- if xr is None:
803
- if not isinstance(context, list):
804
- context = [context]
805
- b, c, h, w = x.shape
806
- x_in = x
807
- x = self.norm(x)
808
- if not self.use_linear:
809
- x = self.proj_in(x)
810
- x = rearrange(x, "b c h w -> b (h w) c").contiguous()
811
- if self.use_linear:
812
- x = self.proj_in(x)
813
- for i, block in enumerate(self.transformer_blocks):
814
- if i > 0 and len(context) == 1:
815
- i = 0 # use same context for each block
816
- x, _, _, _, _ = block(x, context=context[i])
817
- if self.use_linear:
818
- x = self.proj_out(x)
819
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
820
- if not self.use_linear:
821
- x = self.proj_out(x)
822
- return x + x_in, None, None, None, None, None
823
- else:
824
- if not isinstance(context, list):
825
- context = [context]
826
- contextr = [contextr]
827
- b, c, h, w = x.shape
828
- b1, _, _, _ = xr.shape
829
- x_in = x
830
- xr_in = xr
831
- fg_masks = []
832
- alphas = []
833
- rgbs = []
834
-
835
- x = self.norm(x)
836
- with torch.no_grad():
837
- xr = self.norm(xr)
838
-
839
- if not self.use_linear:
840
- x = self.proj_in(x)
841
- with torch.no_grad():
842
- xr = self.proj_in(xr)
843
-
844
- x = rearrange(x, "b c h w -> b (h w) c").contiguous()
845
- xr = rearrange(xr, "b1 c h w -> b1 (h w) c").contiguous()
846
- if self.use_linear:
847
- x = self.proj_in(x)
848
- with torch.no_grad():
849
- xr = self.proj_in(xr)
850
-
851
- prev_weights = None
852
- counter = 0
853
- for i, block in enumerate(self.transformer_blocks):
854
- if i > 0 and len(context) == 1:
855
- i = 0 # use same context for each block
856
- if self.image_cross and (counter % self.poscontrol_interval == 0):
857
- with torch.no_grad():
858
- xr, _, _, _, _ = block(xr, context=contextr[i])
859
- x, fg_mask, weights, alpha, rgb = block(x, context=context[i], context_ref=xr.detach(), pose=pose, mask_ref=mask_ref, prev_weights=prev_weights)
860
- prev_weights = weights
861
- fg_masks.append(fg_mask)
862
- if alpha is not None:
863
- alphas.append(alpha)
864
- if rgb is not None:
865
- rgbs.append(rgb)
866
- else:
867
- with torch.no_grad():
868
- xr, _, _, _, _ = block(xr, context=contextr[i])
869
- x, _, _, _, _ = block(x, context=context[i])
870
- counter += 1
871
- if self.use_linear:
872
- x = self.proj_out(x)
873
- with torch.no_grad():
874
- xr = self.proj_out(xr)
875
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
876
- xr = rearrange(xr, "b1 (h w) c -> b1 c h w", h=h, w=w).contiguous()
877
- if not self.use_linear:
878
- x = self.proj_out(x)
879
- with torch.no_grad():
880
- xr = self.proj_out(xr)
881
- if len(fg_masks) > 0:
882
- if len(rgbs) <= 0:
883
- rgbs = None
884
- if len(alphas) <= 0:
885
- alphas = None
886
- return x + x_in, (xr + xr_in).detach(), fg_masks, prev_weights, alphas, rgbs
887
- else:
888
- return x + x_in, (xr + xr_in).detach(), None, prev_weights, None, None
889
-
890
-
891
- def benchmark_attn():
892
- # Lets define a helpful benchmarking function:
893
- # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
894
- device = "cuda" if torch.cuda.is_available() else "cpu"
895
- import torch.nn.functional as F
896
- import torch.utils.benchmark as benchmark
897
-
898
- def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
899
- t0 = benchmark.Timer(
900
- stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
901
- )
902
- return t0.blocked_autorange().mean * 1e6
903
-
904
- # Lets define the hyper-parameters of our input
905
- batch_size = 32
906
- max_sequence_len = 1024
907
- num_heads = 32
908
- embed_dimension = 32
909
-
910
- dtype = torch.float16
911
-
912
- query = torch.rand(
913
- batch_size,
914
- num_heads,
915
- max_sequence_len,
916
- embed_dimension,
917
- device=device,
918
- dtype=dtype,
919
- )
920
- key = torch.rand(
921
- batch_size,
922
- num_heads,
923
- max_sequence_len,
924
- embed_dimension,
925
- device=device,
926
- dtype=dtype,
927
- )
928
- value = torch.rand(
929
- batch_size,
930
- num_heads,
931
- max_sequence_len,
932
- embed_dimension,
933
- device=device,
934
- dtype=dtype,
935
- )
936
-
937
- print(f"q/k/v shape:", query.shape, key.shape, value.shape)
938
-
939
- # Lets explore the speed of each of the 3 implementations
940
- from torch.backends.cuda import SDPBackend, sdp_kernel
941
-
942
- # Helpful arguments mapper
943
- backend_map = {
944
- SDPBackend.MATH: {
945
- "enable_math": True,
946
- "enable_flash": False,
947
- "enable_mem_efficient": False,
948
- },
949
- SDPBackend.FLASH_ATTENTION: {
950
- "enable_math": False,
951
- "enable_flash": True,
952
- "enable_mem_efficient": False,
953
- },
954
- SDPBackend.EFFICIENT_ATTENTION: {
955
- "enable_math": False,
956
- "enable_flash": False,
957
- "enable_mem_efficient": True,
958
- },
959
- }
960
-
961
- from torch.profiler import ProfilerActivity, profile, record_function
962
-
963
- activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
964
-
965
- print(
966
- f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
967
- )
968
- with profile(
969
- activities=activities, record_shapes=False, profile_memory=True
970
- ) as prof:
971
- with record_function("Default detailed stats"):
972
- for _ in range(25):
973
- o = F.scaled_dot_product_attention(query, key, value)
974
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
975
-
976
- print(
977
- f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
978
- )
979
- with sdp_kernel(**backend_map[SDPBackend.MATH]):
980
- with profile(
981
- activities=activities, record_shapes=False, profile_memory=True
982
- ) as prof:
983
- with record_function("Math implmentation stats"):
984
- for _ in range(25):
985
- o = F.scaled_dot_product_attention(query, key, value)
986
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
987
-
988
- with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
989
- try:
990
- print(
991
- f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
992
- )
993
- except RuntimeError:
994
- print("FlashAttention is not supported. See warnings for reasons.")
995
- with profile(
996
- activities=activities, record_shapes=False, profile_memory=True
997
- ) as prof:
998
- with record_function("FlashAttention stats"):
999
- for _ in range(25):
1000
- o = F.scaled_dot_product_attention(query, key, value)
1001
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1002
-
1003
- with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
1004
- try:
1005
- print(
1006
- f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
1007
- )
1008
- except RuntimeError:
1009
- print("EfficientAttention is not supported. See warnings for reasons.")
1010
- with profile(
1011
- activities=activities, record_shapes=False, profile_memory=True
1012
- ) as prof:
1013
- with record_function("EfficientAttention stats"):
1014
- for _ in range(25):
1015
- o = F.scaled_dot_product_attention(query, key, value)
1016
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1017
-
1018
-
1019
- def run_model(model, x, context):
1020
- return model(x, context)
1021
-
1022
-
1023
- def benchmark_transformer_blocks():
1024
- device = "cuda" if torch.cuda.is_available() else "cpu"
1025
- import torch.utils.benchmark as benchmark
1026
-
1027
- def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
1028
- t0 = benchmark.Timer(
1029
- stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
1030
- )
1031
- return t0.blocked_autorange().mean * 1e6
1032
-
1033
- checkpoint = True
1034
- compile = False
1035
-
1036
- batch_size = 32
1037
- h, w = 64, 64
1038
- context_len = 77
1039
- embed_dimension = 1024
1040
- context_dim = 1024
1041
- d_head = 64
1042
-
1043
- transformer_depth = 4
1044
-
1045
- n_heads = embed_dimension // d_head
1046
-
1047
- dtype = torch.float16
1048
-
1049
- model_native = SpatialTransformer(
1050
- embed_dimension,
1051
- n_heads,
1052
- d_head,
1053
- context_dim=context_dim,
1054
- use_linear=True,
1055
- use_checkpoint=checkpoint,
1056
- attn_type="softmax",
1057
- depth=transformer_depth,
1058
- sdp_backend=SDPBackend.FLASH_ATTENTION,
1059
- ).to(device)
1060
- model_efficient_attn = SpatialTransformer(
1061
- embed_dimension,
1062
- n_heads,
1063
- d_head,
1064
- context_dim=context_dim,
1065
- use_linear=True,
1066
- depth=transformer_depth,
1067
- use_checkpoint=checkpoint,
1068
- attn_type="softmax-xformers",
1069
- ).to(device)
1070
- if not checkpoint and compile:
1071
- print("compiling models")
1072
- model_native = torch.compile(model_native)
1073
- model_efficient_attn = torch.compile(model_efficient_attn)
1074
-
1075
- x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
1076
- c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
1077
-
1078
- from torch.profiler import ProfilerActivity, profile, record_function
1079
-
1080
- activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
1081
-
1082
- with torch.autocast("cuda"):
1083
- print(
1084
- f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
1085
- )
1086
- print(
1087
- f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
1088
- )
1089
-
1090
- print(75 * "+")
1091
- print("NATIVE")
1092
- print(75 * "+")
1093
- torch.cuda.reset_peak_memory_stats()
1094
- with profile(
1095
- activities=activities, record_shapes=False, profile_memory=True
1096
- ) as prof:
1097
- with record_function("NativeAttention stats"):
1098
- for _ in range(25):
1099
- model_native(x, c)
1100
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1101
- print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
1102
-
1103
- print(75 * "+")
1104
- print("Xformers")
1105
- print(75 * "+")
1106
- torch.cuda.reset_peak_memory_stats()
1107
- with profile(
1108
- activities=activities, record_shapes=False, profile_memory=True
1109
- ) as prof:
1110
- with record_function("xformers stats"):
1111
- for _ in range(25):
1112
- model_efficient_attn(x, c)
1113
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1114
- print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
1115
-
1116
-
1117
- def test01():
1118
- # conv1x1 vs linear
1119
- from ..util import count_params
1120
-
1121
- conv = nn.Conv2d(3, 32, kernel_size=1).cuda()
1122
- print(count_params(conv))
1123
- linear = torch.nn.Linear(3, 32).cuda()
1124
- print(count_params(linear))
1125
-
1126
- print(conv.weight.shape)
1127
-
1128
- # use same initialization
1129
- linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
1130
- linear.bias = torch.nn.Parameter(conv.bias)
1131
-
1132
- print(linear.weight.shape)
1133
-
1134
- x = torch.randn(11, 3, 64, 64).cuda()
1135
-
1136
- xr = rearrange(x, "b c h w -> b (h w) c").contiguous()
1137
- print(xr.shape)
1138
- out_linear = linear(xr)
1139
- print(out_linear.mean(), out_linear.shape)
1140
-
1141
- out_conv = conv(x)
1142
- print(out_conv.mean(), out_conv.shape)
1143
- print("done with test01.\n")
1144
-
1145
-
1146
- def test02():
1147
- # try cosine flash attention
1148
- import time
1149
-
1150
- torch.backends.cuda.matmul.allow_tf32 = True
1151
- torch.backends.cudnn.allow_tf32 = True
1152
- torch.backends.cudnn.benchmark = True
1153
- print("testing cosine flash attention...")
1154
- DIM = 1024
1155
- SEQLEN = 4096
1156
- BS = 16
1157
-
1158
- print(" softmax (vanilla) first...")
1159
- model = BasicTransformerBlock(
1160
- dim=DIM,
1161
- n_heads=16,
1162
- d_head=64,
1163
- dropout=0.0,
1164
- context_dim=None,
1165
- attn_mode="softmax",
1166
- ).cuda()
1167
- try:
1168
- x = torch.randn(BS, SEQLEN, DIM).cuda()
1169
- tic = time.time()
1170
- y = model(x)
1171
- toc = time.time()
1172
- print(y.shape, toc - tic)
1173
- except RuntimeError as e:
1174
- # likely oom
1175
- print(str(e))
1176
-
1177
- print("\n now flash-cosine...")
1178
- model = BasicTransformerBlock(
1179
- dim=DIM,
1180
- n_heads=16,
1181
- d_head=64,
1182
- dropout=0.0,
1183
- context_dim=None,
1184
- attn_mode="flash-cosine",
1185
- ).cuda()
1186
- x = torch.randn(BS, SEQLEN, DIM).cuda()
1187
- tic = time.time()
1188
- y = model(x)
1189
- toc = time.time()
1190
- print(y.shape, toc - tic)
1191
- print("done with test02.\n")
1192
-
1193
-
1194
- if __name__ == "__main__":
1195
- # test01()
1196
- # test02()
1197
- # test03()
1198
-
1199
- # benchmark_attn()
1200
- benchmark_transformer_blocks()
1201
-
1202
- print("done.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/__init__.py DELETED
File without changes
sgm/modules/autoencoding/lpips/__init__.py DELETED
File without changes
sgm/modules/autoencoding/lpips/loss.py DELETED
File without changes
sgm/modules/autoencoding/lpips/loss/LICENSE DELETED
@@ -1,23 +0,0 @@
1
- Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2
- All rights reserved.
3
-
4
- Redistribution and use in source and binary forms, with or without
5
- modification, are permitted provided that the following conditions are met:
6
-
7
- * Redistributions of source code must retain the above copyright notice, this
8
- list of conditions and the following disclaimer.
9
-
10
- * Redistributions in binary form must reproduce the above copyright notice,
11
- this list of conditions and the following disclaimer in the documentation
12
- and/or other materials provided with the distribution.
13
-
14
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/lpips/loss/__init__.py DELETED
File without changes
sgm/modules/autoencoding/lpips/loss/lpips.py DELETED
@@ -1,147 +0,0 @@
1
- """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
-
3
- from collections import namedtuple
4
-
5
- import torch
6
- import torch.nn as nn
7
- from torchvision import models
8
-
9
- from ..util import get_ckpt_path
10
-
11
-
12
- class LPIPS(nn.Module):
13
- # Learned perceptual metric
14
- def __init__(self, use_dropout=True):
15
- super().__init__()
16
- self.scaling_layer = ScalingLayer()
17
- self.chns = [64, 128, 256, 512, 512] # vg16 features
18
- self.net = vgg16(pretrained=True, requires_grad=False)
19
- self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
20
- self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
21
- self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
22
- self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
23
- self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
24
- self.load_from_pretrained()
25
- for param in self.parameters():
26
- param.requires_grad = False
27
-
28
- def load_from_pretrained(self, name="vgg_lpips"):
29
- ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
30
- self.load_state_dict(
31
- torch.load(ckpt, map_location=torch.device("cpu")), strict=False
32
- )
33
- print("loaded pretrained LPIPS loss from {}".format(ckpt))
34
-
35
- @classmethod
36
- def from_pretrained(cls, name="vgg_lpips"):
37
- if name != "vgg_lpips":
38
- raise NotImplementedError
39
- model = cls()
40
- ckpt = get_ckpt_path(name)
41
- model.load_state_dict(
42
- torch.load(ckpt, map_location=torch.device("cpu")), strict=False
43
- )
44
- return model
45
-
46
- def forward(self, input, target):
47
- in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
48
- outs0, outs1 = self.net(in0_input), self.net(in1_input)
49
- feats0, feats1, diffs = {}, {}, {}
50
- lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
51
- for kk in range(len(self.chns)):
52
- feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
53
- outs1[kk]
54
- )
55
- diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
56
-
57
- res = [
58
- spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
59
- for kk in range(len(self.chns))
60
- ]
61
- val = res[0]
62
- for l in range(1, len(self.chns)):
63
- val += res[l]
64
- return val
65
-
66
-
67
- class ScalingLayer(nn.Module):
68
- def __init__(self):
69
- super(ScalingLayer, self).__init__()
70
- self.register_buffer(
71
- "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
72
- )
73
- self.register_buffer(
74
- "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
75
- )
76
-
77
- def forward(self, inp):
78
- return (inp - self.shift) / self.scale
79
-
80
-
81
- class NetLinLayer(nn.Module):
82
- """A single linear layer which does a 1x1 conv"""
83
-
84
- def __init__(self, chn_in, chn_out=1, use_dropout=False):
85
- super(NetLinLayer, self).__init__()
86
- layers = (
87
- [
88
- nn.Dropout(),
89
- ]
90
- if (use_dropout)
91
- else []
92
- )
93
- layers += [
94
- nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
95
- ]
96
- self.model = nn.Sequential(*layers)
97
-
98
-
99
- class vgg16(torch.nn.Module):
100
- def __init__(self, requires_grad=False, pretrained=True):
101
- super(vgg16, self).__init__()
102
- vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
103
- self.slice1 = torch.nn.Sequential()
104
- self.slice2 = torch.nn.Sequential()
105
- self.slice3 = torch.nn.Sequential()
106
- self.slice4 = torch.nn.Sequential()
107
- self.slice5 = torch.nn.Sequential()
108
- self.N_slices = 5
109
- for x in range(4):
110
- self.slice1.add_module(str(x), vgg_pretrained_features[x])
111
- for x in range(4, 9):
112
- self.slice2.add_module(str(x), vgg_pretrained_features[x])
113
- for x in range(9, 16):
114
- self.slice3.add_module(str(x), vgg_pretrained_features[x])
115
- for x in range(16, 23):
116
- self.slice4.add_module(str(x), vgg_pretrained_features[x])
117
- for x in range(23, 30):
118
- self.slice5.add_module(str(x), vgg_pretrained_features[x])
119
- if not requires_grad:
120
- for param in self.parameters():
121
- param.requires_grad = False
122
-
123
- def forward(self, X):
124
- h = self.slice1(X)
125
- h_relu1_2 = h
126
- h = self.slice2(h)
127
- h_relu2_2 = h
128
- h = self.slice3(h)
129
- h_relu3_3 = h
130
- h = self.slice4(h)
131
- h_relu4_3 = h
132
- h = self.slice5(h)
133
- h_relu5_3 = h
134
- vgg_outputs = namedtuple(
135
- "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
136
- )
137
- out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
138
- return out
139
-
140
-
141
- def normalize_tensor(x, eps=1e-10):
142
- norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
143
- return x / (norm_factor + eps)
144
-
145
-
146
- def spatial_average(x, keepdim=True):
147
- return x.mean([2, 3], keepdim=keepdim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/lpips/model/LICENSE DELETED
@@ -1,58 +0,0 @@
1
- Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
2
- All rights reserved.
3
-
4
- Redistribution and use in source and binary forms, with or without
5
- modification, are permitted provided that the following conditions are met:
6
-
7
- * Redistributions of source code must retain the above copyright notice, this
8
- list of conditions and the following disclaimer.
9
-
10
- * Redistributions in binary form must reproduce the above copyright notice,
11
- this list of conditions and the following disclaimer in the documentation
12
- and/or other materials provided with the distribution.
13
-
14
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24
-
25
-
26
- --------------------------- LICENSE FOR pix2pix --------------------------------
27
- BSD License
28
-
29
- For pix2pix software
30
- Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
31
- All rights reserved.
32
-
33
- Redistribution and use in source and binary forms, with or without
34
- modification, are permitted provided that the following conditions are met:
35
-
36
- * Redistributions of source code must retain the above copyright notice, this
37
- list of conditions and the following disclaimer.
38
-
39
- * Redistributions in binary form must reproduce the above copyright notice,
40
- this list of conditions and the following disclaimer in the documentation
41
- and/or other materials provided with the distribution.
42
-
43
- ----------------------------- LICENSE FOR DCGAN --------------------------------
44
- BSD License
45
-
46
- For dcgan.torch software
47
-
48
- Copyright (c) 2015, Facebook, Inc. All rights reserved.
49
-
50
- Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
51
-
52
- Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
53
-
54
- Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
55
-
56
- Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
57
-
58
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/lpips/model/__init__.py DELETED
File without changes
sgm/modules/autoencoding/lpips/model/model.py DELETED
@@ -1,88 +0,0 @@
1
- import functools
2
-
3
- import torch.nn as nn
4
-
5
- from ..util import ActNorm
6
-
7
-
8
- def weights_init(m):
9
- classname = m.__class__.__name__
10
- if classname.find("Conv") != -1:
11
- nn.init.normal_(m.weight.data, 0.0, 0.02)
12
- elif classname.find("BatchNorm") != -1:
13
- nn.init.normal_(m.weight.data, 1.0, 0.02)
14
- nn.init.constant_(m.bias.data, 0)
15
-
16
-
17
- class NLayerDiscriminator(nn.Module):
18
- """Defines a PatchGAN discriminator as in Pix2Pix
19
- --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
- """
21
-
22
- def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
23
- """Construct a PatchGAN discriminator
24
- Parameters:
25
- input_nc (int) -- the number of channels in input images
26
- ndf (int) -- the number of filters in the last conv layer
27
- n_layers (int) -- the number of conv layers in the discriminator
28
- norm_layer -- normalization layer
29
- """
30
- super(NLayerDiscriminator, self).__init__()
31
- if not use_actnorm:
32
- norm_layer = nn.BatchNorm2d
33
- else:
34
- norm_layer = ActNorm
35
- if (
36
- type(norm_layer) == functools.partial
37
- ): # no need to use bias as BatchNorm2d has affine parameters
38
- use_bias = norm_layer.func != nn.BatchNorm2d
39
- else:
40
- use_bias = norm_layer != nn.BatchNorm2d
41
-
42
- kw = 4
43
- padw = 1
44
- sequence = [
45
- nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
46
- nn.LeakyReLU(0.2, True),
47
- ]
48
- nf_mult = 1
49
- nf_mult_prev = 1
50
- for n in range(1, n_layers): # gradually increase the number of filters
51
- nf_mult_prev = nf_mult
52
- nf_mult = min(2**n, 8)
53
- sequence += [
54
- nn.Conv2d(
55
- ndf * nf_mult_prev,
56
- ndf * nf_mult,
57
- kernel_size=kw,
58
- stride=2,
59
- padding=padw,
60
- bias=use_bias,
61
- ),
62
- norm_layer(ndf * nf_mult),
63
- nn.LeakyReLU(0.2, True),
64
- ]
65
-
66
- nf_mult_prev = nf_mult
67
- nf_mult = min(2**n_layers, 8)
68
- sequence += [
69
- nn.Conv2d(
70
- ndf * nf_mult_prev,
71
- ndf * nf_mult,
72
- kernel_size=kw,
73
- stride=1,
74
- padding=padw,
75
- bias=use_bias,
76
- ),
77
- norm_layer(ndf * nf_mult),
78
- nn.LeakyReLU(0.2, True),
79
- ]
80
-
81
- sequence += [
82
- nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
83
- ] # output 1 channel prediction map
84
- self.main = nn.Sequential(*sequence)
85
-
86
- def forward(self, input):
87
- """Standard forward."""
88
- return self.main(input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/lpips/util.py DELETED
@@ -1,128 +0,0 @@
1
- import hashlib
2
- import os
3
-
4
- import requests
5
- import torch
6
- import torch.nn as nn
7
- from tqdm import tqdm
8
-
9
- URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
10
-
11
- CKPT_MAP = {"vgg_lpips": "vgg.pth"}
12
-
13
- MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
14
-
15
-
16
- def download(url, local_path, chunk_size=1024):
17
- os.makedirs(os.path.split(local_path)[0], exist_ok=True)
18
- with requests.get(url, stream=True) as r:
19
- total_size = int(r.headers.get("content-length", 0))
20
- with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
21
- with open(local_path, "wb") as f:
22
- for data in r.iter_content(chunk_size=chunk_size):
23
- if data:
24
- f.write(data)
25
- pbar.update(chunk_size)
26
-
27
-
28
- def md5_hash(path):
29
- with open(path, "rb") as f:
30
- content = f.read()
31
- return hashlib.md5(content).hexdigest()
32
-
33
-
34
- def get_ckpt_path(name, root, check=False):
35
- assert name in URL_MAP
36
- path = os.path.join(root, CKPT_MAP[name])
37
- if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
38
- print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
39
- download(URL_MAP[name], path)
40
- md5 = md5_hash(path)
41
- assert md5 == MD5_MAP[name], md5
42
- return path
43
-
44
-
45
- class ActNorm(nn.Module):
46
- def __init__(
47
- self, num_features, logdet=False, affine=True, allow_reverse_init=False
48
- ):
49
- assert affine
50
- super().__init__()
51
- self.logdet = logdet
52
- self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
53
- self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
54
- self.allow_reverse_init = allow_reverse_init
55
-
56
- self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
57
-
58
- def initialize(self, input):
59
- with torch.no_grad():
60
- flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
61
- mean = (
62
- flatten.mean(1)
63
- .unsqueeze(1)
64
- .unsqueeze(2)
65
- .unsqueeze(3)
66
- .permute(1, 0, 2, 3)
67
- )
68
- std = (
69
- flatten.std(1)
70
- .unsqueeze(1)
71
- .unsqueeze(2)
72
- .unsqueeze(3)
73
- .permute(1, 0, 2, 3)
74
- )
75
-
76
- self.loc.data.copy_(-mean)
77
- self.scale.data.copy_(1 / (std + 1e-6))
78
-
79
- def forward(self, input, reverse=False):
80
- if reverse:
81
- return self.reverse(input)
82
- if len(input.shape) == 2:
83
- input = input[:, :, None, None]
84
- squeeze = True
85
- else:
86
- squeeze = False
87
-
88
- _, _, height, width = input.shape
89
-
90
- if self.training and self.initialized.item() == 0:
91
- self.initialize(input)
92
- self.initialized.fill_(1)
93
-
94
- h = self.scale * (input + self.loc)
95
-
96
- if squeeze:
97
- h = h.squeeze(-1).squeeze(-1)
98
-
99
- if self.logdet:
100
- log_abs = torch.log(torch.abs(self.scale))
101
- logdet = height * width * torch.sum(log_abs)
102
- logdet = logdet * torch.ones(input.shape[0]).to(input)
103
- return h, logdet
104
-
105
- return h
106
-
107
- def reverse(self, output):
108
- if self.training and self.initialized.item() == 0:
109
- if not self.allow_reverse_init:
110
- raise RuntimeError(
111
- "Initializing ActNorm in reverse direction is "
112
- "disabled by default. Use allow_reverse_init=True to enable."
113
- )
114
- else:
115
- self.initialize(output)
116
- self.initialized.fill_(1)
117
-
118
- if len(output.shape) == 2:
119
- output = output[:, :, None, None]
120
- squeeze = True
121
- else:
122
- squeeze = False
123
-
124
- h = output / self.scale - self.loc
125
-
126
- if squeeze:
127
- h = h.squeeze(-1).squeeze(-1)
128
- return h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/lpips/vqperceptual.py DELETED
@@ -1,17 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
-
4
-
5
- def hinge_d_loss(logits_real, logits_fake):
6
- loss_real = torch.mean(F.relu(1.0 - logits_real))
7
- loss_fake = torch.mean(F.relu(1.0 + logits_fake))
8
- d_loss = 0.5 * (loss_real + loss_fake)
9
- return d_loss
10
-
11
-
12
- def vanilla_d_loss(logits_real, logits_fake):
13
- d_loss = 0.5 * (
14
- torch.mean(torch.nn.functional.softplus(-logits_real))
15
- + torch.mean(torch.nn.functional.softplus(logits_fake))
16
- )
17
- return d_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/regularizers/__init__.py DELETED
@@ -1,31 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Any, Tuple
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- from ....modules.distributions.distributions import \
9
- DiagonalGaussianDistribution
10
- from .base import AbstractRegularizer
11
-
12
-
13
- class DiagonalGaussianRegularizer(AbstractRegularizer):
14
- def __init__(self, sample: bool = True):
15
- super().__init__()
16
- self.sample = sample
17
-
18
- def get_trainable_parameters(self) -> Any:
19
- yield from ()
20
-
21
- def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
22
- log = dict()
23
- posterior = DiagonalGaussianDistribution(z)
24
- if self.sample:
25
- z = posterior.sample()
26
- else:
27
- z = posterior.mode()
28
- kl_loss = posterior.kl()
29
- kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
30
- log["kl_loss"] = kl_loss
31
- return z, log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/regularizers/base.py DELETED
@@ -1,40 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Any, Tuple
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- from torch import nn
7
-
8
-
9
- class AbstractRegularizer(nn.Module):
10
- def __init__(self):
11
- super().__init__()
12
-
13
- def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
14
- raise NotImplementedError()
15
-
16
- @abstractmethod
17
- def get_trainable_parameters(self) -> Any:
18
- raise NotImplementedError()
19
-
20
-
21
- class IdentityRegularizer(AbstractRegularizer):
22
- def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
23
- return z, dict()
24
-
25
- def get_trainable_parameters(self) -> Any:
26
- yield from ()
27
-
28
-
29
- def measure_perplexity(
30
- predicted_indices: torch.Tensor, num_centroids: int
31
- ) -> Tuple[torch.Tensor, torch.Tensor]:
32
- # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
33
- # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
34
- encodings = (
35
- F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
36
- )
37
- avg_probs = encodings.mean(0)
38
- perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
39
- cluster_use = torch.sum(avg_probs > 0)
40
- return perplexity, cluster_use
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/regularizers/quantize.py DELETED
@@ -1,487 +0,0 @@
1
- import logging
2
- from abc import abstractmethod
3
- from typing import Dict, Iterator, Literal, Optional, Tuple, Union
4
-
5
- import numpy as np
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from einops import rearrange
10
- from torch import einsum
11
-
12
- from .base import AbstractRegularizer, measure_perplexity
13
-
14
- logpy = logging.getLogger(__name__)
15
-
16
-
17
- class AbstractQuantizer(AbstractRegularizer):
18
- def __init__(self):
19
- super().__init__()
20
- # Define these in your init
21
- # shape (N,)
22
- self.used: Optional[torch.Tensor]
23
- self.re_embed: int
24
- self.unknown_index: Union[Literal["random"], int]
25
-
26
- def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:
27
- assert self.used is not None, "You need to define used indices for remap"
28
- ishape = inds.shape
29
- assert len(ishape) > 1
30
- inds = inds.reshape(ishape[0], -1)
31
- used = self.used.to(inds)
32
- match = (inds[:, :, None] == used[None, None, ...]).long()
33
- new = match.argmax(-1)
34
- unknown = match.sum(2) < 1
35
- if self.unknown_index == "random":
36
- new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
37
- device=new.device
38
- )
39
- else:
40
- new[unknown] = self.unknown_index
41
- return new.reshape(ishape)
42
-
43
- def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:
44
- assert self.used is not None, "You need to define used indices for remap"
45
- ishape = inds.shape
46
- assert len(ishape) > 1
47
- inds = inds.reshape(ishape[0], -1)
48
- used = self.used.to(inds)
49
- if self.re_embed > self.used.shape[0]: # extra token
50
- inds[inds >= self.used.shape[0]] = 0 # simply set to zero
51
- back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
52
- return back.reshape(ishape)
53
-
54
- @abstractmethod
55
- def get_codebook_entry(
56
- self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
57
- ) -> torch.Tensor:
58
- raise NotImplementedError()
59
-
60
- def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
61
- yield from self.parameters()
62
-
63
-
64
- class GumbelQuantizer(AbstractQuantizer):
65
- """
66
- credit to @karpathy:
67
- https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
68
- Gumbel Softmax trick quantizer
69
- Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
70
- https://arxiv.org/abs/1611.01144
71
- """
72
-
73
- def __init__(
74
- self,
75
- num_hiddens: int,
76
- embedding_dim: int,
77
- n_embed: int,
78
- straight_through: bool = True,
79
- kl_weight: float = 5e-4,
80
- temp_init: float = 1.0,
81
- remap: Optional[str] = None,
82
- unknown_index: str = "random",
83
- loss_key: str = "loss/vq",
84
- ) -> None:
85
- super().__init__()
86
-
87
- self.loss_key = loss_key
88
- self.embedding_dim = embedding_dim
89
- self.n_embed = n_embed
90
-
91
- self.straight_through = straight_through
92
- self.temperature = temp_init
93
- self.kl_weight = kl_weight
94
-
95
- self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
96
- self.embed = nn.Embedding(n_embed, embedding_dim)
97
-
98
- self.remap = remap
99
- if self.remap is not None:
100
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
101
- self.re_embed = self.used.shape[0]
102
- else:
103
- self.used = None
104
- self.re_embed = n_embed
105
- if unknown_index == "extra":
106
- self.unknown_index = self.re_embed
107
- self.re_embed = self.re_embed + 1
108
- else:
109
- assert unknown_index == "random" or isinstance(
110
- unknown_index, int
111
- ), "unknown index needs to be 'random', 'extra' or any integer"
112
- self.unknown_index = unknown_index # "random" or "extra" or integer
113
- if self.remap is not None:
114
- logpy.info(
115
- f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
116
- f"Using {self.unknown_index} for unknown indices."
117
- )
118
-
119
- def forward(
120
- self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False
121
- ) -> Tuple[torch.Tensor, Dict]:
122
- # force hard = True when we are in eval mode, as we must quantize.
123
- # actually, always true seems to work
124
- hard = self.straight_through if self.training else True
125
- temp = self.temperature if temp is None else temp
126
- out_dict = {}
127
- logits = self.proj(z)
128
- if self.remap is not None:
129
- # continue only with used logits
130
- full_zeros = torch.zeros_like(logits)
131
- logits = logits[:, self.used, ...]
132
-
133
- soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
134
- if self.remap is not None:
135
- # go back to all entries but unused set to zero
136
- full_zeros[:, self.used, ...] = soft_one_hot
137
- soft_one_hot = full_zeros
138
- z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
139
-
140
- # + kl divergence to the prior loss
141
- qy = F.softmax(logits, dim=1)
142
- diff = (
143
- self.kl_weight
144
- * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
145
- )
146
- out_dict[self.loss_key] = diff
147
-
148
- ind = soft_one_hot.argmax(dim=1)
149
- out_dict["indices"] = ind
150
- if self.remap is not None:
151
- ind = self.remap_to_used(ind)
152
-
153
- if return_logits:
154
- out_dict["logits"] = logits
155
-
156
- return z_q, out_dict
157
-
158
- def get_codebook_entry(self, indices, shape):
159
- # TODO: shape not yet optional
160
- b, h, w, c = shape
161
- assert b * h * w == indices.shape[0]
162
- indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
163
- if self.remap is not None:
164
- indices = self.unmap_to_all(indices)
165
- one_hot = (
166
- F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
167
- )
168
- z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
169
- return z_q
170
-
171
-
172
- class VectorQuantizer(AbstractQuantizer):
173
- """
174
- ____________________________________________
175
- Discretization bottleneck part of the VQ-VAE.
176
- Inputs:
177
- - n_e : number of embeddings
178
- - e_dim : dimension of embedding
179
- - beta : commitment cost used in loss term,
180
- beta * ||z_e(x)-sg[e]||^2
181
- _____________________________________________
182
- """
183
-
184
- def __init__(
185
- self,
186
- n_e: int,
187
- e_dim: int,
188
- beta: float = 0.25,
189
- remap: Optional[str] = None,
190
- unknown_index: str = "random",
191
- sane_index_shape: bool = False,
192
- log_perplexity: bool = False,
193
- embedding_weight_norm: bool = False,
194
- loss_key: str = "loss/vq",
195
- ):
196
- super().__init__()
197
- self.n_e = n_e
198
- self.e_dim = e_dim
199
- self.beta = beta
200
- self.loss_key = loss_key
201
-
202
- if not embedding_weight_norm:
203
- self.embedding = nn.Embedding(self.n_e, self.e_dim)
204
- self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
205
- else:
206
- self.embedding = torch.nn.utils.weight_norm(
207
- nn.Embedding(self.n_e, self.e_dim), dim=1
208
- )
209
-
210
- self.remap = remap
211
- if self.remap is not None:
212
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
213
- self.re_embed = self.used.shape[0]
214
- else:
215
- self.used = None
216
- self.re_embed = n_e
217
- if unknown_index == "extra":
218
- self.unknown_index = self.re_embed
219
- self.re_embed = self.re_embed + 1
220
- else:
221
- assert unknown_index == "random" or isinstance(
222
- unknown_index, int
223
- ), "unknown index needs to be 'random', 'extra' or any integer"
224
- self.unknown_index = unknown_index # "random" or "extra" or integer
225
- if self.remap is not None:
226
- logpy.info(
227
- f"Remapping {self.n_e} indices to {self.re_embed} indices. "
228
- f"Using {self.unknown_index} for unknown indices."
229
- )
230
-
231
- self.sane_index_shape = sane_index_shape
232
- self.log_perplexity = log_perplexity
233
-
234
- def forward(
235
- self,
236
- z: torch.Tensor,
237
- ) -> Tuple[torch.Tensor, Dict]:
238
- do_reshape = z.ndim == 4
239
- if do_reshape:
240
- # # reshape z -> (batch, height, width, channel) and flatten
241
- z = rearrange(z, "b c h w -> b h w c").contiguous()
242
-
243
- else:
244
- assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined"
245
- z = z.contiguous()
246
-
247
- z_flattened = z.view(-1, self.e_dim)
248
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
249
-
250
- d = (
251
- torch.sum(z_flattened**2, dim=1, keepdim=True)
252
- + torch.sum(self.embedding.weight**2, dim=1)
253
- - 2
254
- * torch.einsum(
255
- "bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")
256
- )
257
- )
258
-
259
- min_encoding_indices = torch.argmin(d, dim=1)
260
- z_q = self.embedding(min_encoding_indices).view(z.shape)
261
- loss_dict = {}
262
- if self.log_perplexity:
263
- perplexity, cluster_usage = measure_perplexity(
264
- min_encoding_indices.detach(), self.n_e
265
- )
266
- loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage})
267
-
268
- # compute loss for embedding
269
- loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
270
- (z_q - z.detach()) ** 2
271
- )
272
- loss_dict[self.loss_key] = loss
273
-
274
- # preserve gradients
275
- z_q = z + (z_q - z).detach()
276
-
277
- # reshape back to match original input shape
278
- if do_reshape:
279
- z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
280
-
281
- if self.remap is not None:
282
- min_encoding_indices = min_encoding_indices.reshape(
283
- z.shape[0], -1
284
- ) # add batch axis
285
- min_encoding_indices = self.remap_to_used(min_encoding_indices)
286
- min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
287
-
288
- if self.sane_index_shape:
289
- if do_reshape:
290
- min_encoding_indices = min_encoding_indices.reshape(
291
- z_q.shape[0], z_q.shape[2], z_q.shape[3]
292
- )
293
- else:
294
- min_encoding_indices = rearrange(
295
- min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]
296
- )
297
-
298
- loss_dict["min_encoding_indices"] = min_encoding_indices
299
-
300
- return z_q, loss_dict
301
-
302
- def get_codebook_entry(
303
- self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
304
- ) -> torch.Tensor:
305
- # shape specifying (batch, height, width, channel)
306
- if self.remap is not None:
307
- assert shape is not None, "Need to give shape for remap"
308
- indices = indices.reshape(shape[0], -1) # add batch axis
309
- indices = self.unmap_to_all(indices)
310
- indices = indices.reshape(-1) # flatten again
311
-
312
- # get quantized latent vectors
313
- z_q = self.embedding(indices)
314
-
315
- if shape is not None:
316
- z_q = z_q.view(shape)
317
- # reshape back to match original input shape
318
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
319
-
320
- return z_q
321
-
322
-
323
- class EmbeddingEMA(nn.Module):
324
- def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
325
- super().__init__()
326
- self.decay = decay
327
- self.eps = eps
328
- weight = torch.randn(num_tokens, codebook_dim)
329
- self.weight = nn.Parameter(weight, requires_grad=False)
330
- self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
331
- self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
332
- self.update = True
333
-
334
- def forward(self, embed_id):
335
- return F.embedding(embed_id, self.weight)
336
-
337
- def cluster_size_ema_update(self, new_cluster_size):
338
- self.cluster_size.data.mul_(self.decay).add_(
339
- new_cluster_size, alpha=1 - self.decay
340
- )
341
-
342
- def embed_avg_ema_update(self, new_embed_avg):
343
- self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
344
-
345
- def weight_update(self, num_tokens):
346
- n = self.cluster_size.sum()
347
- smoothed_cluster_size = (
348
- (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
349
- )
350
- # normalize embedding average with smoothed cluster size
351
- embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
352
- self.weight.data.copy_(embed_normalized)
353
-
354
-
355
- class EMAVectorQuantizer(AbstractQuantizer):
356
- def __init__(
357
- self,
358
- n_embed: int,
359
- embedding_dim: int,
360
- beta: float,
361
- decay: float = 0.99,
362
- eps: float = 1e-5,
363
- remap: Optional[str] = None,
364
- unknown_index: str = "random",
365
- loss_key: str = "loss/vq",
366
- ):
367
- super().__init__()
368
- self.codebook_dim = embedding_dim
369
- self.num_tokens = n_embed
370
- self.beta = beta
371
- self.loss_key = loss_key
372
-
373
- self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
374
-
375
- self.remap = remap
376
- if self.remap is not None:
377
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
378
- self.re_embed = self.used.shape[0]
379
- else:
380
- self.used = None
381
- self.re_embed = n_embed
382
- if unknown_index == "extra":
383
- self.unknown_index = self.re_embed
384
- self.re_embed = self.re_embed + 1
385
- else:
386
- assert unknown_index == "random" or isinstance(
387
- unknown_index, int
388
- ), "unknown index needs to be 'random', 'extra' or any integer"
389
- self.unknown_index = unknown_index # "random" or "extra" or integer
390
- if self.remap is not None:
391
- logpy.info(
392
- f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
393
- f"Using {self.unknown_index} for unknown indices."
394
- )
395
-
396
- def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
397
- # reshape z -> (batch, height, width, channel) and flatten
398
- # z, 'b c h w -> b h w c'
399
- z = rearrange(z, "b c h w -> b h w c")
400
- z_flattened = z.reshape(-1, self.codebook_dim)
401
-
402
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
403
- d = (
404
- z_flattened.pow(2).sum(dim=1, keepdim=True)
405
- + self.embedding.weight.pow(2).sum(dim=1)
406
- - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
407
- ) # 'n d -> d n'
408
-
409
- encoding_indices = torch.argmin(d, dim=1)
410
-
411
- z_q = self.embedding(encoding_indices).view(z.shape)
412
- encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
413
- avg_probs = torch.mean(encodings, dim=0)
414
- perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
415
-
416
- if self.training and self.embedding.update:
417
- # EMA cluster size
418
- encodings_sum = encodings.sum(0)
419
- self.embedding.cluster_size_ema_update(encodings_sum)
420
- # EMA embedding average
421
- embed_sum = encodings.transpose(0, 1) @ z_flattened
422
- self.embedding.embed_avg_ema_update(embed_sum)
423
- # normalize embed_avg and update weight
424
- self.embedding.weight_update(self.num_tokens)
425
-
426
- # compute loss for embedding
427
- loss = self.beta * F.mse_loss(z_q.detach(), z)
428
-
429
- # preserve gradients
430
- z_q = z + (z_q - z).detach()
431
-
432
- # reshape back to match original input shape
433
- # z_q, 'b h w c -> b c h w'
434
- z_q = rearrange(z_q, "b h w c -> b c h w")
435
-
436
- out_dict = {
437
- self.loss_key: loss,
438
- "encodings": encodings,
439
- "encoding_indices": encoding_indices,
440
- "perplexity": perplexity,
441
- }
442
-
443
- return z_q, out_dict
444
-
445
-
446
- class VectorQuantizerWithInputProjection(VectorQuantizer):
447
- def __init__(
448
- self,
449
- input_dim: int,
450
- n_codes: int,
451
- codebook_dim: int,
452
- beta: float = 1.0,
453
- output_dim: Optional[int] = None,
454
- **kwargs,
455
- ):
456
- super().__init__(n_codes, codebook_dim, beta, **kwargs)
457
- self.proj_in = nn.Linear(input_dim, codebook_dim)
458
- self.output_dim = output_dim
459
- if output_dim is not None:
460
- self.proj_out = nn.Linear(codebook_dim, output_dim)
461
- else:
462
- self.proj_out = nn.Identity()
463
-
464
- def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
465
- rearr = False
466
- in_shape = z.shape
467
-
468
- if z.ndim > 3:
469
- rearr = self.output_dim is not None
470
- z = rearrange(z, "b c ... -> b (...) c")
471
- z = self.proj_in(z)
472
- z_q, loss_dict = super().forward(z)
473
-
474
- z_q = self.proj_out(z_q)
475
- if rearr:
476
- if len(in_shape) == 4:
477
- z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1])
478
- elif len(in_shape) == 5:
479
- z_q = rearrange(
480
- z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2]
481
- )
482
- else:
483
- raise NotImplementedError(
484
- f"rearranging not available for {len(in_shape)}-dimensional input."
485
- )
486
-
487
- return z_q, loss_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/__init__.py DELETED
File without changes
sgm/modules/diffusionmodules/denoiser.py DELETED
@@ -1,79 +0,0 @@
1
- import torch.nn as nn
2
- import torch
3
- from ...util import append_dims, instantiate_from_config
4
-
5
-
6
- class Denoiser(nn.Module):
7
- def __init__(self, weighting_config, scaling_config):
8
- super().__init__()
9
-
10
- self.weighting = instantiate_from_config(weighting_config)
11
- self.scaling = instantiate_from_config(scaling_config)
12
-
13
- def possibly_quantize_sigma(self, sigma):
14
- return sigma
15
-
16
- def possibly_quantize_c_noise(self, c_noise):
17
- return c_noise
18
-
19
- def w(self, sigma):
20
- return self.weighting(sigma)
21
-
22
- def __call__(self, network, input, sigma, cond, sigmas_ref=None, **kwargs):
23
- sigma = self.possibly_quantize_sigma(sigma)
24
- sigma_shape = sigma.shape
25
- sigma = append_dims(sigma, input.ndim)
26
- if sigmas_ref is not None:
27
- if kwargs is not None:
28
- kwargs['sigmas_ref'] = sigmas_ref
29
- else:
30
- kwargs = {'sigmas_ref': sigmas_ref}
31
-
32
- if kwargs['input_ref'] is not None:
33
- noise = torch.randn_like(kwargs['input_ref'])
34
- kwargs['input_ref'] = kwargs['input_ref'] + noise * append_dims(sigmas_ref, kwargs['input_ref'].ndim)
35
-
36
- if 'input_ref' in kwargs and kwargs['input_ref'] is not None and 'sigmas_ref' in kwargs:
37
- _, _, c_in_ref, c_noise_ref = self.scaling(append_dims(kwargs['sigmas_ref'], kwargs['input_ref'].ndim))
38
- kwargs['input_ref'] = kwargs['input_ref']*c_in_ref
39
- kwargs['sigmas_ref'] = self.possibly_quantize_c_noise(kwargs['sigmas_ref'])
40
-
41
- c_skip, c_out, c_in, c_noise = self.scaling(sigma)
42
- c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
43
- predict, fg_mask_list, alphas_list, rgb_list = network(input * c_in, c_noise, cond, **kwargs)
44
- return predict * c_out + input * c_skip, fg_mask_list, alphas_list, rgb_list
45
-
46
-
47
- class DiscreteDenoiser(Denoiser):
48
- def __init__(
49
- self,
50
- weighting_config,
51
- scaling_config,
52
- num_idx,
53
- discretization_config,
54
- do_append_zero=False,
55
- quantize_c_noise=True,
56
- flip=True,
57
- ):
58
- super().__init__(weighting_config, scaling_config)
59
- sigmas = instantiate_from_config(discretization_config)(
60
- num_idx, do_append_zero=do_append_zero, flip=flip
61
- )
62
- self.register_buffer("sigmas", sigmas)
63
- self.quantize_c_noise = quantize_c_noise
64
-
65
- def sigma_to_idx(self, sigma):
66
- dists = sigma - self.sigmas[:, None]
67
- return dists.abs().argmin(dim=0).view(sigma.shape)
68
-
69
- def idx_to_sigma(self, idx):
70
- return self.sigmas[idx]
71
-
72
- def possibly_quantize_sigma(self, sigma):
73
- return self.idx_to_sigma(self.sigma_to_idx(sigma))
74
-
75
- def possibly_quantize_c_noise(self, c_noise):
76
- if self.quantize_c_noise:
77
- return self.sigma_to_idx(c_noise)
78
- else:
79
- return c_noise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/denoiser_scaling.py DELETED
@@ -1,41 +0,0 @@
1
- import torch
2
- from abc import ABC, abstractmethod
3
- from typing import Tuple
4
-
5
-
6
- class DenoiserScaling(ABC):
7
- @abstractmethod
8
- def __call__(
9
- self, sigma: torch.Tensor
10
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
11
- pass
12
-
13
-
14
- class EDMScaling:
15
- def __init__(self, sigma_data=0.5):
16
- self.sigma_data = sigma_data
17
-
18
- def __call__(self, sigma):
19
- c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
20
- c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
21
- c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
22
- c_noise = 0.25 * sigma.log()
23
- return c_skip, c_out, c_in, c_noise
24
-
25
-
26
- class EpsScaling:
27
- def __call__(self, sigma):
28
- c_skip = torch.ones_like(sigma, device=sigma.device)
29
- c_out = -sigma
30
- c_in = 1 / (sigma**2 + 1.0) ** 0.5
31
- c_noise = sigma.clone()
32
- return c_skip, c_out, c_in, c_noise
33
-
34
-
35
- class VScaling:
36
- def __call__(self, sigma):
37
- c_skip = 1.0 / (sigma**2 + 1.0)
38
- c_out = -sigma / (sigma**2 + 1.0) ** 0.5
39
- c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
40
- c_noise = sigma.clone()
41
- return c_skip, c_out, c_in, c_noise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/denoiser_weighting.py DELETED
@@ -1,24 +0,0 @@
1
- import torch
2
-
3
-
4
- class UnitWeighting:
5
- def __call__(self, sigma):
6
- return torch.ones_like(sigma, device=sigma.device)
7
-
8
-
9
- class EDMWeighting:
10
- def __init__(self, sigma_data=0.5):
11
- self.sigma_data = sigma_data
12
-
13
- def __call__(self, sigma):
14
- return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
15
-
16
-
17
- class VWeighting(EDMWeighting):
18
- def __init__(self):
19
- super().__init__(sigma_data=1.0)
20
-
21
-
22
- class EpsWeighting:
23
- def __call__(self, sigma):
24
- return sigma**-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/discretizer.py DELETED
@@ -1,69 +0,0 @@
1
- from abc import abstractmethod
2
- from functools import partial
3
-
4
- import numpy as np
5
- import torch
6
-
7
- from ...modules.diffusionmodules.util import make_beta_schedule
8
- from ...util import append_zero
9
-
10
-
11
- def generate_roughly_equally_spaced_steps(
12
- num_substeps: int, max_step: int
13
- ) -> np.ndarray:
14
- return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
15
-
16
-
17
- class Discretization:
18
- def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
19
- sigmas = self.get_sigmas(n, device=device)
20
- sigmas = append_zero(sigmas) if do_append_zero else sigmas
21
- return sigmas if not flip else torch.flip(sigmas, (0,))
22
-
23
- @abstractmethod
24
- def get_sigmas(self, n, device):
25
- pass
26
-
27
-
28
- class EDMDiscretization(Discretization):
29
- def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
30
- self.sigma_min = sigma_min
31
- self.sigma_max = sigma_max
32
- self.rho = rho
33
-
34
- def get_sigmas(self, n, device="cpu"):
35
- ramp = torch.linspace(0, 1, n, device=device)
36
- min_inv_rho = self.sigma_min ** (1 / self.rho)
37
- max_inv_rho = self.sigma_max ** (1 / self.rho)
38
- sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
39
- return sigmas
40
-
41
-
42
- class LegacyDDPMDiscretization(Discretization):
43
- def __init__(
44
- self,
45
- linear_start=0.00085,
46
- linear_end=0.0120,
47
- num_timesteps=1000,
48
- ):
49
- super().__init__()
50
- self.num_timesteps = num_timesteps
51
- betas = make_beta_schedule(
52
- "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
53
- )
54
- alphas = 1.0 - betas
55
- self.alphas_cumprod = np.cumprod(alphas, axis=0)
56
- self.to_torch = partial(torch.tensor, dtype=torch.float32)
57
-
58
- def get_sigmas(self, n, device="cpu"):
59
- if n < self.num_timesteps:
60
- timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
61
- alphas_cumprod = self.alphas_cumprod[timesteps]
62
- elif n == self.num_timesteps:
63
- alphas_cumprod = self.alphas_cumprod
64
- else:
65
- raise ValueError
66
-
67
- to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
68
- sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
69
- return torch.flip(sigmas, (0,))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/guiders.py DELETED
@@ -1,167 +0,0 @@
1
- import logging
2
- from abc import ABC, abstractmethod
3
- from typing import Dict, List, Literal, Optional, Tuple, Union
4
-
5
- import torch
6
- from einops import rearrange, repeat
7
-
8
- from ...util import append_dims, default
9
-
10
- logpy = logging.getLogger(__name__)
11
-
12
-
13
- class Guider(ABC):
14
- @abstractmethod
15
- def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
16
- pass
17
-
18
- def prepare_inputs(
19
- self, x: torch.Tensor, s: float, c: Dict, uc: Dict
20
- ) -> Tuple[torch.Tensor, float, Dict]:
21
- pass
22
-
23
-
24
- class VanillaCFG(Guider):
25
- def __init__(self, scale: float):
26
- self.scale = scale
27
-
28
- def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
29
- x_u, x_c = x.chunk(2)
30
- x_pred = x_u + self.scale * (x_c - x_u)
31
- return x_pred
32
-
33
- def prepare_inputs(self, x, s, c, uc):
34
- c_out = dict()
35
-
36
- for k in c:
37
- if k in ["vector", "crossattn", "concat"]:
38
- c_out[k] = torch.cat((uc[k], c[k]), 0)
39
- else:
40
- assert c[k] == uc[k]
41
- c_out[k] = c[k]
42
- return torch.cat([x] * 2), torch.cat([s] * 2), c_out
43
-
44
-
45
- class IdentityGuider(Guider):
46
- def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
47
- return x
48
-
49
- def prepare_inputs(
50
- self, x: torch.Tensor, s: float, c: Dict, uc: Dict
51
- ) -> Tuple[torch.Tensor, float, Dict]:
52
- c_out = dict()
53
-
54
- for k in c:
55
- c_out[k] = c[k]
56
-
57
- return x, s, c_out
58
-
59
-
60
- class LinearPredictionGuider(Guider):
61
- def __init__(
62
- self,
63
- max_scale: float,
64
- num_frames: int,
65
- min_scale: float = 1.0,
66
- additional_cond_keys: Optional[Union[List[str], str]] = None,
67
- ):
68
- self.min_scale = min_scale
69
- self.max_scale = max_scale
70
- self.num_frames = num_frames
71
- self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
72
-
73
- additional_cond_keys = default(additional_cond_keys, [])
74
- if isinstance(additional_cond_keys, str):
75
- additional_cond_keys = [additional_cond_keys]
76
- self.additional_cond_keys = additional_cond_keys
77
-
78
- def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
79
- x_u, x_c = x.chunk(2)
80
-
81
- x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
82
- x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
83
- scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
84
- scale = append_dims(scale, x_u.ndim).to(x_u.device)
85
-
86
- return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
87
-
88
- def prepare_inputs(
89
- self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
90
- ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
91
- c_out = dict()
92
-
93
- for k in c:
94
- if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
95
- c_out[k] = torch.cat((uc[k], c[k]), 0)
96
- else:
97
- assert c[k] == uc[k]
98
- c_out[k] = c[k]
99
- return torch.cat([x] * 2), torch.cat([s] * 2), c_out
100
-
101
-
102
- class ScheduledCFGImgTextRef(Guider):
103
- """
104
- From InstructPix2Pix
105
- """
106
-
107
- def __init__(self, scale: float, scale_im: float):
108
- self.scale = scale
109
- self.scale_im = scale_im
110
-
111
- def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
112
- x_u, x_ic, x_c = x.chunk(3)
113
- x_pred = x_u + self.scale * (x_c - x_ic) + self.scale_im*(x_ic - x_u)
114
- return x_pred
115
-
116
- def prepare_inputs(self, x, s, c, uc):
117
- c_out = dict()
118
-
119
- for k in c:
120
- if k in ["vector", "crossattn", "concat"]:
121
- b = uc[k].shape[0]
122
- if k == "crossattn":
123
- uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)])
124
- c1, c2 = c[k].split([x.size(0), b - x.size(0)])
125
- c_out[k] = torch.cat((uc1, uc1, c1, uc2, c2, c2), 0)
126
- else:
127
- uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)])
128
- c1, c2 = c[k].split([x.size(0), b - x.size(0)])
129
- c_out[k] = torch.cat((uc1, uc1, c1, uc2, c2, c2), 0)
130
- else:
131
- assert c[k] == uc[k]
132
- c_out[k] = c[k]
133
- return torch.cat([x] * 3), torch.cat([s] * 3), c_out
134
-
135
-
136
- class VanillaCFGImgRef(Guider):
137
- """
138
- implements parallelized CFG
139
- """
140
-
141
- def __init__(self, scale: float):
142
- self.scale = scale
143
-
144
- def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
145
- x_u, x_c = x.chunk(2)
146
- x_pred = x_u + self.scale * (x_c - x_u)
147
- return x_pred
148
-
149
- def prepare_inputs(self, x, s, c, uc):
150
- c_out = dict()
151
-
152
- for k in c:
153
- if k in ["vector", "crossattn", "concat"]:
154
- b = uc[k].shape[0]
155
- if k == "crossattn":
156
- uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)])
157
- c1, c2 = c[k].split([x.size(0), b - x.size(0)])
158
- c_out[k] = torch.cat((uc1, c1, uc2, c2), 0)
159
- else:
160
- uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)])
161
- c1, c2 = c[k].split([x.size(0), b - x.size(0)])
162
- c_out[k] = torch.cat((uc1, c1, uc2, c2), 0)
163
- else:
164
- assert c[k] == uc[k]
165
- c_out[k] = c[k]
166
- return torch.cat([x] * 2), torch.cat([s] * 2), c_out
167
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/loss.py DELETED
@@ -1,216 +0,0 @@
1
- from typing import Dict, List, Optional, Tuple, Union
2
- import math
3
- import torch
4
- import torch.nn as nn
5
-
6
- from ...modules.autoencoding.lpips.loss.lpips import LPIPS
7
- from ...modules.encoders.modules import GeneralConditioner
8
- from ...util import append_dims, instantiate_from_config
9
- from .denoiser import Denoiser
10
-
11
-
12
- class StandardDiffusionLoss(nn.Module):
13
- def __init__(
14
- self,
15
- sigma_sampler_config: dict,
16
- loss_weighting_config: dict,
17
- loss_type: str = "l2",
18
- offset_noise_level: float = 0.0,
19
- batch2model_keys: Optional[Union[str, List[str]]] = None,
20
- ):
21
- super().__init__()
22
-
23
- assert loss_type in ["l2", "l1", "lpips"]
24
-
25
- self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
26
- self.loss_weighting = instantiate_from_config(loss_weighting_config)
27
-
28
- self.loss_type = loss_type
29
- self.offset_noise_level = offset_noise_level
30
-
31
- if loss_type == "lpips":
32
- self.lpips = LPIPS().eval()
33
-
34
- if not batch2model_keys:
35
- batch2model_keys = []
36
-
37
- if isinstance(batch2model_keys, str):
38
- batch2model_keys = [batch2model_keys]
39
-
40
- self.batch2model_keys = set(batch2model_keys)
41
-
42
- def get_noised_input(
43
- self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
44
- ) -> torch.Tensor:
45
- noised_input = input + noise * sigmas_bc
46
- return noised_input
47
-
48
- def forward(
49
- self,
50
- network: nn.Module,
51
- denoiser: Denoiser,
52
- conditioner: GeneralConditioner,
53
- input: torch.Tensor,
54
- batch: Dict,
55
- ) -> torch.Tensor:
56
- cond = conditioner(batch)
57
- return self._forward(network, denoiser, cond, input, batch)
58
-
59
- def _forward(
60
- self,
61
- network: nn.Module,
62
- denoiser: Denoiser,
63
- cond: Dict,
64
- input: torch.Tensor,
65
- batch: Dict,
66
- ) -> Tuple[torch.Tensor, Dict]:
67
- additional_model_inputs = {
68
- key: batch[key] for key in self.batch2model_keys.intersection(batch)
69
- }
70
- sigmas = self.sigma_sampler(input.shape[0]).to(input)
71
-
72
- noise = torch.randn_like(input)
73
- if self.offset_noise_level > 0.0:
74
- offset_shape = (
75
- (input.shape[0], 1, input.shape[2])
76
- if self.n_frames is not None
77
- else (input.shape[0], input.shape[1])
78
- )
79
- noise = noise + self.offset_noise_level * append_dims(
80
- torch.randn(offset_shape, device=input.device),
81
- input.ndim,
82
- )
83
- sigmas_bc = append_dims(sigmas, input.ndim)
84
- noised_input = self.get_noised_input(sigmas_bc, noise, input)
85
-
86
- model_output = denoiser(
87
- network, noised_input, sigmas, cond, **additional_model_inputs
88
- )
89
- w = append_dims(self.loss_weighting(sigmas), input.ndim)
90
- return self.get_loss(model_output, input, w)
91
-
92
- def get_loss(self, model_output, target, w):
93
- if self.loss_type == "l2":
94
- return torch.mean(
95
- (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
96
- )
97
- elif self.loss_type == "l1":
98
- return torch.mean(
99
- (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
100
- )
101
- elif self.loss_type == "lpips":
102
- loss = self.lpips(model_output, target).reshape(-1)
103
- return loss
104
- else:
105
- raise NotImplementedError(f"Unknown loss type {self.loss_type}")
106
-
107
-
108
- class StandardDiffusionLossImgRef(nn.Module):
109
- def __init__(
110
- self,
111
- sigma_sampler_config: dict,
112
- sigma_sampler_config_ref: dict,
113
- type: str = "l2",
114
- offset_noise_level: float = 0.0,
115
- batch2model_keys: Optional[Union[str, List[str]]] = None,
116
- ):
117
- super().__init__()
118
-
119
- assert type in ["l2", "l1", "lpips"]
120
-
121
- self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
122
- self.sigma_sampler_ref = None
123
- if sigma_sampler_config_ref is not None:
124
- self.sigma_sampler_ref = instantiate_from_config(sigma_sampler_config_ref)
125
-
126
- self.type = type
127
- self.offset_noise_level = offset_noise_level
128
-
129
- if type == "lpips":
130
- self.lpips = LPIPS().eval()
131
-
132
- if not batch2model_keys:
133
- batch2model_keys = []
134
-
135
- if isinstance(batch2model_keys, str):
136
- batch2model_keys = [batch2model_keys]
137
-
138
- self.batch2model_keys = set(batch2model_keys)
139
-
140
- def __call__(self, network, denoiser, conditioner, input, input_rgb, input_ref, pose, mask, mask_ref, opacity, batch):
141
- cond = conditioner(batch)
142
- additional_model_inputs = {
143
- key: batch[key] for key in self.batch2model_keys.intersection(batch)
144
- }
145
-
146
- sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
147
- noise = torch.randn_like(input)
148
- if self.offset_noise_level > 0.0:
149
- noise = noise + self.offset_noise_level * append_dims(
150
- torch.randn(input.shape[0], device=input.device), input.ndim
151
- )
152
-
153
- additional_model_inputs['pose'] = pose
154
- additional_model_inputs['mask_ref'] = mask_ref
155
-
156
- noised_input = input + noise * append_dims(sigmas, input.ndim)
157
- if self.sigma_sampler_ref is not None:
158
- sigmas_ref = self.sigma_sampler_ref(input.shape[0]).to(input.device)
159
- if input_ref is not None:
160
- noise = torch.randn_like(input_ref)
161
- if self.offset_noise_level > 0.0:
162
- noise = noise + self.offset_noise_level * append_dims(
163
- torch.randn(input_ref.shape[0], device=input_ref.device), input_ref.ndim
164
- )
165
- input_ref = input_ref + noise * append_dims(sigmas_ref, input_ref.ndim)
166
- additional_model_inputs['sigmas_ref'] = sigmas_ref
167
-
168
- additional_model_inputs['input_ref'] = input_ref
169
-
170
- model_output, fg_mask_list, alphas, predicted_rgb_list = denoiser(
171
- network, noised_input, sigmas, cond, **additional_model_inputs
172
- )
173
-
174
- w = append_dims(denoiser.w(sigmas), input.ndim)
175
- return self.get_loss(model_output, fg_mask_list, predicted_rgb_list, input, input_rgb, w, mask, mask_ref, opacity, alphas)
176
-
177
- def get_loss(self, model_output, fg_mask_list, predicted_rgb_list, target, target_rgb, w, mask, mask_ref, opacity, alphas_list):
178
- loss_rgb = []
179
- loss_fg = []
180
- loss_bg = []
181
- with torch.amp.autocast(device_type='cuda', dtype=torch.float32):
182
- if self.type == "l2":
183
- loss = (w * (model_output - target) ** 2)
184
- if mask is not None:
185
- loss_l2 = (loss*mask).sum([1, 2, 3])/(mask.sum([1, 2, 3]) + 1e-6)
186
- else:
187
- loss_l2 = torch.mean(loss.reshape(target.shape[0], -1), 1)
188
- if len(fg_mask_list) > 0 and len(alphas_list) > 0:
189
- for fg_mask, alphas in zip(fg_mask_list, alphas_list):
190
- size = int(math.sqrt(fg_mask.size(1)))
191
- opacity = torch.nn.functional.interpolate(opacity, size=size, antialias=True, mode='bilinear').detach()
192
- fg_mask = torch.clamp(fg_mask.reshape(-1, size*size), 0., 1.)
193
- loss_fg_ = ((fg_mask - opacity.reshape(-1, size*size))**2).mean(1) #torch.nn.functional.binary_cross_entropy(rgb, torch.clip(mask.reshape(-1, size*size), 0., 1.), reduce=False)
194
- loss_bg_ = (alphas - opacity.reshape(-1, size*size, 1, 1)).abs()*(1-opacity.reshape(-1, size*size, 1, 1)) #alpahs : b hw d 1
195
- loss_bg_ = (loss_bg_*((opacity.reshape(-1, size*size, 1, 1) < 0.1)*1)).mean([1, 2, 3])
196
- loss_fg.append(loss_fg_)
197
- loss_bg.append(loss_bg_)
198
- loss_fg = torch.stack(loss_fg, 1)
199
- loss_bg = torch.stack(loss_bg, 1)
200
-
201
- if len(predicted_rgb_list) > 0:
202
- for rgb in predicted_rgb_list:
203
- size = int(math.sqrt(rgb.size(1)))
204
- mask_ = torch.nn.functional.interpolate(mask, size=size, antialias=True, mode='bilinear').detach()
205
- loss_rgb_ = ((torch.nn.functional.interpolate(target_rgb*0.5+0.5, size=size, antialias=True, mode='bilinear').detach() - rgb.reshape(-1, size, size, 3).permute(0, 3, 1, 2)) ** 2)
206
- loss_rgb.append((loss_rgb_*mask_).sum([1, 2, 3])/(mask.sum([1, 2, 3]) + 1e-6))
207
- loss_rgb = torch.stack(loss_rgb, 1)
208
- # print(loss_l2, loss_fg, loss_bg, loss_rgb)
209
- return loss_l2, loss_fg, loss_bg, loss_rgb
210
- elif self.type == "l1":
211
- return torch.mean(
212
- (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
213
- ), loss_rgb
214
- elif self.type == "lpips":
215
- loss = self.lpips(model_output, target).reshape(-1)
216
- return loss, loss_rgb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/loss_weighting.py DELETED
@@ -1,32 +0,0 @@
1
- from abc import ABC, abstractmethod
2
-
3
- import torch
4
-
5
-
6
- class DiffusionLossWeighting(ABC):
7
- @abstractmethod
8
- def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
9
- pass
10
-
11
-
12
- class UnitWeighting(DiffusionLossWeighting):
13
- def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
14
- return torch.ones_like(sigma, device=sigma.device)
15
-
16
-
17
- class EDMWeighting(DiffusionLossWeighting):
18
- def __init__(self, sigma_data: float = 0.5):
19
- self.sigma_data = sigma_data
20
-
21
- def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
22
- return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
23
-
24
-
25
- class VWeighting(EDMWeighting):
26
- def __init__(self):
27
- super().__init__(sigma_data=1.0)
28
-
29
-
30
- class EpsWeighting(DiffusionLossWeighting):
31
- def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
32
- return sigma**-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/model.py DELETED
@@ -1,748 +0,0 @@
1
- # pytorch_diffusion + derived encoder decoder
2
- import logging
3
- import math
4
- from typing import Any, Callable, Optional
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
- from einops import rearrange
10
- from packaging import version
11
-
12
- logpy = logging.getLogger(__name__)
13
-
14
- try:
15
- import xformers
16
- import xformers.ops
17
-
18
- XFORMERS_IS_AVAILABLE = True
19
- except:
20
- XFORMERS_IS_AVAILABLE = False
21
- logpy.warning("no module 'xformers'. Processing without...")
22
-
23
- from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
24
-
25
-
26
- def get_timestep_embedding(timesteps, embedding_dim):
27
- """
28
- This matches the implementation in Denoising Diffusion Probabilistic Models:
29
- From Fairseq.
30
- Build sinusoidal embeddings.
31
- This matches the implementation in tensor2tensor, but differs slightly
32
- from the description in Section 3.5 of "Attention Is All You Need".
33
- """
34
- assert len(timesteps.shape) == 1
35
-
36
- half_dim = embedding_dim // 2
37
- emb = math.log(10000) / (half_dim - 1)
38
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
39
- emb = emb.to(device=timesteps.device)
40
- emb = timesteps.float()[:, None] * emb[None, :]
41
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
42
- if embedding_dim % 2 == 1: # zero pad
43
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
44
- return emb
45
-
46
-
47
- def nonlinearity(x):
48
- # swish
49
- return x * torch.sigmoid(x)
50
-
51
-
52
- def Normalize(in_channels, num_groups=32):
53
- return torch.nn.GroupNorm(
54
- num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
55
- )
56
-
57
-
58
- class Upsample(nn.Module):
59
- def __init__(self, in_channels, with_conv):
60
- super().__init__()
61
- self.with_conv = with_conv
62
- if self.with_conv:
63
- self.conv = torch.nn.Conv2d(
64
- in_channels, in_channels, kernel_size=3, stride=1, padding=1
65
- )
66
-
67
- def forward(self, x):
68
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
69
- if self.with_conv:
70
- x = self.conv(x)
71
- return x
72
-
73
-
74
- class Downsample(nn.Module):
75
- def __init__(self, in_channels, with_conv):
76
- super().__init__()
77
- self.with_conv = with_conv
78
- if self.with_conv:
79
- # no asymmetric padding in torch conv, must do it ourselves
80
- self.conv = torch.nn.Conv2d(
81
- in_channels, in_channels, kernel_size=3, stride=2, padding=0
82
- )
83
-
84
- def forward(self, x):
85
- if self.with_conv:
86
- pad = (0, 1, 0, 1)
87
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
88
- x = self.conv(x)
89
- else:
90
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
91
- return x
92
-
93
-
94
- class ResnetBlock(nn.Module):
95
- def __init__(
96
- self,
97
- *,
98
- in_channels,
99
- out_channels=None,
100
- conv_shortcut=False,
101
- dropout,
102
- temb_channels=512,
103
- ):
104
- super().__init__()
105
- self.in_channels = in_channels
106
- out_channels = in_channels if out_channels is None else out_channels
107
- self.out_channels = out_channels
108
- self.use_conv_shortcut = conv_shortcut
109
-
110
- self.norm1 = Normalize(in_channels)
111
- self.conv1 = torch.nn.Conv2d(
112
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
113
- )
114
- if temb_channels > 0:
115
- self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
116
- self.norm2 = Normalize(out_channels)
117
- self.dropout = torch.nn.Dropout(dropout)
118
- self.conv2 = torch.nn.Conv2d(
119
- out_channels, out_channels, kernel_size=3, stride=1, padding=1
120
- )
121
- if self.in_channels != self.out_channels:
122
- if self.use_conv_shortcut:
123
- self.conv_shortcut = torch.nn.Conv2d(
124
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
125
- )
126
- else:
127
- self.nin_shortcut = torch.nn.Conv2d(
128
- in_channels, out_channels, kernel_size=1, stride=1, padding=0
129
- )
130
-
131
- def forward(self, x, temb):
132
- h = x
133
- h = self.norm1(h)
134
- h = nonlinearity(h)
135
- h = self.conv1(h)
136
-
137
- if temb is not None:
138
- h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
139
-
140
- h = self.norm2(h)
141
- h = nonlinearity(h)
142
- h = self.dropout(h)
143
- h = self.conv2(h)
144
-
145
- if self.in_channels != self.out_channels:
146
- if self.use_conv_shortcut:
147
- x = self.conv_shortcut(x)
148
- else:
149
- x = self.nin_shortcut(x)
150
-
151
- return x + h
152
-
153
-
154
- class LinAttnBlock(LinearAttention):
155
- """to match AttnBlock usage"""
156
-
157
- def __init__(self, in_channels):
158
- super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
159
-
160
-
161
- class AttnBlock(nn.Module):
162
- def __init__(self, in_channels):
163
- super().__init__()
164
- self.in_channels = in_channels
165
-
166
- self.norm = Normalize(in_channels)
167
- self.q = torch.nn.Conv2d(
168
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
169
- )
170
- self.k = torch.nn.Conv2d(
171
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
172
- )
173
- self.v = torch.nn.Conv2d(
174
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
175
- )
176
- self.proj_out = torch.nn.Conv2d(
177
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
178
- )
179
-
180
- def attention(self, h_: torch.Tensor) -> torch.Tensor:
181
- h_ = self.norm(h_)
182
- q = self.q(h_)
183
- k = self.k(h_)
184
- v = self.v(h_)
185
-
186
- b, c, h, w = q.shape
187
- q, k, v = map(
188
- lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
189
- )
190
- h_ = torch.nn.functional.scaled_dot_product_attention(
191
- q, k, v
192
- ) # scale is dim ** -0.5 per default
193
- # compute attention
194
-
195
- return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
196
-
197
- def forward(self, x, **kwargs):
198
- h_ = x
199
- h_ = self.attention(h_)
200
- h_ = self.proj_out(h_)
201
- return x + h_
202
-
203
-
204
- class MemoryEfficientAttnBlock(nn.Module):
205
- """
206
- Uses xformers efficient implementation,
207
- see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
208
- Note: this is a single-head self-attention operation
209
- """
210
-
211
- #
212
- def __init__(self, in_channels):
213
- super().__init__()
214
- self.in_channels = in_channels
215
-
216
- self.norm = Normalize(in_channels)
217
- self.q = torch.nn.Conv2d(
218
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
219
- )
220
- self.k = torch.nn.Conv2d(
221
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
222
- )
223
- self.v = torch.nn.Conv2d(
224
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
225
- )
226
- self.proj_out = torch.nn.Conv2d(
227
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
228
- )
229
- self.attention_op: Optional[Any] = None
230
-
231
- def attention(self, h_: torch.Tensor) -> torch.Tensor:
232
- h_ = self.norm(h_)
233
- q = self.q(h_)
234
- k = self.k(h_)
235
- v = self.v(h_)
236
-
237
- # compute attention
238
- B, C, H, W = q.shape
239
- q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
240
-
241
- q, k, v = map(
242
- lambda t: t.unsqueeze(3)
243
- .reshape(B, t.shape[1], 1, C)
244
- .permute(0, 2, 1, 3)
245
- .reshape(B * 1, t.shape[1], C)
246
- .contiguous(),
247
- (q, k, v),
248
- )
249
- out = xformers.ops.memory_efficient_attention(
250
- q, k, v, attn_bias=None, op=self.attention_op
251
- )
252
-
253
- out = (
254
- out.unsqueeze(0)
255
- .reshape(B, 1, out.shape[1], C)
256
- .permute(0, 2, 1, 3)
257
- .reshape(B, out.shape[1], C)
258
- )
259
- return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
260
-
261
- def forward(self, x, **kwargs):
262
- h_ = x
263
- h_ = self.attention(h_)
264
- h_ = self.proj_out(h_)
265
- return x + h_
266
-
267
-
268
- class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
269
- def forward(self, x, context=None, mask=None, **unused_kwargs):
270
- b, c, h, w = x.shape
271
- x = rearrange(x, "b c h w -> b (h w) c")
272
- out = super().forward(x, context=context, mask=mask)
273
- out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
274
- return x + out
275
-
276
-
277
- def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
278
- assert attn_type in [
279
- "vanilla",
280
- "vanilla-xformers",
281
- "memory-efficient-cross-attn",
282
- "linear",
283
- "none",
284
- ], f"attn_type {attn_type} unknown"
285
- if (
286
- version.parse(torch.__version__) < version.parse("2.0.0")
287
- and attn_type != "none"
288
- ):
289
- assert XFORMERS_IS_AVAILABLE, (
290
- f"We do not support vanilla attention in {torch.__version__} anymore, "
291
- f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
292
- )
293
- attn_type = "vanilla-xformers"
294
- logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
295
- if attn_type == "vanilla":
296
- assert attn_kwargs is None
297
- return AttnBlock(in_channels)
298
- elif attn_type == "vanilla-xformers":
299
- logpy.info(
300
- f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
301
- )
302
- return MemoryEfficientAttnBlock(in_channels)
303
- elif type == "memory-efficient-cross-attn":
304
- attn_kwargs["query_dim"] = in_channels
305
- return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
306
- elif attn_type == "none":
307
- return nn.Identity(in_channels)
308
- else:
309
- return LinAttnBlock(in_channels)
310
-
311
-
312
- class Model(nn.Module):
313
- def __init__(
314
- self,
315
- *,
316
- ch,
317
- out_ch,
318
- ch_mult=(1, 2, 4, 8),
319
- num_res_blocks,
320
- attn_resolutions,
321
- dropout=0.0,
322
- resamp_with_conv=True,
323
- in_channels,
324
- resolution,
325
- use_timestep=True,
326
- use_linear_attn=False,
327
- attn_type="vanilla",
328
- ):
329
- super().__init__()
330
- if use_linear_attn:
331
- attn_type = "linear"
332
- self.ch = ch
333
- self.temb_ch = self.ch * 4
334
- self.num_resolutions = len(ch_mult)
335
- self.num_res_blocks = num_res_blocks
336
- self.resolution = resolution
337
- self.in_channels = in_channels
338
-
339
- self.use_timestep = use_timestep
340
- if self.use_timestep:
341
- # timestep embedding
342
- self.temb = nn.Module()
343
- self.temb.dense = nn.ModuleList(
344
- [
345
- torch.nn.Linear(self.ch, self.temb_ch),
346
- torch.nn.Linear(self.temb_ch, self.temb_ch),
347
- ]
348
- )
349
-
350
- # downsampling
351
- self.conv_in = torch.nn.Conv2d(
352
- in_channels, self.ch, kernel_size=3, stride=1, padding=1
353
- )
354
-
355
- curr_res = resolution
356
- in_ch_mult = (1,) + tuple(ch_mult)
357
- self.down = nn.ModuleList()
358
- for i_level in range(self.num_resolutions):
359
- block = nn.ModuleList()
360
- attn = nn.ModuleList()
361
- block_in = ch * in_ch_mult[i_level]
362
- block_out = ch * ch_mult[i_level]
363
- for i_block in range(self.num_res_blocks):
364
- block.append(
365
- ResnetBlock(
366
- in_channels=block_in,
367
- out_channels=block_out,
368
- temb_channels=self.temb_ch,
369
- dropout=dropout,
370
- )
371
- )
372
- block_in = block_out
373
- if curr_res in attn_resolutions:
374
- attn.append(make_attn(block_in, attn_type=attn_type))
375
- down = nn.Module()
376
- down.block = block
377
- down.attn = attn
378
- if i_level != self.num_resolutions - 1:
379
- down.downsample = Downsample(block_in, resamp_with_conv)
380
- curr_res = curr_res // 2
381
- self.down.append(down)
382
-
383
- # middle
384
- self.mid = nn.Module()
385
- self.mid.block_1 = ResnetBlock(
386
- in_channels=block_in,
387
- out_channels=block_in,
388
- temb_channels=self.temb_ch,
389
- dropout=dropout,
390
- )
391
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
392
- self.mid.block_2 = ResnetBlock(
393
- in_channels=block_in,
394
- out_channels=block_in,
395
- temb_channels=self.temb_ch,
396
- dropout=dropout,
397
- )
398
-
399
- # upsampling
400
- self.up = nn.ModuleList()
401
- for i_level in reversed(range(self.num_resolutions)):
402
- block = nn.ModuleList()
403
- attn = nn.ModuleList()
404
- block_out = ch * ch_mult[i_level]
405
- skip_in = ch * ch_mult[i_level]
406
- for i_block in range(self.num_res_blocks + 1):
407
- if i_block == self.num_res_blocks:
408
- skip_in = ch * in_ch_mult[i_level]
409
- block.append(
410
- ResnetBlock(
411
- in_channels=block_in + skip_in,
412
- out_channels=block_out,
413
- temb_channels=self.temb_ch,
414
- dropout=dropout,
415
- )
416
- )
417
- block_in = block_out
418
- if curr_res in attn_resolutions:
419
- attn.append(make_attn(block_in, attn_type=attn_type))
420
- up = nn.Module()
421
- up.block = block
422
- up.attn = attn
423
- if i_level != 0:
424
- up.upsample = Upsample(block_in, resamp_with_conv)
425
- curr_res = curr_res * 2
426
- self.up.insert(0, up) # prepend to get consistent order
427
-
428
- # end
429
- self.norm_out = Normalize(block_in)
430
- self.conv_out = torch.nn.Conv2d(
431
- block_in, out_ch, kernel_size=3, stride=1, padding=1
432
- )
433
-
434
- def forward(self, x, t=None, context=None):
435
- # assert x.shape[2] == x.shape[3] == self.resolution
436
- if context is not None:
437
- # assume aligned context, cat along channel axis
438
- x = torch.cat((x, context), dim=1)
439
- if self.use_timestep:
440
- # timestep embedding
441
- assert t is not None
442
- temb = get_timestep_embedding(t, self.ch)
443
- temb = self.temb.dense[0](temb)
444
- temb = nonlinearity(temb)
445
- temb = self.temb.dense[1](temb)
446
- else:
447
- temb = None
448
-
449
- # downsampling
450
- hs = [self.conv_in(x)]
451
- for i_level in range(self.num_resolutions):
452
- for i_block in range(self.num_res_blocks):
453
- h = self.down[i_level].block[i_block](hs[-1], temb)
454
- if len(self.down[i_level].attn) > 0:
455
- h = self.down[i_level].attn[i_block](h)
456
- hs.append(h)
457
- if i_level != self.num_resolutions - 1:
458
- hs.append(self.down[i_level].downsample(hs[-1]))
459
-
460
- # middle
461
- h = hs[-1]
462
- h = self.mid.block_1(h, temb)
463
- h = self.mid.attn_1(h)
464
- h = self.mid.block_2(h, temb)
465
-
466
- # upsampling
467
- for i_level in reversed(range(self.num_resolutions)):
468
- for i_block in range(self.num_res_blocks + 1):
469
- h = self.up[i_level].block[i_block](
470
- torch.cat([h, hs.pop()], dim=1), temb
471
- )
472
- if len(self.up[i_level].attn) > 0:
473
- h = self.up[i_level].attn[i_block](h)
474
- if i_level != 0:
475
- h = self.up[i_level].upsample(h)
476
-
477
- # end
478
- h = self.norm_out(h)
479
- h = nonlinearity(h)
480
- h = self.conv_out(h)
481
- return h
482
-
483
- def get_last_layer(self):
484
- return self.conv_out.weight
485
-
486
-
487
- class Encoder(nn.Module):
488
- def __init__(
489
- self,
490
- *,
491
- ch,
492
- out_ch,
493
- ch_mult=(1, 2, 4, 8),
494
- num_res_blocks,
495
- attn_resolutions,
496
- dropout=0.0,
497
- resamp_with_conv=True,
498
- in_channels,
499
- resolution,
500
- z_channels,
501
- double_z=True,
502
- use_linear_attn=False,
503
- attn_type="vanilla",
504
- **ignore_kwargs,
505
- ):
506
- super().__init__()
507
- if use_linear_attn:
508
- attn_type = "linear"
509
- self.ch = ch
510
- self.temb_ch = 0
511
- self.num_resolutions = len(ch_mult)
512
- self.num_res_blocks = num_res_blocks
513
- self.resolution = resolution
514
- self.in_channels = in_channels
515
-
516
- # downsampling
517
- self.conv_in = torch.nn.Conv2d(
518
- in_channels, self.ch, kernel_size=3, stride=1, padding=1
519
- )
520
-
521
- curr_res = resolution
522
- in_ch_mult = (1,) + tuple(ch_mult)
523
- self.in_ch_mult = in_ch_mult
524
- self.down = nn.ModuleList()
525
- for i_level in range(self.num_resolutions):
526
- block = nn.ModuleList()
527
- attn = nn.ModuleList()
528
- block_in = ch * in_ch_mult[i_level]
529
- block_out = ch * ch_mult[i_level]
530
- for i_block in range(self.num_res_blocks):
531
- block.append(
532
- ResnetBlock(
533
- in_channels=block_in,
534
- out_channels=block_out,
535
- temb_channels=self.temb_ch,
536
- dropout=dropout,
537
- )
538
- )
539
- block_in = block_out
540
- if curr_res in attn_resolutions:
541
- attn.append(make_attn(block_in, attn_type=attn_type))
542
- down = nn.Module()
543
- down.block = block
544
- down.attn = attn
545
- if i_level != self.num_resolutions - 1:
546
- down.downsample = Downsample(block_in, resamp_with_conv)
547
- curr_res = curr_res // 2
548
- self.down.append(down)
549
-
550
- # middle
551
- self.mid = nn.Module()
552
- self.mid.block_1 = ResnetBlock(
553
- in_channels=block_in,
554
- out_channels=block_in,
555
- temb_channels=self.temb_ch,
556
- dropout=dropout,
557
- )
558
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
559
- self.mid.block_2 = ResnetBlock(
560
- in_channels=block_in,
561
- out_channels=block_in,
562
- temb_channels=self.temb_ch,
563
- dropout=dropout,
564
- )
565
-
566
- # end
567
- self.norm_out = Normalize(block_in)
568
- self.conv_out = torch.nn.Conv2d(
569
- block_in,
570
- 2 * z_channels if double_z else z_channels,
571
- kernel_size=3,
572
- stride=1,
573
- padding=1,
574
- )
575
-
576
- def forward(self, x):
577
- # timestep embedding
578
- temb = None
579
-
580
- # downsampling
581
- hs = [self.conv_in(x)]
582
- for i_level in range(self.num_resolutions):
583
- for i_block in range(self.num_res_blocks):
584
- h = self.down[i_level].block[i_block](hs[-1], temb)
585
- if len(self.down[i_level].attn) > 0:
586
- h = self.down[i_level].attn[i_block](h)
587
- hs.append(h)
588
- if i_level != self.num_resolutions - 1:
589
- hs.append(self.down[i_level].downsample(hs[-1]))
590
-
591
- # middle
592
- h = hs[-1]
593
- h = self.mid.block_1(h, temb)
594
- h = self.mid.attn_1(h)
595
- h = self.mid.block_2(h, temb)
596
-
597
- # end
598
- h = self.norm_out(h)
599
- h = nonlinearity(h)
600
- h = self.conv_out(h)
601
- return h
602
-
603
-
604
- class Decoder(nn.Module):
605
- def __init__(
606
- self,
607
- *,
608
- ch,
609
- out_ch,
610
- ch_mult=(1, 2, 4, 8),
611
- num_res_blocks,
612
- attn_resolutions,
613
- dropout=0.0,
614
- resamp_with_conv=True,
615
- in_channels,
616
- resolution,
617
- z_channels,
618
- give_pre_end=False,
619
- tanh_out=False,
620
- use_linear_attn=False,
621
- attn_type="vanilla",
622
- **ignorekwargs,
623
- ):
624
- super().__init__()
625
- if use_linear_attn:
626
- attn_type = "linear"
627
- self.ch = ch
628
- self.temb_ch = 0
629
- self.num_resolutions = len(ch_mult)
630
- self.num_res_blocks = num_res_blocks
631
- self.resolution = resolution
632
- self.in_channels = in_channels
633
- self.give_pre_end = give_pre_end
634
- self.tanh_out = tanh_out
635
-
636
- # compute in_ch_mult, block_in and curr_res at lowest res
637
- in_ch_mult = (1,) + tuple(ch_mult)
638
- block_in = ch * ch_mult[self.num_resolutions - 1]
639
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
640
- self.z_shape = (1, z_channels, curr_res, curr_res)
641
- logpy.info(
642
- "Working with z of shape {} = {} dimensions.".format(
643
- self.z_shape, np.prod(self.z_shape)
644
- )
645
- )
646
-
647
- make_attn_cls = self._make_attn()
648
- make_resblock_cls = self._make_resblock()
649
- make_conv_cls = self._make_conv()
650
- # z to block_in
651
- self.conv_in = torch.nn.Conv2d(
652
- z_channels, block_in, kernel_size=3, stride=1, padding=1
653
- )
654
-
655
- # middle
656
- self.mid = nn.Module()
657
- self.mid.block_1 = make_resblock_cls(
658
- in_channels=block_in,
659
- out_channels=block_in,
660
- temb_channels=self.temb_ch,
661
- dropout=dropout,
662
- )
663
- self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
664
- self.mid.block_2 = make_resblock_cls(
665
- in_channels=block_in,
666
- out_channels=block_in,
667
- temb_channels=self.temb_ch,
668
- dropout=dropout,
669
- )
670
-
671
- # upsampling
672
- self.up = nn.ModuleList()
673
- for i_level in reversed(range(self.num_resolutions)):
674
- block = nn.ModuleList()
675
- attn = nn.ModuleList()
676
- block_out = ch * ch_mult[i_level]
677
- for i_block in range(self.num_res_blocks + 1):
678
- block.append(
679
- make_resblock_cls(
680
- in_channels=block_in,
681
- out_channels=block_out,
682
- temb_channels=self.temb_ch,
683
- dropout=dropout,
684
- )
685
- )
686
- block_in = block_out
687
- if curr_res in attn_resolutions:
688
- attn.append(make_attn_cls(block_in, attn_type=attn_type))
689
- up = nn.Module()
690
- up.block = block
691
- up.attn = attn
692
- if i_level != 0:
693
- up.upsample = Upsample(block_in, resamp_with_conv)
694
- curr_res = curr_res * 2
695
- self.up.insert(0, up) # prepend to get consistent order
696
-
697
- # end
698
- self.norm_out = Normalize(block_in)
699
- self.conv_out = make_conv_cls(
700
- block_in, out_ch, kernel_size=3, stride=1, padding=1
701
- )
702
-
703
- def _make_attn(self) -> Callable:
704
- return make_attn
705
-
706
- def _make_resblock(self) -> Callable:
707
- return ResnetBlock
708
-
709
- def _make_conv(self) -> Callable:
710
- return torch.nn.Conv2d
711
-
712
- def get_last_layer(self, **kwargs):
713
- return self.conv_out.weight
714
-
715
- def forward(self, z, **kwargs):
716
- # assert z.shape[1:] == self.z_shape[1:]
717
- self.last_z_shape = z.shape
718
-
719
- # timestep embedding
720
- temb = None
721
-
722
- # z to block_in
723
- h = self.conv_in(z)
724
-
725
- # middle
726
- h = self.mid.block_1(h, temb, **kwargs)
727
- h = self.mid.attn_1(h, **kwargs)
728
- h = self.mid.block_2(h, temb, **kwargs)
729
-
730
- # upsampling
731
- for i_level in reversed(range(self.num_resolutions)):
732
- for i_block in range(self.num_res_blocks + 1):
733
- h = self.up[i_level].block[i_block](h, temb, **kwargs)
734
- if len(self.up[i_level].attn) > 0:
735
- h = self.up[i_level].attn[i_block](h, **kwargs)
736
- if i_level != 0:
737
- h = self.up[i_level].upsample(h)
738
-
739
- # end
740
- if self.give_pre_end:
741
- return h
742
-
743
- h = self.norm_out(h)
744
- h = nonlinearity(h)
745
- h = self.conv_out(h, **kwargs)
746
- if self.tanh_out:
747
- h = torch.tanh(h)
748
- return h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/openaimodel.py DELETED
@@ -1,1352 +0,0 @@
1
- import logging
2
- import math
3
- from abc import abstractmethod
4
- from functools import partial
5
- from typing import Iterable, List, Optional, Tuple, Union
6
-
7
- import numpy as np
8
- import torch as th
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- from einops import rearrange
12
- import torch
13
- from torch.profiler import profile, record_function, ProfilerActivity
14
- from ...modules.attention import SpatialTransformer
15
- from ...modules.diffusionmodules.util import (
16
- avg_pool_nd,
17
- checkpoint,
18
- conv_nd,
19
- linear,
20
- normalization,
21
- timestep_embedding,
22
- zero_module,
23
- )
24
- from ...util import default, exists
25
-
26
-
27
- logpy = logging.getLogger(__name__)
28
-
29
-
30
- class AttentionPool2d(nn.Module):
31
- """
32
- Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
33
- """
34
-
35
- def __init__(
36
- self,
37
- spacial_dim: int,
38
- embed_dim: int,
39
- num_heads_channels: int,
40
- output_dim: int = None,
41
- ):
42
- super().__init__()
43
- self.positional_embedding = nn.Parameter(
44
- th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
45
- )
46
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
47
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
48
- self.num_heads = embed_dim // num_heads_channels
49
- self.attention = QKVAttention(self.num_heads)
50
-
51
- def forward(self, x):
52
- b, c, *_spatial = x.shape
53
- x = x.reshape(b, c, -1) # NC(HW)
54
- x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
55
- x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
56
- x = self.qkv_proj(x)
57
- x = self.attention(x)
58
- x = self.c_proj(x)
59
- return x[:, :, 0]
60
-
61
-
62
- class TimestepBlock(nn.Module):
63
- """
64
- Any module where forward() takes timestep embeddings as a second argument.
65
- """
66
-
67
- @abstractmethod
68
- def forward(self, x, emb):
69
- """
70
- Apply the module to `x` given `emb` timestep embeddings.
71
- """
72
-
73
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
74
- """
75
- A sequential module that passes timestep embeddings to the children that
76
- support it as an extra input.
77
- """
78
-
79
- def forward(
80
- self,
81
- x,
82
- emb,
83
- context=None,
84
- xr=None,
85
- embr=None,
86
- contextr=None,
87
- pose=None,
88
- mask_ref=None,
89
- prev_weights=None,
90
- ):
91
- weights = None
92
- fg_mask = None
93
- alphas = None
94
- predicted_rgb = None
95
- for layer in self:
96
- if isinstance(layer, TimestepBlock):
97
- x = layer(x, emb)
98
- if xr is not None:
99
- with torch.no_grad():
100
- xr = layer(xr, embr)
101
- xr = xr.detach()
102
- elif isinstance(layer, SpatialTransformer):
103
- x, xr, fg_mask, weights, alphas, predicted_rgb = layer(x, xr, context, contextr, pose, mask_ref, prev_weights=prev_weights)
104
- else:
105
- x = layer(x)
106
- if xr is not None:
107
- with torch.no_grad():
108
- xr = layer(xr)
109
- xr = xr.detach()
110
-
111
- return x, xr, fg_mask, weights, alphas, predicted_rgb
112
-
113
-
114
- class Upsample(nn.Module):
115
- """
116
- An upsampling layer with an optional convolution.
117
- :param channels: channels in the inputs and outputs.
118
- :param use_conv: a bool determining if a convolution is applied.
119
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
120
- upsampling occurs in the inner-two dimensions.
121
- """
122
-
123
- def __init__(
124
- self,
125
- channels: int,
126
- use_conv: bool,
127
- dims: int = 2,
128
- out_channels: Optional[int] = None,
129
- padding: int = 1,
130
- third_up: bool = False,
131
- kernel_size: int = 3,
132
- scale_factor: int = 2,
133
- ):
134
- super().__init__()
135
- self.channels = channels
136
- self.out_channels = out_channels or channels
137
- self.use_conv = use_conv
138
- self.dims = dims
139
- self.third_up = third_up
140
- self.scale_factor = scale_factor
141
- if use_conv:
142
- self.conv = conv_nd(
143
- dims, self.channels, self.out_channels, kernel_size, padding=padding
144
- )
145
-
146
- def forward(self, x: th.Tensor) -> th.Tensor:
147
- assert x.shape[1] == self.channels
148
-
149
- if self.dims == 3:
150
- t_factor = 1 if not self.third_up else self.scale_factor
151
- x = F.interpolate(
152
- x,
153
- (
154
- t_factor * x.shape[2],
155
- x.shape[3] * self.scale_factor,
156
- x.shape[4] * self.scale_factor,
157
- ),
158
- mode="nearest",
159
- )
160
- else:
161
- x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
162
- if self.use_conv:
163
- x = self.conv(x)
164
- return x
165
-
166
-
167
- class TransposedUpsample(nn.Module):
168
- "Learned 2x upsampling without padding"
169
-
170
- def __init__(self, channels, out_channels=None, ks=5):
171
- super().__init__()
172
- self.channels = channels
173
- self.out_channels = out_channels or channels
174
-
175
- self.up = nn.ConvTranspose2d(
176
- self.channels, self.out_channels, kernel_size=ks, stride=2
177
- )
178
-
179
- def forward(self, x):
180
- return self.up(x)
181
-
182
-
183
- class Downsample(nn.Module):
184
- """
185
- A downsampling layer with an optional convolution.
186
- :param channels: channels in the inputs and outputs.
187
- :param use_conv: a bool determining if a convolution is applied.
188
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
189
- downsampling occurs in the inner-two dimensions.
190
- """
191
-
192
- def __init__(
193
- self,
194
- channels: int,
195
- use_conv: bool,
196
- dims: int = 2,
197
- out_channels: Optional[int] = None,
198
- padding: int = 1,
199
- third_down: bool = False,
200
- ):
201
- super().__init__()
202
- self.channels = channels
203
- self.out_channels = out_channels or channels
204
- self.use_conv = use_conv
205
- self.dims = dims
206
- stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
207
- if use_conv:
208
- logpy.info(f"Building a Downsample layer with {dims} dims.")
209
- logpy.info(
210
- f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
211
- f"kernel-size: 3, stride: {stride}, padding: {padding}"
212
- )
213
- if dims == 3:
214
- logpy.info(f" --> Downsampling third axis (time): {third_down}")
215
- self.op = conv_nd(
216
- dims,
217
- self.channels,
218
- self.out_channels,
219
- 3,
220
- stride=stride,
221
- padding=padding,
222
- )
223
- else:
224
- assert self.channels == self.out_channels
225
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
226
-
227
- def forward(self, x: th.Tensor) -> th.Tensor:
228
- assert x.shape[1] == self.channels
229
-
230
- return self.op(x)
231
-
232
-
233
- class ResBlock(TimestepBlock):
234
- """
235
- A residual block that can optionally change the number of channels.
236
- :param channels: the number of input channels.
237
- :param emb_channels: the number of timestep embedding channels.
238
- :param dropout: the rate of dropout.
239
- :param out_channels: if specified, the number of out channels.
240
- :param use_conv: if True and out_channels is specified, use a spatial
241
- convolution instead of a smaller 1x1 convolution to change the
242
- channels in the skip connection.
243
- :param dims: determines if the signal is 1D, 2D, or 3D.
244
- :param use_checkpoint: if True, use gradient checkpointing on this module.
245
- :param up: if True, use this block for upsampling.
246
- :param down: if True, use this block for downsampling.
247
- """
248
-
249
- def __init__(
250
- self,
251
- channels: int,
252
- emb_channels: int,
253
- dropout: float,
254
- out_channels: Optional[int] = None,
255
- use_conv: bool = False,
256
- use_scale_shift_norm: bool = False,
257
- dims: int = 2,
258
- use_checkpoint: bool = False,
259
- up: bool = False,
260
- down: bool = False,
261
- kernel_size: int = 3,
262
- exchange_temb_dims: bool = False,
263
- skip_t_emb: bool = False,
264
- ):
265
- super().__init__()
266
- self.channels = channels
267
- self.emb_channels = emb_channels
268
- self.dropout = dropout
269
- self.out_channels = out_channels or channels
270
- self.use_conv = use_conv
271
- self.use_checkpoint = use_checkpoint
272
- self.use_scale_shift_norm = use_scale_shift_norm
273
- self.exchange_temb_dims = exchange_temb_dims
274
-
275
- if isinstance(kernel_size, Iterable):
276
- padding = [k // 2 for k in kernel_size]
277
- else:
278
- padding = kernel_size // 2
279
-
280
- self.in_layers = nn.Sequential(
281
- normalization(channels),
282
- nn.SiLU(),
283
- conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
284
- )
285
-
286
- self.updown = up or down
287
-
288
- if up:
289
- self.h_upd = Upsample(channels, False, dims)
290
- self.x_upd = Upsample(channels, False, dims)
291
- elif down:
292
- self.h_upd = Downsample(channels, False, dims)
293
- self.x_upd = Downsample(channels, False, dims)
294
- else:
295
- self.h_upd = self.x_upd = nn.Identity()
296
-
297
- self.skip_t_emb = skip_t_emb
298
- self.emb_out_channels = (
299
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels
300
- )
301
- if self.skip_t_emb:
302
- logpy.info(f"Skipping timestep embedding in {self.__class__.__name__}")
303
- assert not self.use_scale_shift_norm
304
- self.emb_layers = None
305
- self.exchange_temb_dims = False
306
- else:
307
- self.emb_layers = nn.Sequential(
308
- nn.SiLU(),
309
- linear(
310
- emb_channels,
311
- self.emb_out_channels,
312
- ),
313
- )
314
-
315
- self.out_layers = nn.Sequential(
316
- normalization(self.out_channels),
317
- nn.SiLU(),
318
- nn.Dropout(p=dropout),
319
- zero_module(
320
- conv_nd(
321
- dims,
322
- self.out_channels,
323
- self.out_channels,
324
- kernel_size,
325
- padding=padding,
326
- )
327
- ),
328
- )
329
-
330
- if self.out_channels == channels:
331
- self.skip_connection = nn.Identity()
332
- elif use_conv:
333
- self.skip_connection = conv_nd(
334
- dims, channels, self.out_channels, kernel_size, padding=padding
335
- )
336
- else:
337
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
338
-
339
- def forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor:
340
- """
341
- Apply the block to a Tensor, conditioned on a timestep embedding.
342
- :param x: an [N x C x ...] Tensor of features.
343
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
344
- :return: an [N x C x ...] Tensor of outputs.
345
- """
346
- return checkpoint(
347
- self._forward, (x, emb), self.parameters(), self.use_checkpoint
348
- )
349
-
350
- def _forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor:
351
- if self.updown:
352
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
353
- h = in_rest(x)
354
- h = self.h_upd(h)
355
- x = self.x_upd(x)
356
- h = in_conv(h)
357
- else:
358
- h = self.in_layers(x)
359
-
360
- if self.skip_t_emb:
361
- emb_out = th.zeros_like(h)
362
- else:
363
- emb_out = self.emb_layers(emb).type(h.dtype)
364
- while len(emb_out.shape) < len(h.shape):
365
- emb_out = emb_out[..., None]
366
- if self.use_scale_shift_norm:
367
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
368
- scale, shift = th.chunk(emb_out, 2, dim=1)
369
- h = out_norm(h) * (1 + scale) + shift
370
- h = out_rest(h)
371
- else:
372
- if self.exchange_temb_dims:
373
- emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
374
- h = h + emb_out
375
- h = self.out_layers(h)
376
- return self.skip_connection(x) + h
377
-
378
-
379
- class AttentionBlock(nn.Module):
380
- """
381
- An attention block that allows spatial positions to attend to each other.
382
- Originally ported from here, but adapted to the N-d case.
383
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
384
- """
385
-
386
- def __init__(
387
- self,
388
- channels: int,
389
- num_heads: int = 1,
390
- num_head_channels: int = -1,
391
- use_checkpoint: bool = False,
392
- use_new_attention_order: bool = False,
393
- ):
394
- super().__init__()
395
- self.channels = channels
396
- if num_head_channels == -1:
397
- self.num_heads = num_heads
398
- else:
399
- assert (
400
- channels % num_head_channels == 0
401
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
402
- self.num_heads = channels // num_head_channels
403
- self.use_checkpoint = use_checkpoint
404
- self.norm = normalization(channels)
405
- self.qkv = conv_nd(1, channels, channels * 3, 1)
406
- if use_new_attention_order:
407
- # split qkv before split heads
408
- self.attention = QKVAttention(self.num_heads)
409
- else:
410
- # split heads before split qkv
411
- self.attention = QKVAttentionLegacy(self.num_heads)
412
-
413
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
414
-
415
- def forward(self, x, **kwargs):
416
- # TODO add crossframe attention and use mixed checkpoint
417
- return checkpoint(
418
- self._forward, (x,), self.parameters(), True
419
- ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
420
-
421
- def _forward(self, x: th.Tensor) -> th.Tensor:
422
- b, c, *spatial = x.shape
423
- x = x.reshape(b, c, -1)
424
- qkv = self.qkv(self.norm(x))
425
- h = self.attention(qkv)
426
- h = self.proj_out(h)
427
- return (x + h).reshape(b, c, *spatial)
428
-
429
-
430
- def count_flops_attn(model, _x, y):
431
- """
432
- A counter for the `thop` package to count the operations in an
433
- attention operation.
434
- Meant to be used like:
435
- macs, params = thop.profile(
436
- model,
437
- inputs=(inputs, timestamps),
438
- custom_ops={QKVAttention: QKVAttention.count_flops},
439
- )
440
- """
441
- b, c, *spatial = y[0].shape
442
- num_spatial = int(np.prod(spatial))
443
- # We perform two matmuls with the same number of ops.
444
- # The first computes the weight matrix, the second computes
445
- # the combination of the value vectors.
446
- matmul_ops = 2 * b * (num_spatial**2) * c
447
- model.total_ops += th.DoubleTensor([matmul_ops])
448
-
449
-
450
- class QKVAttentionLegacy(nn.Module):
451
- """
452
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
453
- """
454
-
455
- def __init__(self, n_heads: int):
456
- super().__init__()
457
- self.n_heads = n_heads
458
-
459
- def forward(self, qkv: th.Tensor) -> th.Tensor:
460
- """
461
- Apply QKV attention.
462
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
463
- :return: an [N x (H * C) x T] tensor after attention.
464
- """
465
- bs, width, length = qkv.shape
466
- assert width % (3 * self.n_heads) == 0
467
- ch = width // (3 * self.n_heads)
468
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
469
- scale = 1 / math.sqrt(math.sqrt(ch))
470
- weight = th.einsum(
471
- "bct,bcs->bts", q * scale, k * scale
472
- ) # More stable with f16 than dividing afterwards
473
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
474
- a = th.einsum("bts,bcs->bct", weight, v)
475
- return a.reshape(bs, -1, length)
476
-
477
- @staticmethod
478
- def count_flops(model, _x, y):
479
- return count_flops_attn(model, _x, y)
480
-
481
-
482
- class QKVAttention(nn.Module):
483
- """
484
- A module which performs QKV attention and splits in a different order.
485
- """
486
-
487
- def __init__(self, n_heads: int):
488
- super().__init__()
489
- self.n_heads = n_heads
490
-
491
- def forward(self, qkv: th.Tensor) -> th.Tensor:
492
- """
493
- Apply QKV attention.
494
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
495
- :return: an [N x (H * C) x T] tensor after attention.
496
- """
497
- bs, width, length = qkv.shape
498
- assert width % (3 * self.n_heads) == 0
499
- ch = width // (3 * self.n_heads)
500
- q, k, v = qkv.chunk(3, dim=1)
501
- scale = 1 / math.sqrt(math.sqrt(ch))
502
- weight = th.einsum(
503
- "bct,bcs->bts",
504
- (q * scale).view(bs * self.n_heads, ch, length),
505
- (k * scale).view(bs * self.n_heads, ch, length),
506
- ) # More stable with f16 than dividing afterwards
507
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
508
- a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
509
- return a.reshape(bs, -1, length)
510
-
511
- @staticmethod
512
- def count_flops(model, _x, y):
513
- return count_flops_attn(model, _x, y)
514
-
515
-
516
- class Timestep(nn.Module):
517
- def __init__(self, dim: int):
518
- super().__init__()
519
- self.dim = dim
520
-
521
- def forward(self, t: th.Tensor) -> th.Tensor:
522
- return timestep_embedding(t, self.dim)
523
-
524
-
525
- class UNetModel(nn.Module):
526
- """
527
- The full UNet model with attention and timestep embedding.
528
- :param in_channels: channels in the input Tensor.
529
- :param model_channels: base channel count for the model.
530
- :param out_channels: channels in the output Tensor.
531
- :param num_res_blocks: number of residual blocks per downsample.
532
- :param attention_resolutions: a collection of downsample rates at which
533
- attention will take place. May be a set, list, or tuple.
534
- For example, if this contains 4, then at 4x downsampling, attention
535
- will be used.
536
- :param dropout: the dropout probability.
537
- :param channel_mult: channel multiplier for each level of the UNet.
538
- :param conv_resample: if True, use learned convolutions for upsampling and
539
- downsampling.
540
- :param dims: determines if the signal is 1D, 2D, or 3D.
541
- :param num_classes: if specified (as an int), then this model will be
542
- class-conditional with `num_classes` classes.
543
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
544
- :param num_heads: the number of attention heads in each attention layer.
545
- :param num_heads_channels: if specified, ignore num_heads and instead use
546
- a fixed channel width per attention head.
547
- :param num_heads_upsample: works with num_heads to set a different number
548
- of heads for upsampling. Deprecated.
549
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
550
- :param resblock_updown: use residual blocks for up/downsampling.
551
- :param use_new_attention_order: use a different attention pattern for potentially
552
- increased efficiency.
553
- """
554
-
555
- def __init__(
556
- self,
557
- in_channels: int,
558
- model_channels: int,
559
- out_channels: int,
560
- num_res_blocks: int,
561
- attention_resolutions: int,
562
- dropout: float = 0.0,
563
- channel_mult: Union[List, Tuple] = (1, 2, 4, 8),
564
- conv_resample: bool = True,
565
- dims: int = 2,
566
- num_classes: Optional[Union[int, str]] = None,
567
- use_checkpoint: bool = False,
568
- num_heads: int = -1,
569
- num_head_channels: int = -1,
570
- num_heads_upsample: int = -1,
571
- use_scale_shift_norm: bool = False,
572
- resblock_updown: bool = False,
573
- transformer_depth: int = 1,
574
- context_dim: Optional[int] = None,
575
- disable_self_attentions: Optional[List[bool]] = None,
576
- num_attention_blocks: Optional[List[int]] = None,
577
- disable_middle_self_attn: bool = False,
578
- use_linear_in_transformer: bool = False,
579
- spatial_transformer_attn_type: str = "softmax",
580
- adm_in_channels: Optional[int] = None,
581
- use_fairscale_checkpoint=False,
582
- offload_to_cpu=False,
583
- transformer_depth_middle: Optional[int] = None,
584
- ## new args
585
- image_cross_blocks: Union[List, Tuple] = None,
586
- rgb: bool = False,
587
- far: float = 2.,
588
- num_samples: float = 32,
589
- not_add_context_in_triplane: bool = False,
590
- rgb_predict: bool = False,
591
- add_lora: bool = False,
592
- mode: str = 'feature-nerf',
593
- average: bool = False,
594
- num_freqs: int = 16,
595
- use_prev_weights_imp_sample: bool = False,
596
- stratified: bool = False,
597
- poscontrol_interval: int = 4,
598
- imp_sampling_percent: float = 0.9,
599
- near_plane: float = 0.
600
- ):
601
- super().__init__()
602
- from omegaconf.listconfig import ListConfig
603
-
604
- if num_heads_upsample == -1:
605
- num_heads_upsample = num_heads
606
-
607
- if num_heads == -1:
608
- assert (
609
- num_head_channels != -1
610
- ), "Either num_heads or num_head_channels has to be set"
611
-
612
- if num_head_channels == -1:
613
- assert (
614
- num_heads != -1
615
- ), "Either num_heads or num_head_channels has to be set"
616
-
617
- self.in_channels = in_channels
618
- self.model_channels = model_channels
619
- self.out_channels = out_channels
620
- self.rgb = rgb
621
- self.rgb_predict = rgb_predict
622
- if image_cross_blocks is None:
623
- image_cross_blocks = []
624
- if isinstance(transformer_depth, int):
625
- transformer_depth = len(channel_mult) * [transformer_depth]
626
- elif isinstance(transformer_depth, ListConfig):
627
- transformer_depth = list(transformer_depth)
628
- transformer_depth_middle = default(
629
- transformer_depth_middle, transformer_depth[-1]
630
- )
631
-
632
- if isinstance(num_res_blocks, int):
633
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
634
- else:
635
- if len(num_res_blocks) != len(channel_mult):
636
- raise ValueError(
637
- "provide num_res_blocks either as an int (globally constant) or "
638
- "as a list/tuple (per-level) with the same length as channel_mult"
639
- )
640
- self.num_res_blocks = num_res_blocks
641
- if disable_self_attentions is not None:
642
- assert len(disable_self_attentions) == len(channel_mult)
643
- if num_attention_blocks is not None:
644
- assert len(num_attention_blocks) == len(self.num_res_blocks)
645
- assert all(
646
- map(
647
- lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
648
- range(len(num_attention_blocks)),
649
- )
650
- )
651
- logpy.info(
652
- f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
653
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
654
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
655
- f"attention will still not be set."
656
- )
657
-
658
- self.attention_resolutions = attention_resolutions
659
- self.dropout = dropout
660
- self.channel_mult = channel_mult
661
- self.conv_resample = conv_resample
662
- self.num_classes = num_classes
663
- self.use_checkpoint = use_checkpoint
664
- self.num_heads = num_heads
665
- self.num_head_channels = num_head_channels
666
- self.num_heads_upsample = num_heads_upsample
667
-
668
- assert use_fairscale_checkpoint != use_checkpoint or not (
669
- use_checkpoint or use_fairscale_checkpoint
670
- )
671
-
672
- self.use_fairscale_checkpoint = False
673
- checkpoint_wrapper_fn = (
674
- partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
675
- if self.use_fairscale_checkpoint
676
- else lambda x: x
677
- )
678
-
679
- time_embed_dim = model_channels * 4
680
- self.time_embed = checkpoint_wrapper_fn(
681
- nn.Sequential(
682
- linear(model_channels, time_embed_dim),
683
- nn.SiLU(),
684
- linear(time_embed_dim, time_embed_dim),
685
- )
686
- )
687
-
688
- if self.num_classes is not None:
689
- if isinstance(self.num_classes, int):
690
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
691
- elif self.num_classes == "continuous":
692
- logpy.info("setting up linear c_adm embedding layer")
693
- self.label_emb = nn.Linear(1, time_embed_dim)
694
- elif self.num_classes == "timestep":
695
- self.label_emb = checkpoint_wrapper_fn(
696
- nn.Sequential(
697
- Timestep(model_channels),
698
- nn.Sequential(
699
- linear(model_channels, time_embed_dim),
700
- nn.SiLU(),
701
- linear(time_embed_dim, time_embed_dim),
702
- ),
703
- )
704
- )
705
- elif self.num_classes == "sequential":
706
- assert adm_in_channels is not None
707
- self.label_emb = nn.Sequential(
708
- nn.Sequential(
709
- linear(adm_in_channels, time_embed_dim),
710
- nn.SiLU(),
711
- linear(time_embed_dim, time_embed_dim),
712
- )
713
- )
714
- else:
715
- raise ValueError
716
-
717
- self.input_blocks = nn.ModuleList(
718
- [
719
- TimestepEmbedSequential(
720
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
721
- )
722
- ]
723
- )
724
- self._feature_size = model_channels
725
- input_block_chans = [model_channels]
726
- ch = model_channels
727
- ds = 1
728
- id_attention = 0
729
- for level, mult in enumerate(channel_mult):
730
- for nr in range(self.num_res_blocks[level]):
731
- layers = [
732
- checkpoint_wrapper_fn(
733
- ResBlock(
734
- ch,
735
- time_embed_dim,
736
- dropout,
737
- out_channels=mult * model_channels,
738
- dims=dims,
739
- use_checkpoint=use_checkpoint,
740
- use_scale_shift_norm=use_scale_shift_norm,
741
- )
742
- )
743
- ]
744
- ch = mult * model_channels
745
- if ds in attention_resolutions:
746
- if num_head_channels == -1:
747
- dim_head = ch // num_heads
748
- else:
749
- num_heads = ch // num_head_channels
750
- dim_head = num_head_channels
751
-
752
- if context_dim is not None and exists(disable_self_attentions):
753
- disabled_sa = disable_self_attentions[level]
754
- else:
755
- disabled_sa = False
756
-
757
- if (
758
- not exists(num_attention_blocks)
759
- or nr < num_attention_blocks[level]
760
- ):
761
- layers.append(
762
- checkpoint_wrapper_fn(
763
- SpatialTransformer(
764
- ch,
765
- num_heads,
766
- dim_head,
767
- depth=transformer_depth[level],
768
- context_dim=context_dim,
769
- disable_self_attn=disabled_sa,
770
- use_linear=use_linear_in_transformer,
771
- attn_type=spatial_transformer_attn_type,
772
- use_checkpoint=use_checkpoint,
773
- # image_cross=False,
774
- image_cross=(id_attention in image_cross_blocks),
775
- rgb_predict=self.rgb_predict,
776
- far=far,
777
- num_samples=num_samples,
778
- add_lora=add_lora,
779
- mode=mode,
780
- average=average,
781
- num_freqs=num_freqs,
782
- use_prev_weights_imp_sample=use_prev_weights_imp_sample,
783
- stratified=stratified,
784
- poscontrol_interval=poscontrol_interval,
785
- imp_sampling_percent=imp_sampling_percent,
786
- near_plane=near_plane,
787
- )
788
- )
789
- )
790
- print("({}) in Encoder".format(id_attention))
791
- id_attention += 1
792
- self.input_blocks.append(TimestepEmbedSequential(*layers))
793
- self._feature_size += ch
794
- input_block_chans.append(ch)
795
- if level != len(channel_mult) - 1:
796
- out_ch = ch
797
- self.input_blocks.append(
798
- TimestepEmbedSequential(
799
- checkpoint_wrapper_fn(
800
- ResBlock(
801
- ch,
802
- time_embed_dim,
803
- dropout,
804
- out_channels=out_ch,
805
- dims=dims,
806
- use_checkpoint=use_checkpoint,
807
- use_scale_shift_norm=use_scale_shift_norm,
808
- down=True,
809
- )
810
- )
811
- if resblock_updown
812
- else Downsample(
813
- ch, conv_resample, dims=dims, out_channels=out_ch
814
- )
815
- )
816
- )
817
- ch = out_ch
818
- input_block_chans.append(ch)
819
- ds *= 2
820
- self._feature_size += ch
821
-
822
- if num_head_channels == -1:
823
- dim_head = ch // num_heads
824
- else:
825
- num_heads = ch // num_head_channels
826
- dim_head = num_head_channels
827
- self.middle_block = TimestepEmbedSequential(
828
- checkpoint_wrapper_fn(
829
- ResBlock(
830
- ch,
831
- time_embed_dim,
832
- dropout,
833
- dims=dims,
834
- use_checkpoint=use_checkpoint,
835
- use_scale_shift_norm=use_scale_shift_norm,
836
- )
837
- ),
838
- checkpoint_wrapper_fn(
839
- SpatialTransformer( # always uses a self-attn
840
- ch,
841
- num_heads,
842
- dim_head,
843
- depth=transformer_depth_middle,
844
- context_dim=context_dim,
845
- disable_self_attn=disable_middle_self_attn,
846
- use_linear=use_linear_in_transformer,
847
- attn_type=spatial_transformer_attn_type,
848
- use_checkpoint=use_checkpoint,
849
- image_cross=(id_attention in image_cross_blocks),
850
- rgb_predict=self.rgb_predict,
851
- far=far,
852
- num_samples=num_samples,
853
- add_lora=add_lora,
854
- mode=mode,
855
- average=average,
856
- num_freqs=num_freqs,
857
- use_prev_weights_imp_sample=use_prev_weights_imp_sample,
858
- stratified=stratified,
859
- poscontrol_interval=poscontrol_interval,
860
- imp_sampling_percent=imp_sampling_percent,
861
- near_plane=near_plane,
862
- )
863
- ),
864
- checkpoint_wrapper_fn(
865
- ResBlock(
866
- ch,
867
- time_embed_dim,
868
- dropout,
869
- dims=dims,
870
- use_checkpoint=use_checkpoint,
871
- use_scale_shift_norm=use_scale_shift_norm,
872
- )
873
- ),
874
- )
875
-
876
- print("({}) in Middle".format(id_attention))
877
- id_attention += 1
878
-
879
- self._feature_size += ch
880
-
881
- self.output_blocks = nn.ModuleList([])
882
- for level, mult in list(enumerate(channel_mult))[::-1]:
883
- for i in range(self.num_res_blocks[level] + 1):
884
- ich = input_block_chans.pop()
885
- layers = [
886
- checkpoint_wrapper_fn(
887
- ResBlock(
888
- ch + ich,
889
- time_embed_dim,
890
- dropout,
891
- out_channels=model_channels * mult,
892
- dims=dims,
893
- use_checkpoint=use_checkpoint,
894
- use_scale_shift_norm=use_scale_shift_norm,
895
- )
896
- )
897
- ]
898
- ch = model_channels * mult
899
- if ds in attention_resolutions:
900
- if num_head_channels == -1:
901
- dim_head = ch // num_heads
902
- else:
903
- num_heads = ch // num_head_channels
904
- dim_head = num_head_channels
905
-
906
- if exists(disable_self_attentions):
907
- disabled_sa = disable_self_attentions[level]
908
- else:
909
- disabled_sa = False
910
-
911
- if (
912
- not exists(num_attention_blocks)
913
- or i < num_attention_blocks[level]
914
- ):
915
- layers.append(
916
- checkpoint_wrapper_fn(
917
- SpatialTransformer(
918
- ch,
919
- num_heads,
920
- dim_head,
921
- depth=transformer_depth[level],
922
- context_dim=context_dim,
923
- disable_self_attn=disabled_sa,
924
- use_linear=use_linear_in_transformer,
925
- attn_type=spatial_transformer_attn_type,
926
- use_checkpoint=use_checkpoint,
927
- image_cross=(id_attention in image_cross_blocks),
928
- rgb_predict=self.rgb_predict,
929
- far=far,
930
- num_samples=num_samples,
931
- add_lora=add_lora,
932
- mode=mode,
933
- average=average,
934
- num_freqs=num_freqs,
935
- use_prev_weights_imp_sample=use_prev_weights_imp_sample,
936
- stratified=stratified,
937
- poscontrol_interval=poscontrol_interval,
938
- imp_sampling_percent=imp_sampling_percent,
939
- near_plane=near_plane,
940
- )
941
- )
942
- )
943
- print("({}) in Decoder".format(id_attention))
944
- id_attention += 1
945
- if level and i == self.num_res_blocks[level]:
946
- out_ch = ch
947
- layers.append(
948
- checkpoint_wrapper_fn(
949
- ResBlock(
950
- ch,
951
- time_embed_dim,
952
- dropout,
953
- out_channels=out_ch,
954
- dims=dims,
955
- use_checkpoint=use_checkpoint,
956
- use_scale_shift_norm=use_scale_shift_norm,
957
- up=True,
958
- )
959
- )
960
- if resblock_updown
961
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
962
- )
963
- ds //= 2
964
- self.output_blocks.append(TimestepEmbedSequential(*layers))
965
- self._feature_size += ch
966
-
967
- self.out = checkpoint_wrapper_fn(
968
- nn.Sequential(
969
- normalization(ch),
970
- nn.SiLU(),
971
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
972
- )
973
- )
974
-
975
- def forward(
976
- self,
977
- x: th.Tensor,
978
- timesteps: Optional[th.Tensor] = None,
979
- context: Optional[th.Tensor] = None,
980
- y: Optional[th.Tensor] = None,
981
- timesteps2: Optional[th.Tensor] = None,
982
- **kwargs,
983
- ):
984
- """
985
- Apply the model to an input batch.
986
- :param x: an [N x C x ...] Tensor of inputs.
987
- :param timesteps: a 1-D batch of timesteps.
988
- :param context: conditioning plugged in via crossattn
989
- :param y: an [N] Tensor of labels, if class-conditional.
990
- :return: an [N x C x ...] Tensor of outputs.
991
- """
992
- with torch.amp.autocast(device_type='cuda', dtype=torch.float32 if self.training else torch.float16):
993
- b = x.size(0)
994
- contextr = None
995
- reference_image = False
996
- pose = None
997
- mask_ref = None
998
- embr = None
999
- fg_mask_list = []
1000
- use_img_cond = True
1001
- alphas_list = []
1002
- predicted_rgb_list = []
1003
-
1004
- if 'pose' in kwargs:
1005
- pose = kwargs['pose']
1006
- if 'mask_ref' in kwargs:
1007
- mask_ref = kwargs['mask_ref']
1008
- if 'input_ref' in kwargs:
1009
- reference_image = True
1010
- contextr = context[b:]
1011
- if y is not None:
1012
- yr = y[b:]
1013
- xr = kwargs['input_ref']
1014
- if xr is not None:
1015
- b, n = xr.shape[:2]
1016
-
1017
- context = context[: b]
1018
-
1019
- if y is not None:
1020
- y = y[:b]
1021
-
1022
- assert (y is not None) == (
1023
- self.num_classes is not None
1024
- ), "must specify y if and only if the model is class-conditional"
1025
-
1026
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
1027
- emb = self.time_embed(t_emb)
1028
-
1029
- if self.num_classes is not None:
1030
- assert y.shape[0] == x.shape[0]
1031
- emb = emb + self.label_emb(y)
1032
-
1033
- h = x
1034
- hr = None
1035
- hs = []
1036
- hrs = []
1037
- use_img_cond = True
1038
-
1039
- if reference_image:
1040
- with torch.no_grad():
1041
- if 'sigmas_ref' in kwargs:
1042
- t_embr = timestep_embedding(kwargs['sigmas_ref'], self.model_channels, repeat_only=False)
1043
- elif timesteps2 is not None:
1044
- t_embr = timestep_embedding(timesteps2, self.model_channels, repeat_only=False)
1045
- else:
1046
- t_embr = timestep_embedding(torch.zeros_like(timesteps), self.model_channels, repeat_only=False)
1047
- embr = (self.time_embed(t_embr)[:, None].expand(-1, xr.size(1), -1)).reshape(b*n, -1)
1048
- if self.num_classes is not None:
1049
- embr = embr + self.label_emb(yr.reshape(b*n, -1))
1050
- hr = rearrange(xr, "b n ... -> (b n) ...", b=b, n=n)
1051
- hr = hr.to(memory_format=torch.channels_last)
1052
-
1053
- for module in self.input_blocks:
1054
- h, hr, fg_mask, weights, alphas, predicted_rgb = module(h, emb, context, hr, embr, contextr, pose, mask_ref=mask_ref, prev_weights=None)
1055
- if fg_mask is not None:
1056
- fg_mask_list += fg_mask
1057
- if alphas is not None:
1058
- alphas_list += alphas
1059
- if predicted_rgb is not None:
1060
- predicted_rgb_list.extend(predicted_rgb)
1061
- hs.append(h)
1062
- hrs.append(hr)
1063
-
1064
- h, hr, fg_mask, weights, alphas, predicted_rgb = self.middle_block(h, emb, context, hr, embr, contextr, pose, mask_ref=mask_ref, prev_weights=None)
1065
-
1066
- if fg_mask is not None:
1067
- fg_mask_list += fg_mask
1068
- if alphas is not None:
1069
- alphas_list += alphas
1070
- if predicted_rgb is not None:
1071
- predicted_rgb_list.extend(predicted_rgb)
1072
-
1073
- for module in self.output_blocks:
1074
- h = th.cat([h, hs.pop()], dim=1)
1075
- if reference_image:
1076
- hr = th.cat([hr, hrs.pop()], dim=1)
1077
- h, hr, fg_mask, weights, alphas, predicted_rgb = module(h, emb, context, hr, embr, contextr, pose, mask_ref=mask_ref, prev_weights=None)
1078
- if fg_mask is not None:
1079
- fg_mask_list += fg_mask
1080
- if alphas is not None:
1081
- alphas_list += alphas
1082
- if predicted_rgb is not None:
1083
- predicted_rgb_list.extend(predicted_rgb)
1084
-
1085
- h = h.type(x.dtype)
1086
- if reference_image:
1087
- hr = hr.type(xr.dtype)
1088
- out = self.out(h)
1089
-
1090
- if use_img_cond:
1091
- return out, fg_mask_list, alphas_list, predicted_rgb_list
1092
- else:
1093
- return out
1094
-
1095
-
1096
- class NoTimeUNetModel(UNetModel):
1097
- def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
1098
- timesteps = th.zeros_like(timesteps)
1099
- return super().forward(x, timesteps, context, y, **kwargs)
1100
-
1101
-
1102
- class EncoderUNetModel(nn.Module):
1103
- """
1104
- The half UNet model with attention and timestep embedding.
1105
- For usage, see UNet.
1106
- """
1107
-
1108
- def __init__(
1109
- self,
1110
- image_size,
1111
- in_channels,
1112
- model_channels,
1113
- out_channels,
1114
- num_res_blocks,
1115
- attention_resolutions,
1116
- dropout=0,
1117
- channel_mult=(1, 2, 4, 8),
1118
- conv_resample=True,
1119
- dims=2,
1120
- use_checkpoint=False,
1121
- use_fp16=False,
1122
- num_heads=1,
1123
- num_head_channels=-1,
1124
- num_heads_upsample=-1,
1125
- use_scale_shift_norm=False,
1126
- resblock_updown=False,
1127
- use_new_attention_order=False,
1128
- pool="adaptive",
1129
- *args,
1130
- **kwargs,
1131
- ):
1132
- super().__init__()
1133
-
1134
- if num_heads_upsample == -1:
1135
- num_heads_upsample = num_heads
1136
-
1137
- self.in_channels = in_channels
1138
- self.model_channels = model_channels
1139
- self.out_channels = out_channels
1140
- self.num_res_blocks = num_res_blocks
1141
- self.attention_resolutions = attention_resolutions
1142
- self.dropout = dropout
1143
- self.channel_mult = channel_mult
1144
- self.conv_resample = conv_resample
1145
- self.use_checkpoint = use_checkpoint
1146
- self.dtype = th.float16 if use_fp16 else th.float32
1147
- self.num_heads = num_heads
1148
- self.num_head_channels = num_head_channels
1149
- self.num_heads_upsample = num_heads_upsample
1150
-
1151
- time_embed_dim = model_channels * 4
1152
- self.time_embed = nn.Sequential(
1153
- linear(model_channels, time_embed_dim),
1154
- nn.SiLU(),
1155
- linear(time_embed_dim, time_embed_dim),
1156
- )
1157
-
1158
- self.input_blocks = nn.ModuleList(
1159
- [
1160
- TimestepEmbedSequential(
1161
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
1162
- )
1163
- ]
1164
- )
1165
- self._feature_size = model_channels
1166
- input_block_chans = [model_channels]
1167
- ch = model_channels
1168
- ds = 1
1169
- for level, mult in enumerate(channel_mult):
1170
- for _ in range(num_res_blocks):
1171
- layers = [
1172
- ResBlock(
1173
- ch,
1174
- time_embed_dim,
1175
- dropout,
1176
- out_channels=mult * model_channels,
1177
- dims=dims,
1178
- use_checkpoint=use_checkpoint,
1179
- use_scale_shift_norm=use_scale_shift_norm,
1180
- )
1181
- ]
1182
- ch = mult * model_channels
1183
- if ds in attention_resolutions:
1184
- layers.append(
1185
- AttentionBlock(
1186
- ch,
1187
- use_checkpoint=use_checkpoint,
1188
- num_heads=num_heads,
1189
- num_head_channels=num_head_channels,
1190
- use_new_attention_order=use_new_attention_order,
1191
- )
1192
- )
1193
- self.input_blocks.append(TimestepEmbedSequential(*layers))
1194
- self._feature_size += ch
1195
- input_block_chans.append(ch)
1196
- if level != len(channel_mult) - 1:
1197
- out_ch = ch
1198
- self.input_blocks.append(
1199
- TimestepEmbedSequential(
1200
- ResBlock(
1201
- ch,
1202
- time_embed_dim,
1203
- dropout,
1204
- out_channels=out_ch,
1205
- dims=dims,
1206
- use_checkpoint=use_checkpoint,
1207
- use_scale_shift_norm=use_scale_shift_norm,
1208
- down=True,
1209
- )
1210
- if resblock_updown
1211
- else Downsample(
1212
- ch, conv_resample, dims=dims, out_channels=out_ch
1213
- )
1214
- )
1215
- )
1216
- ch = out_ch
1217
- input_block_chans.append(ch)
1218
- ds *= 2
1219
- self._feature_size += ch
1220
-
1221
- self.middle_block = TimestepEmbedSequential(
1222
- ResBlock(
1223
- ch,
1224
- time_embed_dim,
1225
- dropout,
1226
- dims=dims,
1227
- use_checkpoint=use_checkpoint,
1228
- use_scale_shift_norm=use_scale_shift_norm,
1229
- ),
1230
- AttentionBlock(
1231
- ch,
1232
- use_checkpoint=use_checkpoint,
1233
- num_heads=num_heads,
1234
- num_head_channels=num_head_channels,
1235
- use_new_attention_order=use_new_attention_order,
1236
- ),
1237
- ResBlock(
1238
- ch,
1239
- time_embed_dim,
1240
- dropout,
1241
- dims=dims,
1242
- use_checkpoint=use_checkpoint,
1243
- use_scale_shift_norm=use_scale_shift_norm,
1244
- ),
1245
- )
1246
- self._feature_size += ch
1247
- self.pool = pool
1248
- if pool == "adaptive":
1249
- self.out = nn.Sequential(
1250
- normalization(ch),
1251
- nn.SiLU(),
1252
- nn.AdaptiveAvgPool2d((1, 1)),
1253
- zero_module(conv_nd(dims, ch, out_channels, 1)),
1254
- nn.Flatten(),
1255
- )
1256
- elif pool == "attention":
1257
- assert num_head_channels != -1
1258
- self.out = nn.Sequential(
1259
- normalization(ch),
1260
- nn.SiLU(),
1261
- AttentionPool2d(
1262
- (image_size // ds), ch, num_head_channels, out_channels
1263
- ),
1264
- )
1265
- elif pool == "spatial":
1266
- self.out = nn.Sequential(
1267
- nn.Linear(self._feature_size, 2048),
1268
- nn.ReLU(),
1269
- nn.Linear(2048, self.out_channels),
1270
- )
1271
- elif pool == "spatial_v2":
1272
- self.out = nn.Sequential(
1273
- nn.Linear(self._feature_size, 2048),
1274
- normalization(2048),
1275
- nn.SiLU(),
1276
- nn.Linear(2048, self.out_channels),
1277
- )
1278
- else:
1279
- raise NotImplementedError(f"Unexpected {pool} pooling")
1280
-
1281
- def convert_to_fp16(self):
1282
- """
1283
- Convert the torso of the model to float16.
1284
- """
1285
- self.input_blocks.apply(convert_module_to_f16)
1286
- self.middle_block.apply(convert_module_to_f16)
1287
-
1288
- def convert_to_fp32(self):
1289
- """
1290
- Convert the torso of the model to float32.
1291
- """
1292
- self.input_blocks.apply(convert_module_to_f32)
1293
- self.middle_block.apply(convert_module_to_f32)
1294
-
1295
- def forward(self, x, timesteps):
1296
- """
1297
- Apply the model to an input batch.
1298
- :param x: an [N x C x ...] Tensor of inputs.
1299
- :param timesteps: a 1-D batch of timesteps.
1300
- :return: an [N x K] Tensor of outputs.
1301
- """
1302
- emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1303
-
1304
- results = []
1305
- # h = x.type(self.dtype)
1306
- h = x
1307
- for module in self.input_blocks:
1308
- h = module(h, emb)
1309
- if self.pool.startswith("spatial"):
1310
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
1311
- h = self.middle_block(h, emb)
1312
- if self.pool.startswith("spatial"):
1313
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
1314
- h = th.cat(results, axis=-1)
1315
- return self.out(h)
1316
- else:
1317
- h = h.type(x.dtype)
1318
- return self.out(h)
1319
-
1320
-
1321
- if __name__ == "__main__":
1322
-
1323
- class Dummy(nn.Module):
1324
- def __init__(self, in_channels=3, model_channels=64):
1325
- super().__init__()
1326
- self.input_blocks = nn.ModuleList(
1327
- [
1328
- TimestepEmbedSequential(
1329
- conv_nd(2, in_channels, model_channels, 3, padding=1)
1330
- )
1331
- ]
1332
- )
1333
-
1334
- model = UNetModel(
1335
- use_checkpoint=True,
1336
- image_size=64,
1337
- in_channels=4,
1338
- out_channels=4,
1339
- model_channels=128,
1340
- attention_resolutions=[4, 2],
1341
- num_res_blocks=2,
1342
- channel_mult=[1, 2, 4],
1343
- num_head_channels=64,
1344
- use_spatial_transformer=False,
1345
- use_linear_in_transformer=True,
1346
- transformer_depth=1,
1347
- legacy=False,
1348
- ).cuda()
1349
- x = th.randn(11, 4, 64, 64).cuda()
1350
- t = th.randint(low=0, high=10, size=(11,), device="cuda")
1351
- o = model(x, t)
1352
- print("done.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/sampling.py DELETED
@@ -1,465 +0,0 @@
1
- """
2
- Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
3
- """
4
-
5
-
6
- from typing import Dict, Union
7
-
8
- import torch
9
- from omegaconf import ListConfig, OmegaConf
10
- from tqdm import tqdm
11
-
12
- from ...modules.diffusionmodules.sampling_utils import (
13
- get_ancestral_step,
14
- linear_multistep_coeff,
15
- to_d,
16
- to_neg_log_sigma,
17
- to_sigma,
18
- )
19
- from ...util import append_dims, default, instantiate_from_config
20
-
21
- DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
22
-
23
-
24
- class BaseDiffusionSampler:
25
- def __init__(
26
- self,
27
- discretization_config: Union[Dict, ListConfig, OmegaConf],
28
- num_steps: Union[int, None] = None,
29
- guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
30
- verbose: bool = False,
31
- device: str = "cuda",
32
- ):
33
- self.num_steps = num_steps
34
- self.discretization = instantiate_from_config(discretization_config)
35
- self.guider = instantiate_from_config(
36
- default(
37
- guider_config,
38
- DEFAULT_GUIDER,
39
- )
40
- )
41
- self.verbose = verbose
42
- self.device = device
43
-
44
- def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
45
- sigmas = self.discretization(
46
- self.num_steps if num_steps is None else num_steps, device=self.device
47
- )
48
- uc = default(uc, cond)
49
-
50
- x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
51
- num_sigmas = len(sigmas)
52
-
53
- s_in = x.new_ones([x.shape[0]])
54
-
55
- return x, s_in, sigmas, num_sigmas, cond, uc
56
-
57
- def denoise(self, x, denoiser, sigma, cond, uc):
58
- denoised, _, _, rgb_list = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
59
- denoised = self.guider(denoised, sigma)
60
- return denoised, rgb_list
61
-
62
- def get_sigma_gen(self, num_sigmas):
63
- sigma_generator = range(num_sigmas - 1)
64
- if self.verbose:
65
- print("#" * 30, " Sampling setting ", "#" * 30)
66
- print(f"Sampler: {self.__class__.__name__}")
67
- print(f"Discretization: {self.discretization.__class__.__name__}")
68
- print(f"Guider: {self.guider.__class__.__name__}")
69
- sigma_generator = tqdm(
70
- sigma_generator,
71
- total=num_sigmas,
72
- desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
73
- )
74
- return sigma_generator
75
-
76
-
77
- class SingleStepDiffusionSampler(BaseDiffusionSampler):
78
- def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
79
- raise NotImplementedError
80
-
81
- def euler_step(self, x, d, dt):
82
- return x + dt * d
83
-
84
-
85
- class EDMSampler(SingleStepDiffusionSampler):
86
- def __init__(
87
- self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
88
- ):
89
- super().__init__(*args, **kwargs)
90
-
91
- self.s_churn = s_churn
92
- self.s_tmin = s_tmin
93
- self.s_tmax = s_tmax
94
- self.s_noise = s_noise
95
-
96
- def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
97
- sigma_hat = sigma * (gamma + 1.0)
98
- if gamma > 0:
99
- eps = torch.randn_like(x) * self.s_noise
100
- x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
101
-
102
- denoised, rgb_list = self.denoise(x, denoiser, sigma_hat, cond, uc)
103
- d = to_d(x, sigma_hat, denoised)
104
- dt = append_dims(next_sigma - sigma_hat, x.ndim)
105
-
106
- euler_step = self.euler_step(x, d, dt)
107
- x = self.possible_correction_step(
108
- euler_step, x, d, dt, next_sigma, denoiser, cond, uc
109
- )
110
- return x, rgb_list
111
-
112
- def __call__(self, denoiser, x, cond, uc=None, num_steps=None, mask=None, init_im=None):
113
- return self.forward(denoiser, x, cond, uc=uc, num_steps=num_steps, mask=mask, init_im=init_im)
114
-
115
- def forward(self, denoiser, x, cond, uc=None, num_steps=None, mask=None, init_im=None):
116
- x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
117
- x, cond, uc, num_steps
118
- )
119
- for i in self.get_sigma_gen(num_sigmas):
120
- gamma = (
121
- min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
122
- if self.s_tmin <= sigmas[i] <= self.s_tmax
123
- else 0.0
124
- )
125
- x_new, rgb_list = self.sampler_step(
126
- s_in * sigmas[i],
127
- s_in * sigmas[i + 1],
128
- denoiser,
129
- x,
130
- cond,
131
- uc,
132
- gamma,
133
- )
134
- x = x_new
135
-
136
- return x, rgb_list
137
-
138
-
139
- def get_views(panorama_height, panorama_width, window_size=64, stride=48):
140
- # panorama_height /= 8
141
- # panorama_width /= 8
142
- num_blocks_height = (panorama_height - window_size) // stride + 1
143
- num_blocks_width = (panorama_width - window_size) // stride + 1
144
- total_num_blocks = int(num_blocks_height * num_blocks_width)
145
- views = []
146
- for i in range(total_num_blocks):
147
- h_start = int((i // num_blocks_width) * stride)
148
- h_end = h_start + window_size
149
- w_start = int((i % num_blocks_width) * stride)
150
- w_end = w_start + window_size
151
- views.append((h_start, h_end, w_start, w_end))
152
- return views
153
-
154
-
155
- class EDMMultidiffusionSampler(SingleStepDiffusionSampler):
156
- def __init__(
157
- self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
158
- ):
159
- super().__init__(*args, **kwargs)
160
-
161
- self.s_churn = s_churn
162
- self.s_tmin = s_tmin
163
- self.s_tmax = s_tmax
164
- self.s_noise = s_noise
165
-
166
- def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
167
- sigma_hat = sigma * (gamma + 1.0)
168
- if gamma > 0:
169
- eps = torch.randn_like(x) * self.s_noise
170
- x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
171
-
172
- denoised, rgb_list = self.denoise(x, denoiser, sigma_hat, cond, uc)
173
- d = to_d(x, sigma_hat, denoised)
174
- dt = append_dims(next_sigma - sigma_hat, x.ndim)
175
-
176
- euler_step = self.euler_step(x, d, dt)
177
- x = self.possible_correction_step(
178
- euler_step, x, d, dt, next_sigma, denoiser, cond, uc
179
- )
180
- return x, rgb_list
181
-
182
- def __call__(self, denoiser, model, x, cond, uc=None, num_steps=None, multikwargs=None):
183
- return self.forward(denoiser, model, x, cond, uc=uc, num_steps=num_steps, multikwargs=multikwargs)
184
-
185
- def forward(self, denoiser, model, x, cond, uc=None, num_steps=None, multikwargs=None):
186
- views = get_views(x.shape[-2], 48*(len(multikwargs)+1))
187
- shape = x.shape
188
- x = torch.randn(shape[0], shape[1], shape[2], 48*(len(multikwargs)+1)).to(x.device)
189
- count = torch.zeros_like(x, device=x.device)
190
- value = torch.zeros_like(x, device=x.device)
191
-
192
- x, s_in, sigmas, num_sigmas, cond_, uc = self.prepare_sampling_loop(
193
- x, cond[0], uc[0], num_steps
194
- )
195
-
196
- for i in self.get_sigma_gen(num_sigmas):
197
- gamma = (
198
- min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
199
- if self.s_tmin <= sigmas[i] <= self.s_tmax
200
- else 0.0
201
- )
202
- count.zero_()
203
- value.zero_()
204
-
205
- for j, (h_start, h_end, w_start, w_end) in enumerate(views):
206
- # TODO we can support batches, and pass multiple views at once to the unet
207
- latent_view = x[:, :, h_start:h_end, w_start:w_end]
208
- # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
209
- kwargs = {'pose': multikwargs[j]['pose'], 'mask_ref':None, 'drop_im':j}
210
- x_new, rgb_list = self.sampler_step(
211
- s_in * sigmas[i],
212
- s_in * sigmas[i + 1],
213
- lambda input, sigma, c: denoiser(
214
- model, input, sigma, c, **kwargs
215
- ),
216
- latent_view,
217
- cond[j],
218
- uc,
219
- gamma,
220
- )
221
- # compute the denoising step with the reference model
222
- value[:, :, h_start:h_end, w_start:w_end] += x_new
223
- count[:, :, h_start:h_end, w_start:w_end] += 1
224
-
225
- # take the MultiDiffusion step
226
- x = torch.where(count > 0, value / count, value)
227
-
228
- return x, rgb_list
229
-
230
- def possible_correction_step(
231
- self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
232
- ):
233
- return euler_step
234
-
235
-
236
- class AncestralSampler(SingleStepDiffusionSampler):
237
- def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
238
- super().__init__(*args, **kwargs)
239
-
240
- self.eta = eta
241
- self.s_noise = s_noise
242
- self.noise_sampler = lambda x: torch.randn_like(x)
243
-
244
- def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
245
- d = to_d(x, sigma, denoised)
246
- dt = append_dims(sigma_down - sigma, x.ndim)
247
-
248
- return self.euler_step(x, d, dt)
249
-
250
- def ancestral_step(self, x, sigma, next_sigma, sigma_up):
251
- x = torch.where(
252
- append_dims(next_sigma, x.ndim) > 0.0,
253
- x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
254
- x,
255
- )
256
- return x
257
-
258
- def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
259
- x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
260
- x, cond, uc, num_steps
261
- )
262
-
263
- for i in self.get_sigma_gen(num_sigmas):
264
- x = self.sampler_step(
265
- s_in * sigmas[i],
266
- s_in * sigmas[i + 1],
267
- denoiser,
268
- x,
269
- cond,
270
- uc,
271
- )
272
-
273
- return x
274
-
275
-
276
- class LinearMultistepSampler(BaseDiffusionSampler):
277
- def __init__(
278
- self,
279
- order=4,
280
- *args,
281
- **kwargs,
282
- ):
283
- super().__init__(*args, **kwargs)
284
-
285
- self.order = order
286
-
287
- def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
288
- x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
289
- x, cond, uc, num_steps
290
- )
291
-
292
- ds = []
293
- sigmas_cpu = sigmas.detach().cpu().numpy()
294
- for i in self.get_sigma_gen(num_sigmas):
295
- sigma = s_in * sigmas[i]
296
- denoised, _ = denoiser(
297
- *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
298
- )
299
- denoised = self.guider(denoised, sigma)
300
- d = to_d(x, sigma, denoised)
301
- ds.append(d)
302
- if len(ds) > self.order:
303
- ds.pop(0)
304
- cur_order = min(i + 1, self.order)
305
- coeffs = [
306
- linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
307
- for j in range(cur_order)
308
- ]
309
- x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
310
-
311
- return x
312
-
313
-
314
- class EulerEDMSampler(EDMSampler):
315
- def possible_correction_step(
316
- self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
317
- ):
318
- return euler_step
319
-
320
-
321
- class HeunEDMSampler(EDMSampler):
322
- def possible_correction_step(
323
- self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
324
- ):
325
- if torch.sum(next_sigma) < 1e-14:
326
- # Save a network evaluation if all noise levels are 0
327
- return euler_step
328
- else:
329
- denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
330
- d_new = to_d(euler_step, next_sigma, denoised)
331
- d_prime = (d + d_new) / 2.0
332
-
333
- # apply correction if noise level is not 0
334
- x = torch.where(
335
- append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
336
- )
337
- return x
338
-
339
-
340
- class EulerAncestralSampler(AncestralSampler):
341
- def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
342
- sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
343
- denoised = self.denoise(x, denoiser, sigma, cond, uc)
344
- x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
345
- x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
346
-
347
- return x
348
-
349
-
350
- class DPMPP2SAncestralSampler(AncestralSampler):
351
- def get_variables(self, sigma, sigma_down):
352
- t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
353
- h = t_next - t
354
- s = t + 0.5 * h
355
- return h, s, t, t_next
356
-
357
- def get_mult(self, h, s, t, t_next):
358
- mult1 = to_sigma(s) / to_sigma(t)
359
- mult2 = (-0.5 * h).expm1()
360
- mult3 = to_sigma(t_next) / to_sigma(t)
361
- mult4 = (-h).expm1()
362
-
363
- return mult1, mult2, mult3, mult4
364
-
365
- def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
366
- sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
367
- denoised = self.denoise(x, denoiser, sigma, cond, uc)
368
- x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
369
-
370
- if torch.sum(sigma_down) < 1e-14:
371
- # Save a network evaluation if all noise levels are 0
372
- x = x_euler
373
- else:
374
- h, s, t, t_next = self.get_variables(sigma, sigma_down)
375
- mult = [
376
- append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
377
- ]
378
-
379
- x2 = mult[0] * x - mult[1] * denoised
380
- denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
381
- x_dpmpp2s = mult[2] * x - mult[3] * denoised2
382
-
383
- # apply correction if noise level is not 0
384
- x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
385
-
386
- x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
387
- return x
388
-
389
-
390
- class DPMPP2MSampler(BaseDiffusionSampler):
391
- def get_variables(self, sigma, next_sigma, previous_sigma=None):
392
- t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
393
- h = t_next - t
394
-
395
- if previous_sigma is not None:
396
- h_last = t - to_neg_log_sigma(previous_sigma)
397
- r = h_last / h
398
- return h, r, t, t_next
399
- else:
400
- return h, None, t, t_next
401
-
402
- def get_mult(self, h, r, t, t_next, previous_sigma):
403
- mult1 = to_sigma(t_next) / to_sigma(t)
404
- mult2 = (-h).expm1()
405
-
406
- if previous_sigma is not None:
407
- mult3 = 1 + 1 / (2 * r)
408
- mult4 = 1 / (2 * r)
409
- return mult1, mult2, mult3, mult4
410
- else:
411
- return mult1, mult2
412
-
413
- def sampler_step(
414
- self,
415
- old_denoised,
416
- previous_sigma,
417
- sigma,
418
- next_sigma,
419
- denoiser,
420
- x,
421
- cond,
422
- uc=None,
423
- ):
424
- denoised = self.denoise(x, denoiser, sigma, cond, uc)
425
-
426
- h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
427
- mult = [
428
- append_dims(mult, x.ndim)
429
- for mult in self.get_mult(h, r, t, t_next, previous_sigma)
430
- ]
431
-
432
- x_standard = mult[0] * x - mult[1] * denoised
433
- if old_denoised is None or torch.sum(next_sigma) < 1e-14:
434
- # Save a network evaluation if all noise levels are 0 or on the first step
435
- return x_standard, denoised
436
- else:
437
- denoised_d = mult[2] * denoised - mult[3] * old_denoised
438
- x_advanced = mult[0] * x - mult[1] * denoised_d
439
-
440
- # apply correction if noise level is not 0 and not first step
441
- x = torch.where(
442
- append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
443
- )
444
-
445
- return x, denoised
446
-
447
- def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
448
- x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
449
- x, cond, uc, num_steps
450
- )
451
-
452
- old_denoised = None
453
- for i in self.get_sigma_gen(num_sigmas):
454
- x, old_denoised = self.sampler_step(
455
- old_denoised,
456
- None if i == 0 else s_in * sigmas[i - 1],
457
- s_in * sigmas[i],
458
- s_in * sigmas[i + 1],
459
- denoiser,
460
- x,
461
- cond,
462
- uc=uc,
463
- )
464
-
465
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/sampling_utils.py DELETED
@@ -1,48 +0,0 @@
1
- import torch
2
- from scipy import integrate
3
-
4
- from ...util import append_dims
5
-
6
-
7
- class NoDynamicThresholding:
8
- def __call__(self, uncond, cond, scale):
9
- return uncond + scale * (cond - uncond)
10
-
11
-
12
- def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
13
- if order - 1 > i:
14
- raise ValueError(f"Order {order} too high for step {i}")
15
-
16
- def fn(tau):
17
- prod = 1.0
18
- for k in range(order):
19
- if j == k:
20
- continue
21
- prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
22
- return prod
23
-
24
- return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
25
-
26
-
27
- def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
28
- if not eta:
29
- return sigma_to, 0.0
30
- sigma_up = torch.minimum(
31
- sigma_to,
32
- eta
33
- * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
34
- )
35
- sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
36
- return sigma_down, sigma_up
37
-
38
-
39
- def to_d(x, sigma, denoised):
40
- return (x - denoised) / append_dims(sigma, x.ndim)
41
-
42
-
43
- def to_neg_log_sigma(sigma):
44
- return sigma.log().neg()
45
-
46
-
47
- def to_sigma(neg_log_sigma):
48
- return neg_log_sigma.neg().exp()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/sigma_sampling.py DELETED
@@ -1,54 +0,0 @@
1
- import torch
2
-
3
- from ...util import default, instantiate_from_config
4
-
5
-
6
- class EDMSampling:
7
- def __init__(self, p_mean=-1.2, p_std=1.2):
8
- self.p_mean = p_mean
9
- self.p_std = p_std
10
-
11
- def __call__(self, n_samples, rand=None):
12
- log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
13
- return log_sigma.exp()
14
-
15
-
16
- class DiscreteSampling:
17
- def __init__(self, discretization_config, num_idx, num_idx_start=0, do_append_zero=False, flip=True):
18
- self.num_idx = num_idx
19
- self.num_idx_start = num_idx_start
20
- self.sigmas = instantiate_from_config(discretization_config)(
21
- num_idx, do_append_zero=do_append_zero, flip=flip
22
- )
23
-
24
- def idx_to_sigma(self, idx):
25
- return self.sigmas[idx]
26
-
27
- def __call__(self, n_samples, rand=None):
28
- idx = default(
29
- rand,
30
- torch.randint(self.num_idx_start, self.num_idx, (n_samples,)),
31
- )
32
- return self.idx_to_sigma(idx)
33
-
34
-
35
- class CubicSampling:
36
- def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True):
37
- self.num_idx = num_idx
38
- self.sigmas = instantiate_from_config(discretization_config)(
39
- num_idx, do_append_zero=do_append_zero, flip=flip
40
- )
41
-
42
- def idx_to_sigma(self, idx):
43
- return self.sigmas[idx]
44
-
45
- def __call__(self, n_samples, rand=None):
46
- t = torch.rand((n_samples,))
47
- t = (1 - t ** 3) * (self.num_idx-1)
48
- t = t.long()
49
- idx = default(
50
- rand,
51
- t,
52
- )
53
- return self.idx_to_sigma(idx)
54
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/util.py DELETED
@@ -1,344 +0,0 @@
1
- """
2
- adopted from
3
- https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
4
- and
5
- https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
6
- and
7
- https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
8
-
9
- thanks!
10
- """
11
-
12
- import math
13
-
14
- import torch
15
- import torch.nn as nn
16
- from einops import repeat
17
-
18
-
19
- def make_beta_schedule(
20
- schedule,
21
- n_timestep,
22
- linear_start=1e-4,
23
- linear_end=2e-2,
24
- ):
25
- if schedule == "linear":
26
- betas = (
27
- torch.linspace(
28
- linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
29
- )
30
- ** 2
31
- )
32
- return betas.numpy()
33
-
34
-
35
- def extract_into_tensor(a, t, x_shape):
36
- b, *_ = t.shape
37
- out = a.gather(-1, t)
38
- return out.reshape(b, *((1,) * (len(x_shape) - 1)))
39
-
40
-
41
- def mixed_checkpoint(func, inputs: dict, params, flag):
42
- """
43
- Evaluate a function without caching intermediate activations, allowing for
44
- reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
45
- borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
46
- it also works with non-tensor inputs
47
- :param func: the function to evaluate.
48
- :param inputs: the argument dictionary to pass to `func`.
49
- :param params: a sequence of parameters `func` depends on but does not
50
- explicitly take as arguments.
51
- :param flag: if False, disable gradient checkpointing.
52
- """
53
- if flag:
54
- tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
55
- tensor_inputs = [
56
- inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)
57
- ]
58
- non_tensor_keys = [
59
- key for key in inputs if not isinstance(inputs[key], torch.Tensor)
60
- ]
61
- non_tensor_inputs = [
62
- inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
63
- ]
64
- args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
65
- return MixedCheckpointFunction.apply(
66
- func,
67
- len(tensor_inputs),
68
- len(non_tensor_inputs),
69
- tensor_keys,
70
- non_tensor_keys,
71
- *args,
72
- )
73
- else:
74
- return func(**inputs)
75
-
76
-
77
- class MixedCheckpointFunction(torch.autograd.Function):
78
- @staticmethod
79
- def forward(
80
- ctx,
81
- run_function,
82
- length_tensors,
83
- length_non_tensors,
84
- tensor_keys,
85
- non_tensor_keys,
86
- *args,
87
- ):
88
- ctx.end_tensors = length_tensors
89
- ctx.end_non_tensors = length_tensors + length_non_tensors
90
- ctx.gpu_autocast_kwargs = {
91
- "enabled": torch.is_autocast_enabled(),
92
- "dtype": torch.get_autocast_gpu_dtype(),
93
- "cache_enabled": torch.is_autocast_cache_enabled(),
94
- }
95
- assert (
96
- len(tensor_keys) == length_tensors
97
- and len(non_tensor_keys) == length_non_tensors
98
- )
99
-
100
- ctx.input_tensors = {
101
- key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
102
- }
103
- ctx.input_non_tensors = {
104
- key: val
105
- for (key, val) in zip(
106
- non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
107
- )
108
- }
109
- ctx.run_function = run_function
110
- ctx.input_params = list(args[ctx.end_non_tensors :])
111
-
112
- with torch.no_grad():
113
- output_tensors = ctx.run_function(
114
- **ctx.input_tensors, **ctx.input_non_tensors
115
- )
116
- return output_tensors
117
-
118
- @staticmethod
119
- def backward(ctx, *output_grads):
120
- # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
121
- ctx.input_tensors = {
122
- key: ctx.input_tensors[key].detach().requires_grad_(True)
123
- for key in ctx.input_tensors
124
- }
125
-
126
- with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
127
- # Fixes a bug where the first op in run_function modifies the
128
- # Tensor storage in place, which is not allowed for detach()'d
129
- # Tensors.
130
- shallow_copies = {
131
- key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
132
- for key in ctx.input_tensors
133
- }
134
- # shallow_copies.update(additional_args)
135
- output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
136
- input_grads = torch.autograd.grad(
137
- output_tensors,
138
- list(ctx.input_tensors.values()) + ctx.input_params,
139
- output_grads,
140
- allow_unused=True,
141
- )
142
- del ctx.input_tensors
143
- del ctx.input_params
144
- del output_tensors
145
- return (
146
- (None, None, None, None, None)
147
- + input_grads[: ctx.end_tensors]
148
- + (None,) * (ctx.end_non_tensors - ctx.end_tensors)
149
- + input_grads[ctx.end_tensors :]
150
- )
151
-
152
-
153
- def checkpoint(func, inputs, params, flag):
154
- """
155
- Evaluate a function without caching intermediate activations, allowing for
156
- reduced memory at the expense of extra compute in the backward pass.
157
- :param func: the function to evaluate.
158
- :param inputs: the argument sequence to pass to `func`.
159
- :param params: a sequence of parameters `func` depends on but does not
160
- explicitly take as arguments.
161
- :param flag: if False, disable gradient checkpointing.
162
- """
163
- if flag:
164
- args = tuple(inputs) + tuple(params)
165
- return CheckpointFunction.apply(func, len(inputs), *args)
166
- else:
167
- return func(*inputs)
168
-
169
-
170
- class CheckpointFunction(torch.autograd.Function):
171
- @staticmethod
172
- def forward(ctx, run_function, length, *args):
173
- ctx.run_function = run_function
174
- ctx.input_tensors = list(args[:length])
175
- ctx.input_params = list(args[length:])
176
- ctx.gpu_autocast_kwargs = {
177
- "enabled": torch.is_autocast_enabled(),
178
- "dtype": torch.get_autocast_gpu_dtype(),
179
- "cache_enabled": torch.is_autocast_cache_enabled(),
180
- }
181
- with torch.no_grad():
182
- output_tensors = ctx.run_function(*ctx.input_tensors)
183
- return output_tensors
184
-
185
- @staticmethod
186
- def backward(ctx, *output_grads):
187
- ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
188
- with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
189
- # Fixes a bug where the first op in run_function modifies the
190
- # Tensor storage in place, which is not allowed for detach()'d
191
- # Tensors.
192
- shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
193
- output_tensors = ctx.run_function(*shallow_copies)
194
- input_grads = torch.autograd.grad(
195
- output_tensors,
196
- ctx.input_tensors + ctx.input_params,
197
- output_grads,
198
- allow_unused=True,
199
- )
200
- del ctx.input_tensors
201
- del ctx.input_params
202
- del output_tensors
203
- return (None, None) + input_grads
204
-
205
-
206
- def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
207
- """
208
- Create sinusoidal timestep embeddings.
209
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
210
- These may be fractional.
211
- :param dim: the dimension of the output.
212
- :param max_period: controls the minimum frequency of the embeddings.
213
- :return: an [N x dim] Tensor of positional embeddings.
214
- """
215
- if not repeat_only:
216
- half = dim // 2
217
- freqs = torch.exp(
218
- -math.log(max_period)
219
- * torch.arange(start=0, end=half, dtype=torch.float32)
220
- / half
221
- ).to(device=timesteps.device)
222
- args = timesteps[:, None].float() * freqs[None]
223
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
224
- if dim % 2:
225
- embedding = torch.cat(
226
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
227
- )
228
- else:
229
- embedding = repeat(timesteps, "b -> b d", d=dim)
230
- return embedding
231
-
232
-
233
- def timestep_embedding_pose(timesteps, dim, max_period=10000, repeat_only=False):
234
- """
235
- Create sinusoidal timestep embeddings.
236
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
237
- These may be fractional.
238
- :param dim: the dimension of the output.
239
- :param max_period: controls the minimum frequency of the embeddings.
240
- :return: an [N x dim] Tensor of positional embeddings.
241
- """
242
- if not repeat_only:
243
- half = dim // 2
244
- freqs = torch.exp(
245
- -math.log(max_period)
246
- * torch.arange(start=0, end=half, dtype=torch.float32)
247
- / half
248
- ).to(device=timesteps.device)
249
- args = timesteps[:, None].float() * freqs[None]
250
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
251
- if dim % 2:
252
- embedding = torch.cat(
253
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
254
- )
255
- else:
256
- embedding = repeat(timesteps, "b -> b d", d=dim)
257
- return embedding
258
-
259
-
260
- def zero_module(module):
261
- """
262
- Zero out the parameters of a module and return it.
263
- """
264
- for p in module.parameters():
265
- p.detach().zero_()
266
- return module
267
-
268
-
269
- def ones_module(module):
270
- """
271
- Zero out the parameters of a module and return it.
272
- """
273
- for p in module.parameters():
274
- p.detach().data.fill_(1.)
275
- return module
276
-
277
-
278
- def scale_module(module, scale):
279
- """
280
- Scale the parameters of a module and return it.
281
- """
282
- for p in module.parameters():
283
- p.detach().mul_(scale)
284
- return module
285
-
286
-
287
- def mean_flat(tensor):
288
- """
289
- Take the mean over all non-batch dimensions.
290
- """
291
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
292
-
293
-
294
- def normalization(channels):
295
- """
296
- Make a standard normalization layer.
297
- :param channels: number of input channels.
298
- :return: an nn.Module for normalization.
299
- """
300
- return GroupNorm32(32, channels)
301
-
302
-
303
- # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
304
- class SiLU(nn.Module):
305
- def forward(self, x):
306
- return x * torch.sigmoid(x)
307
-
308
-
309
- class GroupNorm32(nn.GroupNorm):
310
- def forward(self, x):
311
- return super().forward(x.float()).type(x.dtype)
312
-
313
-
314
- def conv_nd(dims, *args, **kwargs):
315
- """
316
- Create a 1D, 2D, or 3D convolution module.
317
- """
318
- if dims == 1:
319
- return nn.Conv1d(*args, **kwargs)
320
- elif dims == 2:
321
- return nn.Conv2d(*args, **kwargs)
322
- elif dims == 3:
323
- return nn.Conv3d(*args, **kwargs)
324
- raise ValueError(f"unsupported dimensions: {dims}")
325
-
326
-
327
- def linear(*args, **kwargs):
328
- """
329
- Create a linear module.
330
- """
331
- return nn.Linear(*args, **kwargs)
332
-
333
-
334
- def avg_pool_nd(dims, *args, **kwargs):
335
- """
336
- Create a 1D, 2D, or 3D average pooling module.
337
- """
338
- if dims == 1:
339
- return nn.AvgPool1d(*args, **kwargs)
340
- elif dims == 2:
341
- return nn.AvgPool2d(*args, **kwargs)
342
- elif dims == 3:
343
- return nn.AvgPool3d(*args, **kwargs)
344
- raise ValueError(f"unsupported dimensions: {dims}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/wrappers.py DELETED
@@ -1,35 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from packaging import version
4
-
5
- OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
6
-
7
-
8
- class IdentityWrapper(nn.Module):
9
- def __init__(self, diffusion_model, compile_model: bool = False):
10
- super().__init__()
11
- compile = (
12
- torch.compile
13
- if (version.parse(torch.__version__) >= version.parse("2.0.0"))
14
- and compile_model
15
- else lambda x: x
16
- )
17
- self.diffusion_model = compile(diffusion_model)
18
-
19
- def forward(self, *args, **kwargs):
20
- return self.diffusion_model(*args, **kwargs)
21
-
22
-
23
- class OpenAIWrapper(IdentityWrapper):
24
- def forward(
25
- self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
26
- ) -> torch.Tensor:
27
- x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
28
- return self.diffusion_model(
29
- x,
30
- timesteps=t,
31
- context=c.get("crossattn", None),
32
- y=c.get("vector", None),
33
- **kwargs
34
- )
35
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/distributions/__init__.py DELETED
File without changes
sgm/modules/distributions/distributions.py DELETED
@@ -1,102 +0,0 @@
1
- import numpy as np
2
- import torch
3
-
4
-
5
- class AbstractDistribution:
6
- def sample(self):
7
- raise NotImplementedError()
8
-
9
- def mode(self):
10
- raise NotImplementedError()
11
-
12
-
13
- class DiracDistribution(AbstractDistribution):
14
- def __init__(self, value):
15
- self.value = value
16
-
17
- def sample(self):
18
- return self.value
19
-
20
- def mode(self):
21
- return self.value
22
-
23
-
24
- class DiagonalGaussianDistribution(object):
25
- def __init__(self, parameters, deterministic=False):
26
- self.parameters = parameters
27
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
- self.deterministic = deterministic
30
- self.std = torch.exp(0.5 * self.logvar)
31
- self.var = torch.exp(self.logvar)
32
- if self.deterministic:
33
- self.var = self.std = torch.zeros_like(self.mean).to(
34
- device=self.parameters.device
35
- )
36
-
37
- def sample(self):
38
- x = self.mean + self.std * torch.randn(self.mean.shape).to(
39
- device=self.parameters.device
40
- )
41
- return x
42
-
43
- def kl(self, other=None):
44
- if self.deterministic:
45
- return torch.Tensor([0.0])
46
- else:
47
- if other is None:
48
- return 0.5 * torch.sum(
49
- torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50
- dim=[1, 2, 3],
51
- )
52
- else:
53
- return 0.5 * torch.sum(
54
- torch.pow(self.mean - other.mean, 2) / other.var
55
- + self.var / other.var
56
- - 1.0
57
- - self.logvar
58
- + other.logvar,
59
- dim=[1, 2, 3],
60
- )
61
-
62
- def nll(self, sample, dims=[1, 2, 3]):
63
- if self.deterministic:
64
- return torch.Tensor([0.0])
65
- logtwopi = np.log(2.0 * np.pi)
66
- return 0.5 * torch.sum(
67
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68
- dim=dims,
69
- )
70
-
71
- def mode(self):
72
- return self.mean
73
-
74
-
75
- def normal_kl(mean1, logvar1, mean2, logvar2):
76
- """
77
- source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78
- Compute the KL divergence between two gaussians.
79
- Shapes are automatically broadcasted, so batches can be compared to
80
- scalars, among other use cases.
81
- """
82
- tensor = None
83
- for obj in (mean1, logvar1, mean2, logvar2):
84
- if isinstance(obj, torch.Tensor):
85
- tensor = obj
86
- break
87
- assert tensor is not None, "at least one argument must be a Tensor"
88
-
89
- # Force variances to be Tensors. Broadcasting helps convert scalars to
90
- # Tensors, but it does not work for torch.exp().
91
- logvar1, logvar2 = [
92
- x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93
- for x in (logvar1, logvar2)
94
- ]
95
-
96
- return 0.5 * (
97
- -1.0
98
- + logvar2
99
- - logvar1
100
- + torch.exp(logvar1 - logvar2)
101
- + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/distributions/distributions1.py DELETED
@@ -1,102 +0,0 @@
1
- import torch
2
- import numpy as np
3
-
4
-
5
- class AbstractDistribution:
6
- def sample(self):
7
- raise NotImplementedError()
8
-
9
- def mode(self):
10
- raise NotImplementedError()
11
-
12
-
13
- class DiracDistribution(AbstractDistribution):
14
- def __init__(self, value):
15
- self.value = value
16
-
17
- def sample(self):
18
- return self.value
19
-
20
- def mode(self):
21
- return self.value
22
-
23
-
24
- class DiagonalGaussianDistribution(object):
25
- def __init__(self, parameters, deterministic=False):
26
- self.parameters = parameters
27
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
- self.deterministic = deterministic
30
- self.std = torch.exp(0.5 * self.logvar)
31
- self.var = torch.exp(self.logvar)
32
- if self.deterministic:
33
- self.var = self.std = torch.zeros_like(self.mean).to(
34
- device=self.parameters.device
35
- )
36
-
37
- def sample(self):
38
- x = self.mean + self.std * torch.randn(self.mean.shape).to(
39
- device=self.parameters.device
40
- )
41
- return x
42
-
43
- def kl(self, other=None):
44
- if self.deterministic:
45
- return torch.Tensor([0.0])
46
- else:
47
- if other is None:
48
- return 0.5 * torch.sum(
49
- torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50
- dim=[1, 2, 3],
51
- )
52
- else:
53
- return 0.5 * torch.sum(
54
- torch.pow(self.mean - other.mean, 2) / other.var
55
- + self.var / other.var
56
- - 1.0
57
- - self.logvar
58
- + other.logvar,
59
- dim=[1, 2, 3],
60
- )
61
-
62
- def nll(self, sample, dims=[1, 2, 3]):
63
- if self.deterministic:
64
- return torch.Tensor([0.0])
65
- logtwopi = np.log(2.0 * np.pi)
66
- return 0.5 * torch.sum(
67
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68
- dim=dims,
69
- )
70
-
71
- def mode(self):
72
- return self.mean
73
-
74
-
75
- def normal_kl(mean1, logvar1, mean2, logvar2):
76
- """
77
- source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78
- Compute the KL divergence between two gaussians.
79
- Shapes are automatically broadcasted, so batches can be compared to
80
- scalars, among other use cases.
81
- """
82
- tensor = None
83
- for obj in (mean1, logvar1, mean2, logvar2):
84
- if isinstance(obj, torch.Tensor):
85
- tensor = obj
86
- break
87
- assert tensor is not None, "at least one argument must be a Tensor"
88
-
89
- # Force variances to be Tensors. Broadcasting helps convert scalars to
90
- # Tensors, but it does not work for torch.exp().
91
- logvar1, logvar2 = [
92
- x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93
- for x in (logvar1, logvar2)
94
- ]
95
-
96
- return 0.5 * (
97
- -1.0
98
- + logvar2
99
- - logvar1
100
- + torch.exp(logvar1 - logvar2)
101
- + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/ema.py DELETED
@@ -1,86 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- class LitEma(nn.Module):
6
- def __init__(self, model, decay=0.9999, use_num_upates=True):
7
- super().__init__()
8
- if decay < 0.0 or decay > 1.0:
9
- raise ValueError("Decay must be between 0 and 1")
10
-
11
- self.m_name2s_name = {}
12
- self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13
- self.register_buffer(
14
- "num_updates",
15
- torch.tensor(0, dtype=torch.int)
16
- if use_num_upates
17
- else torch.tensor(-1, dtype=torch.int),
18
- )
19
-
20
- for name, p in model.named_parameters():
21
- if p.requires_grad:
22
- # remove as '.'-character is not allowed in buffers
23
- s_name = name.replace(".", "")
24
- self.m_name2s_name.update({name: s_name})
25
- self.register_buffer(s_name, p.clone().detach().data)
26
-
27
- self.collected_params = []
28
-
29
- def reset_num_updates(self):
30
- del self.num_updates
31
- self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
32
-
33
- def forward(self, model):
34
- decay = self.decay
35
-
36
- if self.num_updates >= 0:
37
- self.num_updates += 1
38
- decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
39
-
40
- one_minus_decay = 1.0 - decay
41
-
42
- with torch.no_grad():
43
- m_param = dict(model.named_parameters())
44
- shadow_params = dict(self.named_buffers())
45
-
46
- for key in m_param:
47
- if m_param[key].requires_grad:
48
- sname = self.m_name2s_name[key]
49
- shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
50
- shadow_params[sname].sub_(
51
- one_minus_decay * (shadow_params[sname] - m_param[key])
52
- )
53
- else:
54
- assert not key in self.m_name2s_name
55
-
56
- def copy_to(self, model):
57
- m_param = dict(model.named_parameters())
58
- shadow_params = dict(self.named_buffers())
59
- for key in m_param:
60
- if m_param[key].requires_grad:
61
- m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
62
- else:
63
- assert not key in self.m_name2s_name
64
-
65
- def store(self, parameters):
66
- """
67
- Save the current parameters for restoring later.
68
- Args:
69
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
70
- temporarily stored.
71
- """
72
- self.collected_params = [param.clone() for param in parameters]
73
-
74
- def restore(self, parameters):
75
- """
76
- Restore the parameters stored with the `store` method.
77
- Useful to validate the model with EMA parameters without affecting the
78
- original optimization process. Store the parameters before the
79
- `copy_to` method. After validation (or model saving), use this to
80
- restore the former parameters.
81
- Args:
82
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
83
- updated with the stored parameters.
84
- """
85
- for c_param, param in zip(self.collected_params, parameters):
86
- param.data.copy_(c_param.data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/encoders/__init__.py DELETED
File without changes
sgm/modules/encoders/modules.py DELETED
@@ -1,1154 +0,0 @@
1
- from contextlib import nullcontext
2
- from functools import partial
3
- from typing import Dict, List, Optional, Tuple, Union
4
- from packaging import version
5
-
6
- import kornia
7
- import numpy as np
8
- import open_clip
9
- from open_clip.tokenizer import SimpleTokenizer
10
- import torch
11
- import torch.nn as nn
12
- from einops import rearrange, repeat
13
- from omegaconf import ListConfig
14
- from torch.utils.checkpoint import checkpoint
15
- import transformers
16
- from transformers import (ByT5Tokenizer, CLIPTextModel, CLIPTokenizer,
17
- T5EncoderModel, T5Tokenizer)
18
-
19
- from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
20
- from ...modules.diffusionmodules.model import Encoder
21
- from ...modules.diffusionmodules.openaimodel import Timestep
22
- from ...modules.diffusionmodules.util import (extract_into_tensor,
23
- make_beta_schedule)
24
- from ...modules.distributions.distributions import DiagonalGaussianDistribution
25
- from ...util import (append_dims, autocast, count_params, default,
26
- disabled_train, expand_dims_like, instantiate_from_config)
27
-
28
-
29
- class AbstractEmbModel(nn.Module):
30
- def __init__(self):
31
- super().__init__()
32
- self._is_trainable = None
33
- self._ucg_rate = None
34
- self._input_key = None
35
-
36
- @property
37
- def is_trainable(self) -> bool:
38
- return self._is_trainable
39
-
40
- @property
41
- def ucg_rate(self) -> Union[float, torch.Tensor]:
42
- return self._ucg_rate
43
-
44
- @property
45
- def input_key(self) -> str:
46
- return self._input_key
47
-
48
- @is_trainable.setter
49
- def is_trainable(self, value: bool):
50
- self._is_trainable = value
51
-
52
- @ucg_rate.setter
53
- def ucg_rate(self, value: Union[float, torch.Tensor]):
54
- self._ucg_rate = value
55
-
56
- @input_key.setter
57
- def input_key(self, value: str):
58
- self._input_key = value
59
-
60
- @is_trainable.deleter
61
- def is_trainable(self):
62
- del self._is_trainable
63
-
64
- @ucg_rate.deleter
65
- def ucg_rate(self):
66
- del self._ucg_rate
67
-
68
- @input_key.deleter
69
- def input_key(self):
70
- del self._input_key
71
-
72
-
73
- class GeneralConditioner(nn.Module):
74
- OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
75
- KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
76
-
77
- def __init__(self, emb_models: Union[List, ListConfig]):
78
- super().__init__()
79
- embedders = []
80
- for n, embconfig in enumerate(emb_models):
81
- embedder = instantiate_from_config(embconfig)
82
- assert isinstance(
83
- embedder, AbstractEmbModel
84
- ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
85
- embedder.is_trainable = embconfig.get("is_trainable", False)
86
- embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
87
- if not embedder.is_trainable:
88
- embedder.train = disabled_train
89
- for param in embedder.parameters():
90
- param.requires_grad = False
91
- embedder.eval()
92
- print(
93
- f"Initialized embedder #{n}: {embedder.__class__.__name__} "
94
- f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
95
- )
96
-
97
- if "input_key" in embconfig:
98
- embedder.input_key = embconfig["input_key"]
99
- elif "input_keys" in embconfig:
100
- embedder.input_keys = embconfig["input_keys"].split(',')
101
- else:
102
- raise KeyError(
103
- f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
104
- )
105
-
106
- embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
107
- if embedder.legacy_ucg_val is not None:
108
- embedder.ucg_prng = np.random.RandomState()
109
-
110
- embedders.append(embedder)
111
- self.embedders = nn.ModuleList(embedders)
112
-
113
- def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
114
- assert embedder.legacy_ucg_val is not None
115
- p = embedder.ucg_rate
116
- val = embedder.legacy_ucg_val
117
- for i in range(len(batch[embedder.input_key])):
118
- if embedder.ucg_prng.choice(2, p=[1 - p, p]):
119
- batch[embedder.input_key][i] = val
120
- return batch
121
-
122
- def forward(
123
- self, batch: Dict, force_zero_embeddings: Optional[List] = None, force_ref_zero_embeddings: bool = False
124
- ) -> Dict:
125
- output = dict()
126
- if force_zero_embeddings is None:
127
- force_zero_embeddings = []
128
- for embedder in self.embedders:
129
- embedding_context = nullcontext if (embedder.is_trainable or embedder.modifier_token is not None) else torch.no_grad
130
- with embedding_context():
131
- if hasattr(embedder, "input_key") and (embedder.input_key is not None):
132
- if embedder.legacy_ucg_val is not None:
133
- batch = self.possibly_get_ucg_val(embedder, batch)
134
- emb_out = embedder(batch[embedder.input_key])
135
- elif hasattr(embedder, "input_keys"):
136
- if force_ref_zero_embeddings:
137
- emb_out = embedder(batch[embedder.input_keys[0]])
138
- else:
139
- emb_out = [embedder(batch[k]) for k in embedder.input_keys]
140
- if isinstance(emb_out[0], tuple):
141
- emb_out = [torch.cat([x[0] for x in emb_out]), torch.cat([x[1] for x in emb_out])]
142
- else:
143
- emb_out = torch.cat(emb_out)
144
-
145
- assert isinstance(
146
- emb_out, (torch.Tensor, list, tuple)
147
- ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
148
- if not isinstance(emb_out, (list, tuple)):
149
- emb_out = [emb_out]
150
- for emb in emb_out:
151
- out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
152
- if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
153
- emb = (
154
- expand_dims_like(
155
- torch.bernoulli(
156
- (1.0 - embedder.ucg_rate)
157
- * torch.ones(emb.shape[0], device=emb.device)
158
- ),
159
- emb,
160
- )
161
- * emb
162
- )
163
- if (
164
- hasattr(embedder, "input_key")
165
- and embedder.input_key in force_zero_embeddings
166
- ):
167
- emb = torch.zeros_like(emb)
168
- if (
169
- hasattr(embedder, "input_keys")
170
- and embedder.input_keys in force_zero_embeddings
171
- ):
172
- emb = torch.zeros_like(emb)
173
- if out_key in output:
174
- if hasattr(embedder, "input_keys"):
175
- catdim = 1 if ('pose' in embedder.input_keys) else self.KEY2CATDIM[out_key]
176
- if not force_ref_zero_embeddings:
177
- c, c1 = emb.chunk(2)
178
- output[out_key] = torch.cat(
179
- (output[out_key], c), catdim
180
- )
181
- output[out_key+'_ref'] = torch.cat(
182
- (output[out_key+'_ref'], c1), catdim
183
- )
184
- else:
185
- # print(output[out_key].size(), emb.size(), "$")
186
- output[out_key] = torch.cat(
187
- (output[out_key], emb), catdim
188
- )
189
- else:
190
- catdim = 1 if ('pose' in embedder.input_key and emb.size(1) != 77) else self.KEY2CATDIM[out_key]
191
- output[out_key] = torch.cat(
192
- (output[out_key], emb), catdim
193
- )
194
- else:
195
- if hasattr(embedder, "input_keys"):
196
- if not force_ref_zero_embeddings:
197
- c, c1 = emb.chunk(2)
198
- output[out_key] = c
199
- output[out_key+'_ref'] = c1
200
- else:
201
- output[out_key] = emb
202
- else:
203
- output[out_key] = emb
204
-
205
- for out_key in self.OUTPUT_DIM2KEYS.values():
206
- if out_key+'_ref' in output and not force_ref_zero_embeddings:
207
- output[out_key] = torch.cat([output[out_key], output[out_key+'_ref']], 0)
208
- del output[out_key+'_ref']
209
-
210
- return output
211
-
212
- def get_unconditional_conditioning(
213
- self,
214
- batch_c: Dict,
215
- batch_uc: Optional[Dict] = None,
216
- force_uc_zero_embeddings: Optional[List[str]] = None,
217
- force_ref_zero_embeddings: Optional[List[str]] = None,
218
- ):
219
- if force_uc_zero_embeddings is None:
220
- force_uc_zero_embeddings = []
221
- ucg_rates = list()
222
- for embedder in self.embedders:
223
- ucg_rates.append(embedder.ucg_rate)
224
- embedder.ucg_rate = 0.0
225
- c = self(batch_c, force_ref_zero_embeddings=force_ref_zero_embeddings)
226
- uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings, force_ref_zero_embeddings)
227
-
228
- for embedder, rate in zip(self.embedders, ucg_rates):
229
- embedder.ucg_rate = rate
230
- return c, uc
231
-
232
-
233
- class InceptionV3(nn.Module):
234
- """Wrapper around the https://github.com/mseitzer/pytorch-fid inception
235
- port with an additional squeeze at the end"""
236
-
237
- def __init__(self, normalize_input=False, **kwargs):
238
- super().__init__()
239
- from pytorch_fid import inception
240
-
241
- kwargs["resize_input"] = True
242
- self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)
243
-
244
- def forward(self, inp):
245
- outp = self.model(inp)
246
-
247
- if len(outp) == 1:
248
- return outp[0].squeeze()
249
-
250
- return outp
251
-
252
-
253
- class IdentityEncoder(AbstractEmbModel):
254
- def encode(self, x):
255
- return x
256
-
257
- def forward(self, x):
258
- return x
259
-
260
-
261
- class ClassEmbedder(AbstractEmbModel):
262
- def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):
263
- super().__init__()
264
- self.embedding = nn.Embedding(n_classes, embed_dim)
265
- self.n_classes = n_classes
266
- self.add_sequence_dim = add_sequence_dim
267
-
268
- def forward(self, c):
269
- c = self.embedding(c)
270
- if self.add_sequence_dim:
271
- c = c[:, None, :]
272
- return c
273
-
274
- def get_unconditional_conditioning(self, bs, device="cuda"):
275
- uc_class = (
276
- self.n_classes - 1
277
- ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
278
- uc = torch.ones((bs,), device=device) * uc_class
279
- uc = {self.key: uc.long()}
280
- return uc
281
-
282
-
283
- class ClassEmbedderForMultiCond(ClassEmbedder):
284
- def forward(self, batch, key=None, disable_dropout=False):
285
- out = batch
286
- key = default(key, self.key)
287
- islist = isinstance(batch[key], list)
288
- if islist:
289
- batch[key] = batch[key][0]
290
- c_out = super().forward(batch, key, disable_dropout)
291
- out[key] = [c_out] if islist else c_out
292
- return out
293
-
294
-
295
- class FrozenT5Embedder(AbstractEmbModel):
296
- """Uses the T5 transformer encoder for text"""
297
-
298
- def __init__(
299
- self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
300
- ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
301
- super().__init__()
302
- self.tokenizer = T5Tokenizer.from_pretrained(version)
303
- self.transformer = T5EncoderModel.from_pretrained(version)
304
- self.device = device
305
- self.max_length = max_length
306
- if freeze:
307
- self.freeze()
308
-
309
- def freeze(self):
310
- self.transformer = self.transformer.eval()
311
-
312
- for param in self.parameters():
313
- param.requires_grad = False
314
-
315
- def forward(self, text):
316
- batch_encoding = self.tokenizer(
317
- text,
318
- truncation=True,
319
- max_length=self.max_length,
320
- return_length=True,
321
- return_overflowing_tokens=False,
322
- padding="max_length",
323
- return_tensors="pt",
324
- )
325
- tokens = batch_encoding["input_ids"].to(self.device)
326
- with torch.autocast("cuda", enabled=False):
327
- outputs = self.transformer(input_ids=tokens)
328
- z = outputs.last_hidden_state
329
- return z
330
-
331
- def encode(self, text):
332
- return self(text)
333
-
334
-
335
- class FrozenByT5Embedder(AbstractEmbModel):
336
- """
337
- Uses the ByT5 transformer encoder for text. Is character-aware.
338
- """
339
-
340
- def __init__(
341
- self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
342
- ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
343
- super().__init__()
344
- self.tokenizer = ByT5Tokenizer.from_pretrained(version)
345
- self.transformer = T5EncoderModel.from_pretrained(version)
346
- self.device = device
347
- self.max_length = max_length
348
- if freeze:
349
- self.freeze()
350
-
351
- def freeze(self):
352
- self.transformer = self.transformer.eval()
353
-
354
- for param in self.parameters():
355
- param.requires_grad = False
356
-
357
- def forward(self, text):
358
- batch_encoding = self.tokenizer(
359
- text,
360
- truncation=True,
361
- max_length=self.max_length,
362
- return_length=True,
363
- return_overflowing_tokens=False,
364
- padding="max_length",
365
- return_tensors="pt",
366
- )
367
- tokens = batch_encoding["input_ids"].to(self.device)
368
- with torch.autocast("cuda", enabled=False):
369
- outputs = self.transformer(input_ids=tokens)
370
- z = outputs.last_hidden_state
371
- return z
372
-
373
- def encode(self, text):
374
- return self(text)
375
-
376
-
377
- class FrozenCLIPEmbedder(AbstractEmbModel):
378
- """Uses the CLIP transformer encoder for text (from huggingface)"""
379
-
380
- LAYERS = ["last", "pooled", "hidden"]
381
-
382
- def __init__(
383
- self,
384
- modifier_token=None,
385
- version="openai/clip-vit-large-patch14",
386
- device="cuda",
387
- max_length=77,
388
- freeze=True,
389
- layer="last",
390
- layer_idx=None,
391
- always_return_pooled=False,
392
- ): # clip-vit-base-patch32
393
- super().__init__()
394
- assert layer in self.LAYERS
395
- self.tokenizer = CLIPTokenizer.from_pretrained(version)
396
- self.transformer = CLIPTextModel.from_pretrained(version)
397
- self.device = device
398
- self.max_length = max_length
399
- self.modifier_token = modifier_token
400
- if self.modifier_token is not None:
401
- if '+' in self.modifier_token:
402
- self.modifier_token = self.modifier_token.split('+')
403
- else:
404
- self.modifier_token = [self.modifier_token]
405
-
406
- self.add_token()
407
-
408
- if freeze:
409
- self.freeze()
410
- self.layer = layer
411
- self.layer_idx = layer_idx
412
- self.return_pooled = always_return_pooled
413
- if layer == "hidden":
414
- assert layer_idx is not None
415
- assert 0 <= abs(layer_idx) <= 12
416
-
417
- def add_token(self):
418
- self.modifier_token_id = []
419
- for each_modifier_token in self.modifier_token:
420
- print(each_modifier_token, "adding new token")
421
- _ = self.tokenizer.add_tokens(each_modifier_token)
422
- modifier_token_id = self.tokenizer.convert_tokens_to_ids(each_modifier_token)
423
- self.modifier_token_id.append(modifier_token_id)
424
-
425
- self.transformer.resize_token_embeddings(len(self.tokenizer))
426
- token_embeds = self.transformer.get_input_embeddings().weight.data
427
- token_embeds[self.modifier_token_id[-1]] = torch.nn.Parameter(token_embeds[42170], requires_grad=True)
428
- if len(self.modifier_token) == 2:
429
- token_embeds[self.modifier_token_id[-2]] = torch.nn.Parameter(token_embeds[47629], requires_grad=True)
430
- if len(self.modifier_token) == 3:
431
- token_embeds[self.modifier_token_id[-3]] = torch.nn.Parameter(token_embeds[43514], requires_grad=True)
432
-
433
- def freeze(self):
434
- if self.modifier_token is not None:
435
- self.transformer = self.transformer.eval()
436
- for param in self.transformer.text_model.encoder.parameters():
437
- param.requires_grad = False
438
- for param in self.transformer.text_model.final_layer_norm.parameters():
439
- param.requires_grad = False
440
- for param in self.transformer.text_model.embeddings.parameters():
441
- param.requires_grad = False
442
- for param in self.transformer.get_input_embeddings().parameters():
443
- param.requires_grad = True
444
- print("making grad true")
445
- else:
446
- self.transformer = self.transformer.eval()
447
-
448
- for param in self.parameters():
449
- param.requires_grad = False
450
-
451
- def _build_causal_attention_mask(self, bsz, seq_len, dtype):
452
- # lazily create causal attention mask, with full attention between the vision tokens
453
- # pytorch uses additive attention mask; fill with -inf
454
- mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
455
- mask.fill_(torch.tensor(torch.finfo(dtype).min))
456
- mask.triu_(1) # zero out the lower diagonal
457
- mask = mask.unsqueeze(1) # expand mask
458
- return mask
459
-
460
- @autocast
461
- def custom_forward(self, hidden_states, input_ids):
462
- r"""
463
- Returns:
464
- """
465
- input_shape = hidden_states.size()
466
- bsz, seq_len = input_shape[:2]
467
- if version.parse(transformers.__version__) >= version.parse('4.21'):
468
- causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
469
- hidden_states.device
470
- )
471
- else:
472
- causal_attention_mask = self.transformer.text_model._build_causal_attention_mask(bsz, seq_len).to(
473
- hidden_states.device
474
- )
475
-
476
- encoder_outputs = self.transformer.text_model.encoder(
477
- inputs_embeds=hidden_states,
478
- causal_attention_mask=causal_attention_mask,
479
- )
480
-
481
- last_hidden_state = encoder_outputs[0]
482
- last_hidden_state = self.transformer.text_model.final_layer_norm(last_hidden_state)
483
-
484
- return last_hidden_state
485
-
486
- @autocast
487
- def forward(self, text):
488
- batch_encoding = self.tokenizer(
489
- text,
490
- truncation=True,
491
- max_length=self.max_length,
492
- return_length=True,
493
- return_overflowing_tokens=False,
494
- padding="max_length",
495
- return_tensors="pt"
496
- )
497
- tokens = batch_encoding["input_ids"].to(self.device)
498
-
499
- if self.modifier_token is not None:
500
- indices = tokens == self.modifier_token_id[-1]
501
- for token_id in self.modifier_token_id:
502
- indices |= tokens == token_id
503
-
504
- indices = (indices*1).unsqueeze(-1)
505
-
506
- input_shape = tokens.size()
507
- tokens = tokens.view(-1, input_shape[-1])
508
-
509
- hidden_states = self.transformer.text_model.embeddings(input_ids=tokens)
510
- if self.modifier_token is not None:
511
- hidden_states = (1-indices)*hidden_states.detach() + indices*hidden_states
512
- z = self.custom_forward(hidden_states, tokens)
513
- return z
514
-
515
- def encode(self, text):
516
- return self(text)
517
-
518
-
519
- class FrozenOpenCLIPEmbedder2(AbstractEmbModel):
520
- """
521
- Uses the OpenCLIP transformer encoder for text
522
- """
523
-
524
- LAYERS = ["pooled", "last", "penultimate"]
525
-
526
- def __init__(
527
- self,
528
- arch="ViT-H-14",
529
- version="laion2b_s32b_b79k",
530
- device="cuda",
531
- max_length=77,
532
- freeze=True,
533
- layer="last",
534
- always_return_pooled=False,
535
- legacy=True,
536
- ):
537
- super().__init__()
538
- assert layer in self.LAYERS
539
- model, _, _ = open_clip.create_model_and_transforms(
540
- arch,
541
- device=torch.device("cpu"),
542
- pretrained=version,
543
- )
544
- del model.visual
545
- self.model = model
546
- self.modifier_token = None
547
-
548
- self.device = device
549
- self.max_length = max_length
550
- self.return_pooled = always_return_pooled
551
- if freeze:
552
- self.freeze()
553
- self.layer = layer
554
- if self.layer == "last":
555
- self.layer_idx = 0
556
- elif self.layer == "penultimate":
557
- self.layer_idx = 1
558
- else:
559
- raise NotImplementedError()
560
- self.legacy = legacy
561
-
562
- def freeze(self):
563
- self.model = self.model.eval()
564
- for param in self.parameters():
565
- param.requires_grad = False
566
-
567
- @autocast
568
- def forward(self, text):
569
- tokens = open_clip.tokenize(text)
570
- z = self.encode_with_transformer(tokens.to(self.device))
571
- if not self.return_pooled and self.legacy:
572
- return z
573
- if self.return_pooled:
574
- assert not self.legacy
575
- return z[self.layer], z["pooled"]
576
- return z[self.layer]
577
-
578
- def encode_with_transformer(self, text):
579
- x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
580
- x = x + self.model.positional_embedding
581
- x = x.permute(1, 0, 2) # NLD -> LND
582
- x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
583
- if self.legacy:
584
- x = x[self.layer]
585
- x = self.model.ln_final(x)
586
- return x
587
- else:
588
- # x is a dict and will stay a dict
589
- o = x["last"]
590
- o = self.model.ln_final(o)
591
- pooled = self.pool(o, text)
592
- x["pooled"] = pooled
593
- return x
594
-
595
- def pool(self, x, text):
596
- # take features from the eot embedding (eot_token is the highest number in each sequence)
597
- x = (
598
- x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
599
- @ self.model.text_projection
600
- )
601
- return x
602
-
603
- def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
604
- outputs = {}
605
- for i, r in enumerate(self.model.transformer.resblocks):
606
- if i == len(self.model.transformer.resblocks) - 1:
607
- outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD
608
- if (
609
- self.model.transformer.grad_checkpointing
610
- and not torch.jit.is_scripting()
611
- ):
612
- x = checkpoint(r, x, attn_mask)
613
- else:
614
- x = r(x, attn_mask=attn_mask)
615
- outputs["last"] = x.permute(1, 0, 2) # LND -> NLD
616
- return outputs
617
-
618
- def encode(self, text):
619
- return self(text)
620
-
621
-
622
- class FrozenOpenCLIPEmbedder(AbstractEmbModel):
623
- LAYERS = [
624
- # "pooled",
625
- "last",
626
- "penultimate",
627
- ]
628
-
629
- def __init__(
630
- self,
631
- modifier_token=None,
632
- arch="ViT-H-14",
633
- version="laion2b_s32b_b79k",
634
- device="cuda",
635
- max_length=77,
636
- freeze=True,
637
- layer="last",
638
- always_return_pooled=False,
639
- legacy=True,
640
- ):
641
- super().__init__()
642
- assert layer in self.LAYERS
643
- model, _, _ = open_clip.create_model_and_transforms(
644
- arch, device=torch.device("cpu"), pretrained=version
645
- )
646
- del model.visual
647
- self.model = model
648
-
649
- self.device = device
650
- self.max_length = max_length
651
- self.modifier_token = modifier_token
652
- self.return_pooled = always_return_pooled
653
- if self.modifier_token is not None:
654
- if '+' in self.modifier_token:
655
- self.modifier_token = self.modifier_token.split('+')
656
- else:
657
- self.modifier_token = [self.modifier_token]
658
- self.tokenizer = SimpleTokenizer(additional_special_tokens=self.modifier_token)
659
-
660
- self.add_token()
661
- else:
662
- self.tokenizer = SimpleTokenizer()
663
-
664
- if freeze:
665
- self.freeze()
666
- self.layer = layer
667
- if self.layer == "last":
668
- self.layer_idx = 0
669
- elif self.layer == "penultimate":
670
- self.layer_idx = 1
671
- else:
672
- raise NotImplementedError()
673
- self.legacy = legacy
674
-
675
- def tokenize(self, texts, context_length=77):
676
- return self.tokenizer(texts, context_length=context_length)
677
-
678
- def add_token(self):
679
- self.modifier_token_id = []
680
-
681
- token_embeds1 = self.model.token_embedding.weight.data
682
- for each_modifier_token in self.modifier_token:
683
- modifier_token_id = self.tokenizer.encoder[each_modifier_token]
684
- self.modifier_token_id.append(modifier_token_id)
685
-
686
- self.model.token_embedding = nn.Embedding(token_embeds1.shape[0] + len(self.modifier_token), token_embeds1.shape[1])
687
- self.model.token_embedding.weight.data[:token_embeds1.shape[0]] = token_embeds1
688
-
689
- self.model.token_embedding.weight.data[self.modifier_token_id[-1]] = token_embeds1[42170]
690
- if len(self.modifier_token) == 2:
691
- self.model.token_embedding.weight.data[self.modifier_token_id[-2]] = token_embeds1[47629]
692
-
693
- def freeze(self):
694
- if self.modifier_token is not None:
695
- self.model = self.model.eval()
696
- for param in self.model.transformer.parameters():
697
- param.requires_grad = False
698
- for param in self.model.ln_final.parameters():
699
- param.requires_grad = False
700
- self.model.text_projection.requires_grad = False
701
- self.model.positional_embedding.requires_grad = False
702
- for param in self.model.token_embedding.parameters():
703
- param.requires_grad = True
704
- print("making grad true")
705
- else:
706
- self.model = self.model.eval()
707
- for param in self.parameters():
708
- param.requires_grad = False
709
-
710
- @autocast
711
- def forward(self, text):
712
- tokens = self.tokenize(text)
713
- z = self.encode_with_transformer(tokens.to(self.device))
714
- if not self.return_pooled and self.legacy:
715
- return z
716
- if self.return_pooled:
717
- assert not self.legacy
718
- return z[self.layer], z["pooled"]
719
- return z[self.layer]
720
-
721
- def encode_with_transformer(self, text):
722
- x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
723
-
724
- if self.modifier_token is not None:
725
- indices = text == self.modifier_token_id[-1]
726
- for token_id in self.modifier_token_id:
727
- indices |= text == token_id
728
-
729
- indices = (indices*1).unsqueeze(-1)
730
- x = ((1-indices)*x.detach() + indices*x) + self.model.positional_embedding
731
- else:
732
- x = x + self.model.positional_embedding
733
- x = x.permute(1, 0, 2) # NLD -> LND
734
- x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
735
- if self.legacy:
736
- x = x[self.layer]
737
- x = self.model.ln_final(x)
738
- return x
739
- else:
740
- # x is a dict and will stay a dict
741
- o = x["last"]
742
- o = self.model.ln_final(o)
743
- pooled = self.pool(o, text)
744
- x["pooled"] = pooled
745
- return x
746
-
747
- def pool(self, x, text):
748
- # take features from the eot embedding (eot_token is the highest number in each sequence)
749
- x = (
750
- x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
751
- @ self.model.text_projection
752
- )
753
- return x
754
-
755
- def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
756
- outputs = {}
757
- for i, r in enumerate(self.model.transformer.resblocks):
758
- if i == len(self.model.transformer.resblocks) - 1:
759
- outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD
760
- if (
761
- self.model.transformer.grad_checkpointing
762
- and not torch.jit.is_scripting()
763
- ):
764
- x = checkpoint(r, x, attn_mask)
765
- else:
766
- x = r(x, attn_mask=attn_mask)
767
- outputs["last"] = x.permute(1, 0, 2) # LND -> NLD
768
- return outputs
769
-
770
- def encode(self, text):
771
- return self(text)
772
-
773
-
774
- class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
775
- """
776
- Uses the OpenCLIP vision transformer encoder for images
777
- """
778
-
779
- def __init__(
780
- self,
781
- arch="ViT-H-14",
782
- version="laion2b_s32b_b79k",
783
- device="cuda",
784
- max_length=77,
785
- freeze=True,
786
- antialias=True,
787
- ucg_rate=0.0,
788
- unsqueeze_dim=False,
789
- repeat_to_max_len=False,
790
- num_image_crops=0,
791
- output_tokens=False,
792
- init_device=None,
793
- ):
794
- super().__init__()
795
- model, _, _ = open_clip.create_model_and_transforms(
796
- arch,
797
- device=torch.device(default(init_device, "cpu")),
798
- pretrained=version,
799
- )
800
- del model.transformer
801
- self.model = model
802
- self.max_crops = num_image_crops
803
- self.pad_to_max_len = self.max_crops > 0
804
- self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
805
- self.device = device
806
- self.max_length = max_length
807
- if freeze:
808
- self.freeze()
809
-
810
- self.antialias = antialias
811
-
812
- self.register_buffer(
813
- "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
814
- )
815
- self.register_buffer(
816
- "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
817
- )
818
- self.ucg_rate = ucg_rate
819
- self.unsqueeze_dim = unsqueeze_dim
820
- self.stored_batch = None
821
- self.model.visual.output_tokens = output_tokens
822
- self.output_tokens = output_tokens
823
-
824
- def preprocess(self, x):
825
- # normalize to [0,1]
826
- x = kornia.geometry.resize(
827
- x,
828
- (224, 224),
829
- interpolation="bicubic",
830
- align_corners=True,
831
- antialias=self.antialias,
832
- )
833
- x = (x + 1.0) / 2.0
834
- # renormalize according to clip
835
- x = kornia.enhance.normalize(x, self.mean, self.std)
836
- return x
837
-
838
- def freeze(self):
839
- self.model = self.model.eval()
840
- for param in self.parameters():
841
- param.requires_grad = False
842
-
843
- @autocast
844
- def forward(self, image, no_dropout=False):
845
- z = self.encode_with_vision_transformer(image)
846
- tokens = None
847
- if self.output_tokens:
848
- z, tokens = z[0], z[1]
849
- z = z.to(image.dtype)
850
- if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
851
- z = (
852
- torch.bernoulli(
853
- (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
854
- )[:, None]
855
- * z
856
- )
857
- if tokens is not None:
858
- tokens = (
859
- expand_dims_like(
860
- torch.bernoulli(
861
- (1.0 - self.ucg_rate)
862
- * torch.ones(tokens.shape[0], device=tokens.device)
863
- ),
864
- tokens,
865
- )
866
- * tokens
867
- )
868
- if self.unsqueeze_dim:
869
- z = z[:, None, :]
870
- if self.output_tokens:
871
- assert not self.repeat_to_max_len
872
- assert not self.pad_to_max_len
873
- return tokens, z
874
- if self.repeat_to_max_len:
875
- if z.dim() == 2:
876
- z_ = z[:, None, :]
877
- else:
878
- z_ = z
879
- return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
880
- elif self.pad_to_max_len:
881
- assert z.dim() == 3
882
- z_pad = torch.cat(
883
- (
884
- z,
885
- torch.zeros(
886
- z.shape[0],
887
- self.max_length - z.shape[1],
888
- z.shape[2],
889
- device=z.device,
890
- ),
891
- ),
892
- 1,
893
- )
894
- return z_pad, z_pad[:, 0, ...]
895
- return z
896
-
897
- def encode_with_vision_transformer(self, img):
898
- # if self.max_crops > 0:
899
- # img = self.preprocess_by_cropping(img)
900
- if img.dim() == 5:
901
- assert self.max_crops == img.shape[1]
902
- img = rearrange(img, "b n c h w -> (b n) c h w")
903
- img = self.preprocess(img)
904
- if not self.output_tokens:
905
- assert not self.model.visual.output_tokens
906
- x = self.model.visual(img)
907
- tokens = None
908
- else:
909
- assert self.model.visual.output_tokens
910
- x, tokens = self.model.visual(img)
911
- if self.max_crops > 0:
912
- x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
913
- # drop out between 0 and all along the sequence axis
914
- x = (
915
- torch.bernoulli(
916
- (1.0 - self.ucg_rate)
917
- * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
918
- )
919
- * x
920
- )
921
- if tokens is not None:
922
- tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
923
- print(
924
- f"You are running very experimental token-concat in {self.__class__.__name__}. "
925
- f"Check what you are doing, and then remove this message."
926
- )
927
- if self.output_tokens:
928
- return x, tokens
929
- return x
930
-
931
- def encode(self, text):
932
- return self(text)
933
-
934
-
935
- class FrozenCLIPT5Encoder(AbstractEmbModel):
936
- def __init__(
937
- self,
938
- clip_version="openai/clip-vit-large-patch14",
939
- t5_version="google/t5-v1_1-xl",
940
- device="cuda",
941
- clip_max_length=77,
942
- t5_max_length=77,
943
- ):
944
- super().__init__()
945
- self.clip_encoder = FrozenCLIPEmbedder(
946
- clip_version, device, max_length=clip_max_length
947
- )
948
- self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
949
- print(
950
- f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
951
- f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
952
- )
953
-
954
- def encode(self, text):
955
- return self(text)
956
-
957
- def forward(self, text):
958
- clip_z = self.clip_encoder.encode(text)
959
- t5_z = self.t5_encoder.encode(text)
960
- return [clip_z, t5_z]
961
-
962
-
963
- class SpatialRescaler(nn.Module):
964
- def __init__(
965
- self,
966
- n_stages=1,
967
- method="bilinear",
968
- multiplier=0.5,
969
- in_channels=3,
970
- out_channels=None,
971
- bias=False,
972
- wrap_video=False,
973
- kernel_size=1,
974
- remap_output=False,
975
- ):
976
- super().__init__()
977
- self.n_stages = n_stages
978
- assert self.n_stages >= 0
979
- assert method in [
980
- "nearest",
981
- "linear",
982
- "bilinear",
983
- "trilinear",
984
- "bicubic",
985
- "area",
986
- ]
987
- self.multiplier = multiplier
988
- self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
989
- self.remap_output = out_channels is not None or remap_output
990
- if self.remap_output:
991
- print(
992
- f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
993
- )
994
- self.channel_mapper = nn.Conv2d(
995
- in_channels,
996
- out_channels,
997
- kernel_size=kernel_size,
998
- bias=bias,
999
- padding=kernel_size // 2,
1000
- )
1001
- self.wrap_video = wrap_video
1002
-
1003
- def forward(self, x):
1004
- if self.wrap_video and x.ndim == 5:
1005
- B, C, T, H, W = x.shape
1006
- x = rearrange(x, "b c t h w -> b t c h w")
1007
- x = rearrange(x, "b t c h w -> (b t) c h w")
1008
-
1009
- for stage in range(self.n_stages):
1010
- x = self.interpolator(x, scale_factor=self.multiplier)
1011
-
1012
- if self.wrap_video:
1013
- x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C)
1014
- x = rearrange(x, "b t c h w -> b c t h w")
1015
- if self.remap_output:
1016
- x = self.channel_mapper(x)
1017
- return x
1018
-
1019
- def encode(self, x):
1020
- return self(x)
1021
-
1022
-
1023
- class LowScaleEncoder(nn.Module):
1024
- def __init__(
1025
- self,
1026
- model_config,
1027
- linear_start,
1028
- linear_end,
1029
- timesteps=1000,
1030
- max_noise_level=250,
1031
- output_size=64,
1032
- scale_factor=1.0,
1033
- ):
1034
- super().__init__()
1035
- self.max_noise_level = max_noise_level
1036
- self.model = instantiate_from_config(model_config)
1037
- self.augmentation_schedule = self.register_schedule(
1038
- timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
1039
- )
1040
- self.out_size = output_size
1041
- self.scale_factor = scale_factor
1042
-
1043
- def register_schedule(
1044
- self,
1045
- beta_schedule="linear",
1046
- timesteps=1000,
1047
- linear_start=1e-4,
1048
- linear_end=2e-2,
1049
- cosine_s=8e-3,
1050
- ):
1051
- betas = make_beta_schedule(
1052
- beta_schedule,
1053
- timesteps,
1054
- linear_start=linear_start,
1055
- linear_end=linear_end,
1056
- cosine_s=cosine_s,
1057
- )
1058
- alphas = 1.0 - betas
1059
- alphas_cumprod = np.cumprod(alphas, axis=0)
1060
- alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
1061
-
1062
- (timesteps,) = betas.shape
1063
- self.num_timesteps = int(timesteps)
1064
- self.linear_start = linear_start
1065
- self.linear_end = linear_end
1066
- assert (
1067
- alphas_cumprod.shape[0] == self.num_timesteps
1068
- ), "alphas have to be defined for each timestep"
1069
-
1070
- to_torch = partial(torch.tensor, dtype=torch.float32)
1071
-
1072
- self.register_buffer("betas", to_torch(betas))
1073
- self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
1074
- self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
1075
-
1076
- # calculations for diffusion q(x_t | x_{t-1}) and others
1077
- self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
1078
- self.register_buffer(
1079
- "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
1080
- )
1081
- self.register_buffer(
1082
- "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
1083
- )
1084
- self.register_buffer(
1085
- "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
1086
- )
1087
- self.register_buffer(
1088
- "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
1089
- )
1090
-
1091
- def q_sample(self, x_start, t, noise=None):
1092
- noise = default(noise, lambda: torch.randn_like(x_start))
1093
- return (
1094
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
1095
- + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
1096
- * noise
1097
- )
1098
-
1099
- def forward(self, x):
1100
- z = self.model.encode(x)
1101
- if isinstance(z, DiagonalGaussianDistribution):
1102
- z = z.sample()
1103
- z = z * self.scale_factor
1104
- noise_level = torch.randint(
1105
- 0, self.max_noise_level, (x.shape[0],), device=x.device
1106
- ).long()
1107
- z = self.q_sample(z, noise_level)
1108
- if self.out_size is not None:
1109
- z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
1110
- return z, noise_level
1111
-
1112
- def decode(self, z):
1113
- z = z / self.scale_factor
1114
- return self.model.decode(z)
1115
-
1116
-
1117
- class ConcatTimestepEmbedderND(AbstractEmbModel):
1118
- """embeds each dimension independently and concatenates them"""
1119
-
1120
- def __init__(self, outdim):
1121
- super().__init__()
1122
- self.timestep = Timestep(outdim)
1123
- self.outdim = outdim
1124
- self.modifier_token = None
1125
-
1126
- def forward(self, x):
1127
- if x.ndim == 1:
1128
- x = x[:, None]
1129
- assert len(x.shape) == 2
1130
- b, dims = x.shape[0], x.shape[1]
1131
- x = rearrange(x, "b d -> (b d)")
1132
- emb = self.timestep(x)
1133
- emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
1134
- return emb
1135
-
1136
-
1137
- class GaussianEncoder(Encoder, AbstractEmbModel):
1138
- def __init__(
1139
- self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs
1140
- ):
1141
- super().__init__(*args, **kwargs)
1142
- self.posterior = DiagonalGaussianRegularizer()
1143
- self.weight = weight
1144
- self.flatten_output = flatten_output
1145
-
1146
- def forward(self, x) -> Tuple[Dict, torch.Tensor]:
1147
- z = super().forward(x)
1148
- z, log = self.posterior(z)
1149
- log["loss"] = log["kl_loss"]
1150
- log["weight"] = self.weight
1151
- if self.flatten_output:
1152
- z = rearrange(z, "b c h w -> b (h w ) c")
1153
- return log, z
1154
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/nerfsd_pytorch3d.py DELETED
@@ -1,468 +0,0 @@
1
- import math
2
- import sys
3
- import itertools
4
-
5
- import numpy as np
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- import torch
9
- from einops import rearrange
10
- from ..modules.utils_cameraray import (
11
- get_patch_rays,
12
- get_plucker_parameterization,
13
- positional_encoding,
14
- convert_to_view_space,
15
- convert_to_view_space_points,
16
- convert_to_target_space,
17
- )
18
-
19
-
20
- from pytorch3d.renderer import ray_bundle_to_ray_points
21
- from pytorch3d.renderer.implicit.raysampling import RayBundle as RayBundle
22
- from pytorch3d import _C
23
-
24
- from ..modules.diffusionmodules.util import zero_module
25
-
26
-
27
- class FeatureNeRFEncoding(nn.Module):
28
- def __init__(
29
- self,
30
- in_channels,
31
- out_channels,
32
- far_plane: float = 2.0,
33
- rgb_predict=False,
34
- average=False,
35
- num_freqs=16,
36
- ) -> None:
37
- super().__init__()
38
-
39
- self.far_plane = far_plane
40
- self.rgb_predict = rgb_predict
41
- self.average = average
42
- self.num_freqs = num_freqs
43
- dim = 3
44
- self.plane_coefs = nn.Sequential(
45
- nn.Linear(in_channels + self.num_freqs * dim * 4 + 2 * dim, out_channels),
46
- nn.SiLU(),
47
- nn.Linear(out_channels, out_channels),
48
- )
49
- if not self.average:
50
- self.nviews = nn.Linear(
51
- in_channels + self.num_freqs * dim * 4 + 2 * dim, 1
52
- )
53
- self.decoder = zero_module(
54
- nn.Linear(out_channels, 1 + (3 if rgb_predict else 0), bias=False)
55
- )
56
-
57
- def forward(self, pose, xref, ray_points, rays, mask_ref):
58
- # xref : [b, n, hw, c]
59
- # ray_points: [b, n+1, hw, d, 3]
60
- # rays: [b, n+1, hw, 6]
61
-
62
- b, n, hw, c = xref.shape
63
- d = ray_points.shape[3]
64
- res = int(math.sqrt(hw))
65
- if mask_ref is not None:
66
- mask_ref = torch.nn.functional.interpolate(
67
- rearrange(
68
- mask_ref,
69
- "b n ... -> (b n) ...",
70
- ),
71
- size=[res, res],
72
- mode="nearest",
73
- ).reshape(b, n, -1, 1)
74
- xref = xref * mask_ref
75
-
76
- volume = []
77
- for i, cam in enumerate(pose):
78
- volume.append(
79
- cam.transform_points_ndc(ray_points[i, 0].reshape(-1, 3)).reshape(n + 1, hw, d, 3)[..., :2]
80
- )
81
- volume = torch.stack(volume)
82
-
83
- plane_features = F.grid_sample(
84
- rearrange(
85
- xref,
86
- "b n (h w) c -> (b n) c h w",
87
- b=b,
88
- h=int(math.sqrt(hw)),
89
- w=int(math.sqrt(hw)),
90
- c=c,
91
- n=n,
92
- ),
93
- torch.clip(
94
- torch.nan_to_num(
95
- rearrange(-1 * volume[:, 1:].detach(), "b n ... -> (b n) ...")
96
- ),
97
- -1.2,
98
- 1.2,
99
- ),
100
- align_corners=True,
101
- padding_mode="zeros",
102
- ) # [bn, c, hw, d]
103
-
104
- plane_features = rearrange(plane_features, "(b n) ... -> b n ...", b=b, n=n)
105
-
106
- xyz_grid_features_inviewframe = convert_to_view_space_points(pose, ray_points[:, 0])
107
- xyz_grid_features_inviewframe_encoding = positional_encoding(xyz_grid_features_inviewframe, self.num_freqs)
108
- camera_features_inviewframe = (
109
- convert_to_view_space(pose, rays[:, 0])[:, 1:]
110
- .reshape(b, n, hw, 1, -1)
111
- .expand(-1, -1, -1, d, -1)
112
- )
113
- camera_features_inviewframe_encoding = positional_encoding(
114
- get_plucker_parameterization(camera_features_inviewframe),
115
- self.num_freqs // 2,
116
- )
117
- xyz_grid_features = xyz_grid_features_inviewframe_encoding[:, :1].expand(
118
- -1, n, -1, -1, -1
119
- )
120
- camera_features = (
121
- (convert_to_target_space(pose, rays[:, 1:])[..., :3])
122
- .reshape(b, n, hw, 1, -1)
123
- .expand(-1, -1, -1, d, -1)
124
- )
125
- camera_features_encoding = positional_encoding(
126
- camera_features, self.num_freqs
127
- )
128
- plane_features_final = self.plane_coefs(
129
- torch.cat(
130
- [
131
- plane_features.permute(0, 1, 3, 4, 2),
132
- xyz_grid_features_inviewframe_encoding[:, 1:],
133
- xyz_grid_features_inviewframe[:, 1:],
134
- camera_features_inviewframe_encoding,
135
- camera_features_inviewframe[..., 3:],
136
- ],
137
- dim=-1,
138
- )
139
- ) # b, n, hw, d, c
140
-
141
- # plane_features = torch.cat([plane_features, xyz_grid_features, camera_features], dim=1)
142
- if not self.average:
143
- plane_features_attn = nn.functional.softmax(
144
- self.nviews(
145
- torch.cat(
146
- [
147
- plane_features.permute(0, 1, 3, 4, 2),
148
- xyz_grid_features,
149
- xyz_grid_features_inviewframe[:, :1].expand(-1, n, -1, -1, -1),
150
- camera_features,
151
- camera_features_encoding,
152
- ],
153
- dim=-1,
154
- )
155
- ),
156
- dim=1,
157
- ) # b, n, hw, d, 1
158
-
159
- plane_features_final = (plane_features_final * plane_features_attn).sum(1)
160
- else:
161
- plane_features_final = plane_features_final.mean(1)
162
- plane_features_attn = None
163
-
164
- out = self.decoder(plane_features_final)
165
- return torch.cat([plane_features_final, out], dim=-1), plane_features_attn
166
-
167
-
168
- class VolRender(nn.Module):
169
- def __init__(
170
- self,
171
- ):
172
- super().__init__()
173
-
174
- def get_weights(self, densities, deltas):
175
- """Return weights based on predicted densities
176
-
177
- Args:
178
- densities: Predicted densities for samples along ray
179
-
180
- Returns:
181
- Weights for each sample
182
- """
183
- delta_density = deltas * densities # [b, hw, "num_samples", 1]
184
- alphas = 1 - torch.exp(-delta_density)
185
- transmittance = torch.cumsum(delta_density[..., :-1, :], dim=-2)
186
- transmittance = torch.cat(
187
- [
188
- torch.zeros((*transmittance.shape[:2], 1, 1), device=densities.device),
189
- transmittance,
190
- ],
191
- dim=-2,
192
- )
193
- transmittance = torch.exp(-transmittance) # [b, hw, "num_samples", 1]
194
-
195
- weights = alphas * transmittance # [b, hw, "num_samples", 1]
196
- weights = torch.nan_to_num(weights)
197
- # opacities = 1.0 - torch.prod(1.0 - alphas, dim=-2, keepdim=True)
198
- return weights, alphas, transmittance
199
-
200
- def forward(
201
- self,
202
- features,
203
- densities,
204
- dists=None,
205
- return_weight=False,
206
- densities_uniform=None,
207
- dists_uniform=None,
208
- return_weights_uniform=False,
209
- rgb=None
210
- ):
211
- alphas = None
212
- fg_mask = None
213
- if dists is not None:
214
- weights, alphas, transmittance = self.get_weights(densities, dists)
215
- fg_mask = torch.sum(weights, -2)
216
- else:
217
- weights = densities # used when we have a pretraind nerf with direct weights as output
218
-
219
- rendered_feats = torch.sum(weights * features, dim=-2) + torch.sum(
220
- (1 - weights) * torch.zeros_like(features), dim=-2
221
- )
222
- if rgb is not None:
223
- rgb = torch.sum(weights * rgb, dim=-2) + torch.sum(
224
- (1 - weights) * torch.zeros_like(rgb), dim=-2
225
- )
226
- # print("RENDER", fg_mask.shape, weights.shape)
227
- weights_uniform = None
228
- if return_weight:
229
- return rendered_feats, fg_mask, alphas, weights, rgb
230
- elif return_weights_uniform:
231
- if densities_uniform is not None:
232
- weights_uniform, _, transmittance = self.get_weights(densities_uniform, dists_uniform)
233
- return rendered_feats, fg_mask, alphas, weights_uniform, rgb
234
- else:
235
- return rendered_feats, fg_mask, alphas, None, rgb
236
-
237
-
238
- class Raymarcher(nn.Module):
239
- def __init__(
240
- self,
241
- num_samples=32,
242
- far_plane=2.0,
243
- stratified=False,
244
- training=True,
245
- imp_sampling_percent=0.9,
246
- near_plane=0.,
247
- ):
248
- super().__init__()
249
- self.num_samples = num_samples
250
- self.far_plane = far_plane
251
- self.near_plane = near_plane
252
- u_max = 1. / (self.num_samples)
253
- u = torch.linspace(0, 1 - u_max, self.num_samples, device="cuda")
254
- self.register_buffer("u", u)
255
- lengths = torch.linspace(self.near_plane, self.near_plane+self.far_plane, self.num_samples+1, device="cuda")
256
- # u = (u[..., :-1] + u[..., 1:]) / 2.0
257
- lengths_center = (lengths[..., 1:] + lengths[..., :-1]) / 2.0
258
- lengths_upper = torch.cat([lengths_center, lengths[..., -1:]], -1)
259
- lengths_lower = torch.cat([lengths[..., :1], lengths_center], -1)
260
- self.register_buffer("lengths", lengths)
261
- self.register_buffer("lengths_center", lengths_center)
262
- self.register_buffer("lengths_upper", lengths_upper)
263
- self.register_buffer("lengths_lower", lengths_lower)
264
- self.stratified = stratified
265
- self.training = training
266
- self.imp_sampling_percent = imp_sampling_percent
267
-
268
- @torch.no_grad()
269
- def importance_sampling(self, cdf, num_rays, num_samples, device):
270
- # sample target rays for each reference view
271
- cdf = cdf[..., 0] + 0.01
272
- if cdf.shape[1] != num_rays:
273
- size = int(math.sqrt(num_rays))
274
- size_ = int(math.sqrt(cdf.size(1)))
275
- cdf = rearrange(
276
- torch.nn.functional.interpolate(
277
- rearrange(
278
- cdf.permute(0, 2, 1), "... (h w) -> ... h w", h=size_, w=size_
279
- ),
280
- size=[size, size],
281
- antialias=True,
282
- mode="bilinear",
283
- ),
284
- "... h w -> ... (h w)",
285
- h=size,
286
- w=size,
287
- ).permute(0, 2, 1)
288
-
289
- lengths = self.lengths[None, None, :].expand(cdf.shape[0], num_rays, -1)
290
-
291
- cdf_sum = torch.sum(cdf, dim=-1, keepdim=True)
292
- padding = torch.relu(1e-5 - cdf_sum)
293
- cdf = cdf + padding / cdf.shape[-1]
294
- cdf_sum += padding
295
-
296
- pdf = cdf / cdf_sum
297
-
298
- # sample_pdf function
299
- u_max = 1. / (num_samples)
300
- u = self.u[None, None, :].expand(cdf.shape[0], num_rays, -1)
301
- if self.stratified and self.training:
302
- u += torch.rand((cdf.shape[0], num_rays, num_samples), dtype=cdf.dtype, device=cdf.device,) * u_max
303
-
304
- _C.sample_pdf(
305
- lengths.reshape(-1, num_samples + 1),
306
- pdf.reshape(-1, num_samples),
307
- u.reshape(-1, num_samples),
308
- 1e-5,
309
- )
310
- return u, torch.cat([u[..., 1:] - u[..., :-1], lengths[..., -1:] - u[..., -1:] ], -1)
311
-
312
- @torch.no_grad()
313
- def stratified_sampling(self, num_rays, device, uniform=False):
314
- lengths_uniform = self.lengths[None, None, :].expand(-1, num_rays, -1)
315
-
316
- if uniform:
317
- return (
318
- (lengths_uniform[..., 1:] + lengths_uniform[..., :-1]) / 2.0,
319
- lengths_uniform[..., 1:] - lengths_uniform[..., :-1],
320
- )
321
- if self.stratified and self.training:
322
- t_rand = torch.rand(
323
- (num_rays, self.num_samples + 1),
324
- dtype=lengths_uniform.dtype,
325
- device=lengths_uniform.device,
326
- )
327
- jittered = self.lengths_lower[None, None, :].expand(-1, num_rays, -1) + \
328
- (self.lengths_upper[None, None, :].expand(-1, num_rays, -1) - self.lengths_lower[None, None, :].expand(-1, num_rays, -1)) * t_rand
329
- return ((jittered[..., :-1] + jittered[..., 1:])/2., jittered[..., 1:] - jittered[..., :-1])
330
- else:
331
- return (
332
- (lengths_uniform[..., 1:] + lengths_uniform[..., :-1]) / 2.0,
333
- lengths_uniform[..., 1:] - lengths_uniform[..., :-1],
334
- )
335
-
336
- @torch.no_grad()
337
- def forward(self, pose, resolution, weights, imp_sample_next_step=False, device='cuda', pytorch3d=True):
338
- input_patch_rays, xys = get_patch_rays(
339
- pose,
340
- num_patches_x=resolution,
341
- num_patches_y=resolution,
342
- device=device,
343
- return_xys=True,
344
- stratified=self.stratified and self.training,
345
- ) # (b, n, h*w, 6)
346
-
347
- num_rays = resolution**2
348
- # sample target rays for each reference view
349
- if weights is not None:
350
- if self.imp_sampling_percent <= 0:
351
- lengths, dists = self.stratified_sampling(num_rays, device)
352
- elif (torch.rand(1) < (1.-self.imp_sampling_percent)) and self.training:
353
- lengths, dists = self.stratified_sampling(num_rays, device)
354
- else:
355
- lengths, dists = self.importance_sampling(
356
- weights, num_rays, self.num_samples, device=device
357
- )
358
- else:
359
- lengths, dists = self.stratified_sampling(num_rays, device)
360
-
361
- dists_uniform = None
362
- ray_points_uniform = None
363
- if imp_sample_next_step:
364
- lengths_uniform, dists_uniform = self.stratified_sampling(
365
- num_rays, device, uniform=True
366
- )
367
-
368
- target_patch_raybundle_uniform = RayBundle(
369
- origins=input_patch_rays[:, :1, :, :3],
370
- directions=input_patch_rays[:, :1, :, 3:],
371
- lengths=lengths_uniform,
372
- xys=xys.to(device),
373
- )
374
- ray_points_uniform = ray_bundle_to_ray_points(target_patch_raybundle_uniform).detach()
375
- dists_uniform = dists_uniform.detach()
376
-
377
- # print(
378
- # "SAMPLING",
379
- # lengths.shape,
380
- # lengths_uniform.shape,
381
- # dists.shape,
382
- # dists_uniform.shape,
383
- # input_patch_rays.shape,
384
- # )
385
- target_patch_raybundle = RayBundle(
386
- origins=input_patch_rays[:, :1, :, :3],
387
- directions=input_patch_rays[:, :1, :, 3:],
388
- lengths=lengths.to(device),
389
- xys=xys.to(device),
390
- )
391
- ray_points = ray_bundle_to_ray_points(target_patch_raybundle)
392
- return (
393
- input_patch_rays.detach(),
394
- ray_points.detach(),
395
- dists.detach(),
396
- ray_points_uniform,
397
- dists_uniform,
398
- )
399
-
400
-
401
- class NerfSDModule(nn.Module):
402
- def __init__(
403
- self,
404
- mode="feature-nerf",
405
- out_channels=None,
406
- far_plane=2.0,
407
- num_samples=32,
408
- rgb_predict=False,
409
- average=False,
410
- num_freqs=16,
411
- stratified=False,
412
- imp_sampling_percent=0.9,
413
- near_plane=0.
414
- ):
415
- MODES = {
416
- "feature-nerf": FeatureNeRFEncoding, # ampere
417
- }
418
- super().__init__()
419
- self.rgb_predict = rgb_predict
420
-
421
- self.raymarcher = Raymarcher(
422
- num_samples=num_samples,
423
- far_plane=near_plane + far_plane,
424
- stratified=stratified,
425
- imp_sampling_percent=imp_sampling_percent,
426
- near_plane=near_plane,
427
- )
428
- model_class = MODES[mode]
429
- self.model = model_class(
430
- out_channels,
431
- out_channels,
432
- far_plane=near_plane + far_plane,
433
- rgb_predict=rgb_predict,
434
- average=average,
435
- num_freqs=num_freqs,
436
- )
437
-
438
- def forward(self, pose, xref=None, mask_ref=None, prev_weights=None, imp_sample_next_step=False,):
439
- # xref: b n h w c or b n hw c
440
- # pose: a list of pytorch3d cameras
441
- # mask_ref: mask corresponding to black regions because of padding non square images.
442
- rgb = None
443
- dists_uniform = None
444
- weights_uniform = None
445
- resolution = (int(math.sqrt(xref.size(2))) if len(xref.shape) == 4 else xref.size(3))
446
- input_patch_rays, ray_points, dists, ray_points_uniform, dists_uniform = (self.raymarcher(pose, resolution, weights=prev_weights, device=xref.device))
447
- output, plane_features_attn = self.model(pose, xref, ray_points, input_patch_rays, mask_ref)
448
- weights = output[..., -1:]
449
- features = output[..., :-1]
450
- if self.rgb_predict:
451
- rgb = features[..., -3:]
452
- features = features[..., :-3]
453
- dists = dists.unsqueeze(-1)
454
- with torch.no_grad():
455
- if ray_points_uniform is not None:
456
- output_uniform, _ = self.model(pose, xref, ray_points_uniform, input_patch_rays, mask_ref)
457
- weights_uniform = output_uniform[..., -1:]
458
- dists_uniform = dists_uniform.unsqueeze(-1)
459
-
460
- return (
461
- features,
462
- weights,
463
- dists,
464
- plane_features_attn,
465
- rgb,
466
- weights_uniform,
467
- dists_uniform,
468
- )