DveloperY0115 commited on
Commit
801501a
·
1 Parent(s): 7d3169e
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. README.md +5 -4
  3. app.py +116 -0
  4. checkpoints/lang_phase1/hparams.yaml +48 -0
  5. checkpoints/lang_phase1/state_only.ckpt +3 -0
  6. checkpoints/lang_phase2/hparams.yaml +47 -0
  7. checkpoints/lang_phase2/state_only.ckpt +3 -0
  8. checkpoints/phase1/hparams.yaml +39 -0
  9. checkpoints/phase1/state_only.ckpt +3 -0
  10. checkpoints/phase2/hparams.yaml +41 -0
  11. checkpoints/phase2/state_only.ckpt +3 -0
  12. custom_wheels/salad-0.1-py3-none-any.whl +0 -0
  13. data/autosdf_spaghetti_intersec_game_data.csv +0 -0
  14. data/spaghetti_airplane_latents.hdf5 +3 -0
  15. data/spaghetti_airplane_latents_mean_std.hdf5 +3 -0
  16. data/spaghetti_chair_latents.hdf5 +3 -0
  17. data/spaghetti_chair_latents_mean_std.hdf5 +3 -0
  18. data/spaghetti_table_latents.hdf5 +3 -0
  19. data/spaghetti_table_latents_mean_std.hdf5 +3 -0
  20. requirements.txt +1 -0
  21. salad.egg-info/PKG-INFO +5 -0
  22. salad.egg-info/SOURCES.txt +7 -0
  23. salad.egg-info/dependency_links.txt +1 -0
  24. salad.egg-info/not-zip-safe +1 -0
  25. salad.egg-info/top_level.txt +1 -0
  26. salad/data/__pycache__/dataset.cpython-39.pyc +0 -0
  27. salad/data/dataset.py +149 -0
  28. salad/model_components/__pycache__/lstm.cpython-39.pyc +0 -0
  29. salad/model_components/__pycache__/network.cpython-39.pyc +0 -0
  30. salad/model_components/__pycache__/simple_module.cpython-39.pyc +0 -0
  31. salad/model_components/__pycache__/transformer.cpython-39.pyc +0 -0
  32. salad/model_components/__pycache__/variance_schedule.cpython-39.pyc +0 -0
  33. salad/model_components/lstm.py +56 -0
  34. salad/model_components/network.py +229 -0
  35. salad/model_components/simple_module.py +125 -0
  36. salad/model_components/transformer.py +308 -0
  37. salad/model_components/variance_schedule.py +57 -0
  38. salad/models/__init__.py +0 -0
  39. salad/models/__pycache__/__init__.cpython-39.pyc +0 -0
  40. salad/models/__pycache__/base_model.cpython-39.pyc +0 -0
  41. salad/models/__pycache__/language_phase1.cpython-39.pyc +0 -0
  42. salad/models/__pycache__/language_phase2.cpython-39.pyc +0 -0
  43. salad/models/__pycache__/phase1.cpython-39.pyc +0 -0
  44. salad/models/__pycache__/phase2.cpython-39.pyc +0 -0
  45. salad/models/base_model.py +147 -0
  46. salad/models/language_phase1.py +340 -0
  47. salad/models/language_phase2.py +201 -0
  48. salad/models/phase1.py +65 -0
  49. salad/models/phase2.py +183 -0
  50. salad/spaghetti/.gitignore +9 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.hdf5 filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Salad Demo
3
- emoji: 🏆
4
  colorFrom: blue
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.38.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Test
3
+ emoji: 🦀
4
  colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py
3
+
4
+ An interactive demo of text-guided shape generation.
5
+ """
6
+
7
+ from pathlib import Path
8
+ from typing import Literal
9
+
10
+ import gradio as gr
11
+ import plotly.graph_objects as go
12
+
13
+ from salad.utils.spaghetti_util import (
14
+ get_mesh_from_spaghetti,
15
+ generate_zc_from_sj_gaus,
16
+ load_mesher,
17
+ load_spaghetti,
18
+ )
19
+ import hydra
20
+ from omegaconf import OmegaConf
21
+ import torch
22
+ from pytorch_lightning import seed_everything
23
+
24
+
25
+ def load_model(
26
+ model_class: Literal["phase1", "phase2", "lang_phase1", "lang_phase2"],
27
+ device,
28
+ ):
29
+ checkpoint_dir = Path(__file__).parent / "checkpoints"
30
+ c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml")
31
+ model = hydra.utils.instantiate(c)
32
+ ckpt = torch.load(checkpoint_dir / f"{model_class}/state_only.ckpt")
33
+ model.load_state_dict(ckpt)
34
+ model.eval()
35
+ for p in model.parameters(): p.requires_grad_(False)
36
+ model = model.to(device)
37
+ return model
38
+
39
+
40
+ def run_inference(prompt: str):
41
+ """The entry point of the demo."""
42
+
43
+ device: torch.device = torch.device("cuda")
44
+ """Device to run the demo on."""
45
+ seed: int = 63
46
+ """Random seed for reproducibility."""
47
+
48
+ # set random seed
49
+ seed_everything(seed)
50
+
51
+ # load SPAGHETTI and mesher
52
+ spaghetti = load_spaghetti(device)
53
+ mesher = load_mesher(device)
54
+
55
+ # load SALAD
56
+ lang_phase1_model = load_model("lang_phase1", device)
57
+ lang_phase2_model = load_model("phase2", device)
58
+ lang_phase1_model._build_dataset("val")
59
+
60
+ # run phase 1
61
+ extrinsics = lang_phase1_model.sampling_gaussians([prompt])
62
+
63
+ # run phase 2
64
+ intrinsics = lang_phase2_model.sample(extrinsics)
65
+
66
+ # generate mesh
67
+ zcs = generate_zc_from_sj_gaus(spaghetti, intrinsics, extrinsics)
68
+ vertices, faces = get_mesh_from_spaghetti(
69
+ spaghetti,
70
+ mesher,
71
+ zcs[0],
72
+ res=256,
73
+ )
74
+
75
+ # plot
76
+ figure = go.Figure(
77
+ data=[
78
+ go.Mesh3d(
79
+ x=vertices[:, 0], # flip front-back
80
+ y=-vertices[:, 2],
81
+ z=vertices[:, 1],
82
+ i=faces[:, 0],
83
+ j=faces[:, 1],
84
+ k=faces[:, 2],
85
+ color="gray",
86
+ )
87
+ ],
88
+ layout=dict(
89
+ scene=dict(
90
+ xaxis=dict(visible=False),
91
+ yaxis=dict(visible=False),
92
+ zaxis=dict(visible=False),
93
+ )
94
+ ),
95
+ )
96
+
97
+ return figure
98
+
99
+ if __name__ == "__main__":
100
+
101
+ # create UI
102
+ demo = gr.Interface(
103
+ fn=run_inference,
104
+ inputs="text",
105
+ outputs=gr.Plot(),
106
+ title="SALAD: Text-Guided Shape Generation",
107
+ description="Describe a chair",
108
+ examples=[
109
+ "an office chair",
110
+ "a chair with armrests",
111
+ "a chair without armrests",
112
+ ]
113
+ )
114
+ # initiate
115
+ demo.queue(max_size=30)
116
+ demo.launch()
checkpoints/lang_phase1/hparams.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: salad.models.language_phase1.LangPhase1Model
2
+
3
+ network:
4
+ _target_: salad.model_components.network.CondDiffNetwork
5
+ input_dim: 16
6
+ residual: true
7
+ context_dim: 768
8
+ context_embedding_dim: 1024
9
+ embedding_dim: 512
10
+ encoder_use_time: false
11
+ encoder_type: pointwise
12
+ decoder_type: transformer_encoder
13
+ enc_num_layers: 2
14
+ dec_num_layers: 6
15
+ use_timestep_embedder: true
16
+ timestep_embedder_dim: 128
17
+
18
+ variance_schedule:
19
+ _target_: salad.model_components.variance_schedule.VarianceSchedule
20
+ num_steps: &time_steps 1000
21
+ beta_1: 1e-4
22
+ beta_T: 0.05
23
+ mode: linear
24
+
25
+ # optimizer
26
+ lr: 1e-4
27
+ batch_size: 64
28
+
29
+ # dataset
30
+ dataset_kwargs:
31
+ data_path: spaghetti_chair_latents.hdf5
32
+ repeat: 1
33
+ data_keys: ["g_js_affine"]
34
+ only_easy_context: false
35
+ global_normalization: &normalization partial
36
+
37
+ global_normalization: *normalization
38
+ num_timesteps: *time_steps
39
+ faster: true
40
+ validation_step: 10
41
+ no_run_validation: false
42
+ spaghetti_tag: "chairs_large" # or airplanes, tables
43
+
44
+ text_encoder_freeze: false
45
+ use_lstm: true
46
+ classifier_free_guidance: true
47
+ conditioning_dropout_prob: 0.2
48
+
checkpoints/lang_phase1/state_only.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf46454eaaabbb7f3008c51beaae5b16794b189f3cae48f79db70fcdf5413380
3
+ size 318782397
checkpoints/lang_phase2/hparams.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: salad.models.language_phase2.LangPhase2Model
2
+ network:
3
+ _target_: salad.model_components.network.CondDiffNetwork
4
+ input_dim: 512
5
+ residual: true
6
+ context_dim: 784 # concat of 768 lang feat and gaussian.
7
+ context_embedding_dim: 1024
8
+ embedding_dim: 512
9
+ encoder_use_time: false
10
+ encoder_type: transformer
11
+ decoder_type: transformer_encoder
12
+ enc_num_layers: 6
13
+ dec_num_layers: 6
14
+ use_timestep_embedder: true
15
+ timestep_embedder_dim: 128
16
+
17
+ variance_schedule:
18
+ _target_: salad.model_components.variance_schedule.VarianceSchedule
19
+ num_steps: &time_steps 1000
20
+ beta_1: 1e-4
21
+ beta_T: 0.05
22
+ mode: linear
23
+
24
+ # optimizer
25
+ lr: 1e-4
26
+ batch_size: 64
27
+
28
+ # dataset
29
+ dataset_kwargs:
30
+ data_path: spaghetti_chair_latents.hdf5
31
+ repeat: 1
32
+ data_keys: ["s_j_affine", "g_js_affine"]
33
+ only_easy_context: false
34
+ global_normalization: &normalization false
35
+
36
+ global_normalization: *normalization
37
+ num_timesteps: *time_steps
38
+ faster: true
39
+ validation_step: 10
40
+ no_run_validation: false
41
+ spaghetti_tag: "chairs_large" # or airplanes, tables
42
+
43
+ text_encoder_freeze: false
44
+ use_lstm: true
45
+ classifier_free_guidance: true
46
+ conditioning_dropout_prob: 0.2
47
+
checkpoints/lang_phase2/state_only.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4105dd24201fa8aad3fc4db2f74376f98f4df53b38ae749d944bfdb6552ea40f
3
+ size 455307461
checkpoints/phase1/hparams.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: salad.models.phase1.Phase1Model
2
+
3
+ network:
4
+ _target_: salad.model_components.network.UnCondDiffNetwork
5
+ input_dim: 16
6
+ embedding_dim: 512
7
+ num_heads: 4
8
+ use_timestep_embedder: true
9
+ timestep_embedder_dim: 128
10
+ enc_num_layers: 6
11
+ residual: true
12
+ encoder_type: transformer
13
+ attn_dropout: 0.0
14
+
15
+ variance_schedule:
16
+ _target_: salad.model_components.variance_schedule.VarianceSchedule
17
+ num_steps: &time_steps 1000
18
+ beta_1: 1e-4
19
+ beta_T: 0.05
20
+ mode: linear
21
+
22
+ # optimizer
23
+ lr: 1e-4
24
+ batch_size: 64
25
+
26
+ # dataset
27
+ dataset_kwargs:
28
+ data_path: spaghetti_chair_latents.hdf5
29
+ repeat: 3
30
+ data_keys: ["g_js_affine"]
31
+ global_normalization: &normalization partial
32
+
33
+ global_normalization: *normalization # normalize pi, eigenvalues.
34
+ num_timesteps: *time_steps
35
+ faster: true
36
+ validation_step: 10
37
+ no_run_validation: false
38
+ spaghetti_tag: "chairs_large" # or airplanes, tables
39
+
checkpoints/phase1/state_only.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f616fa657723de4855e8571f3ef828ff25221b86cb516a755aaa93538b0c7de
3
+ size 60275831
checkpoints/phase2/hparams.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: salad.models.phase2.Phase2Model
2
+
3
+ network:
4
+ _target_: salad.model_components.network.CondDiffNetwork
5
+ input_dim: 512
6
+ residual: true
7
+ context_dim: 16 # gaussian condition dim.
8
+ context_embedding_dim: 512
9
+ embedding_dim: 512
10
+ encoder_use_time: false
11
+ encoder_type: transformer
12
+ decoder_type: transformer_encoder # we don't use cross attention.
13
+ enc_num_layers: 6
14
+ dec_num_layers: 6
15
+ use_timestep_embedder: true
16
+ timestep_embedder_dim: 128
17
+
18
+ variance_schedule:
19
+ _target_: salad.model_components.variance_schedule.VarianceSchedule
20
+ num_steps: &time_steps 1000
21
+ beta_1: 1e-4
22
+ beta_T: 0.05
23
+ mode: linear
24
+
25
+ # optimizer
26
+ lr: 1e-4
27
+ batch_size: 64
28
+
29
+ # dataset
30
+ dataset_kwargs:
31
+ data_path: spaghetti_chair_latents.hdf5
32
+ repeat: 3
33
+ data_keys: ["s_j_affine", "g_js_affine"]
34
+ global_normalization: &normalization null
35
+
36
+ global_normalization: *normalization # normalize pi, eigenvalues.
37
+ num_timesteps: *time_steps
38
+ faster: true
39
+ validation_step: 10
40
+ no_run_validation: false
41
+ spaghetti_tag: "chairs_large" # or airplanes, tables
checkpoints/phase2/state_only.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aed08103f6eebbd84fac523affaab2cd493f8f2a1d5e81e9d298cc0a7a807ed2
3
+ size 150592331
custom_wheels/salad-0.1-py3-none-any.whl ADDED
Binary file (994 Bytes). View file
 
data/autosdf_spaghetti_intersec_game_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/spaghetti_airplane_latents.hdf5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c242271687d13159b0df44a3179a0d460c9e87c577851d8d0282f0369a529f46
3
+ size 222017536
data/spaghetti_airplane_latents_mean_std.hdf5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c32e24c7786593ffdd918e05fdd1148634c3c707a4948e0a3bf6a6c002b540e1
3
+ size 12544
data/spaghetti_chair_latents.hdf5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bfa1533a0366e9271f6bf96d4f7a135f8763ba66ce26b5cc952e6af14e5bfe4
3
+ size 1255457792
data/spaghetti_chair_latents_mean_std.hdf5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6be55ae235fe77aa821146122f0e911e9593e78a001b8fa63dea041c49095fa
3
+ size 8320
data/spaghetti_table_latents.hdf5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20ef1da19e47e2c23782defa2c5d2172d7322c5476c92f7ed3fee271e3893f91
3
+ size 1127843840
data/spaghetti_table_latents_mean_std.hdf5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:718825073ede0c52ccd5c277b1114425558a2a45258dd952a4707fecb3dc5d57
3
+ size 8320
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ./custom_wheels/salad-0.1-py3-none-any.whl
salad.egg-info/PKG-INFO ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: salad
3
+ Version: 0.1
4
+ Summary: SALAD: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation
5
+ Home-page: https://github.com/63days/SALAD
salad.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ README.md
2
+ setup.py
3
+ salad.egg-info/PKG-INFO
4
+ salad.egg-info/SOURCES.txt
5
+ salad.egg-info/dependency_links.txt
6
+ salad.egg-info/not-zip-safe
7
+ salad.egg-info/top_level.txt
salad.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
salad.egg-info/not-zip-safe ADDED
@@ -0,0 +1 @@
 
 
1
+
salad.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ salad
salad/data/__pycache__/dataset.cpython-39.pyc ADDED
Binary file (4.61 kB). View file
 
salad/data/dataset.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import h5py
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ from dotmap import DotMap
6
+
7
+ from salad.utils.paths import DATA_DIR
8
+ from salad.utils import thutil
9
+
10
+
11
+ class SALADDataset(torch.utils.data.Dataset):
12
+ def __init__(self, data_path, repeat=None, **kwargs):
13
+ super().__init__()
14
+ self.data_path = str(DATA_DIR / data_path)
15
+ self.repeat = repeat
16
+ self.__dict__.update(kwargs)
17
+ self.hparams = DotMap(self.__dict__)
18
+
19
+ """
20
+ Global Data statistics.
21
+ """
22
+ if self.hparams.get("global_normalization"):
23
+ with h5py.File(self.data_path.replace(".hdf5", "_mean_std.hdf5")) as f:
24
+ self.global_mean = f["mean"][:].astype(np.float32)
25
+ self.global_std = f["std"][:].astype(np.float32)
26
+
27
+ self.data = dict()
28
+ with h5py.File(self.data_path) as f:
29
+ for k in self.hparams.data_keys:
30
+ self.data[k] = f[k][:].astype(np.float32)
31
+
32
+ """
33
+ global_normalization arg is for gaussians only.
34
+ """
35
+ if k == "g_js_affine":
36
+ if self.hparams.get("global_normalization") == "partial":
37
+ assert k == "g_js_affine"
38
+ if self.hparams.get("verbose"):
39
+ print("[*] Normalize data only for pi and eigenvalues.")
40
+ # 3: mu, 9: eigvec, 1: pi, 3: eigval
41
+ self.data[k] = self.normalize_global_static(
42
+ self.data[k], slice(12, None)
43
+ )
44
+ elif self.hparams.get("global_normalization") == "all":
45
+ assert k == "g_js_affine"
46
+ if self.hparams.get("verbose"):
47
+ print("[*] Normalize data for all elements.")
48
+ self.data[k] = self.normalize_global_static(
49
+ self.data[k], slice(None)
50
+ )
51
+
52
+ def __getitem__(self, idx):
53
+ if self.repeat is not None and self.repeat > 1:
54
+ idx = int(idx / self.repeat)
55
+
56
+ items = []
57
+ for k in self.hparams.data_keys:
58
+ data = torch.from_numpy(self.data[k][idx])
59
+ items.append(data)
60
+
61
+ if self.hparams.get("concat_data"):
62
+ return torch.cat(items, -1) # [16,528]
63
+ if len(items) == 1:
64
+ return items[0]
65
+ return items
66
+
67
+ def __len__(self):
68
+ k = self.hparams.data_keys[0]
69
+ if self.repeat is not None and self.repeat > 1:
70
+ return len(self.data[k]) * self.repeat
71
+ return len(self.data[k])
72
+
73
+ def get_other_latents(self, key):
74
+ with h5py.File(self.data_path) as f:
75
+ return f[key][:].astype(np.float32)
76
+
77
+ def normalize_global_static(self, data: np.ndarray, normalize_indices=slice(None)):
78
+ """
79
+ Input:
80
+ np.ndarray or torch.Tensor. [16,16] or [B,16,16]
81
+ slice(None) -> full
82
+ slice(12, None) -> partial
83
+ Output:
84
+ [16,16] or [B,16,16]
85
+ """
86
+ assert normalize_indices == slice(None) or normalize_indices == slice(
87
+ 12, None
88
+ ), print(f"{normalize_indices} is wrong.")
89
+ data = thutil.th2np(data).copy()
90
+ data[..., normalize_indices] = (
91
+ data[..., normalize_indices] - self.global_mean[normalize_indices]
92
+ ) / self.global_std[normalize_indices]
93
+ return data
94
+
95
+ def unnormalize_global_static(
96
+ self, data: np.ndarray, unnormalize_indices=slice(None)
97
+ ):
98
+ """
99
+ Input:
100
+ np.ndarray or torch.Tensor. [16,16] or [B,16,16]
101
+ slice(None) -> full
102
+ slice(12, None) -> partial
103
+ Output:
104
+ [16,16] or [B,16,16]
105
+ """
106
+ assert unnormalize_indices == slice(None) or unnormalize_indices == slice(
107
+ 12, None
108
+ ), print(f"{unnormalize_indices} is wrong.")
109
+ data = thutil.th2np(data).copy()
110
+ data[..., unnormalize_indices] = (
111
+ data[..., unnormalize_indices]
112
+ ) * self.global_std[unnormalize_indices] + self.global_mean[unnormalize_indices]
113
+ return data
114
+
115
+
116
+ class LangSALADDataset(SALADDataset):
117
+ def __init__(self, data_path, repeat=None, **kwargs):
118
+ super().__init__(data_path, repeat, **kwargs)
119
+
120
+ # self.game_data = pd.read_csv(self.hparams.lang_data_path)
121
+ self.game_data = pd.read_csv(DATA_DIR / "autosdf_spaghetti_intersec_game_data.csv")
122
+ self.shapenet_ids = np.array(self.game_data["sn"])
123
+ self.spaghetti_indices = np.array(self.game_data["spaghetti_idx"]) # for 5401
124
+ self.texts = np.array(self.game_data["text"])
125
+
126
+ assert len(self.shapenet_ids) == len(self.spaghetti_indices) == len(self.texts)
127
+
128
+ def __getitem__(self, idx):
129
+ if self.repeat is not None and self.repeat > 1:
130
+ idx = int(idx / self.repeat)
131
+
132
+ spa_idx = self.spaghetti_indices[idx]
133
+ text = self.texts[idx]
134
+ latents = []
135
+ for k in self.hparams.data_keys:
136
+ data = torch.from_numpy(self.data[k][spa_idx])
137
+ latents.append(data)
138
+
139
+ item = latents + [text]
140
+ if self.hparams.get("concat_data"):
141
+ latents = torch.cat(latents, -1)
142
+ return latents, text
143
+
144
+ return item
145
+
146
+ def __len__(self):
147
+ if self.repeat is not None and self.repeat > 1:
148
+ return len(self.texts) * self.repeat
149
+ return len(self.texts)
salad/model_components/__pycache__/lstm.cpython-39.pyc ADDED
Binary file (2.38 kB). View file
 
salad/model_components/__pycache__/network.cpython-39.pyc ADDED
Binary file (4.73 kB). View file
 
salad/model_components/__pycache__/simple_module.cpython-39.pyc ADDED
Binary file (3.9 kB). View file
 
salad/model_components/__pycache__/transformer.cpython-39.pyc ADDED
Binary file (8.63 kB). View file
 
salad/model_components/__pycache__/variance_schedule.cpython-39.pyc ADDED
Binary file (1.99 kB). View file
 
salad/model_components/lstm.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
5
+
6
+
7
+ class LSTM(nn.Module):
8
+ def __init__(self, text_dim, embedding_dim, vocab_size, padding_idx=0):
9
+ super().__init__()
10
+ self.padding_idx = padding_idx
11
+ self.word_embedding = nn.Embedding(
12
+ vocab_size, embedding_dim, padding_idx=padding_idx
13
+ )
14
+ self.rnn = nn.LSTM(embedding_dim, text_dim, batch_first=True)
15
+ self.w_attn = nn.Parameter(torch.Tensor(1, text_dim))
16
+ nn.init.xavier_uniform_(self.w_attn)
17
+
18
+ def forward(self, padded_tokens, dropout=0.5):
19
+ w_emb = self.word_embedding(padded_tokens)
20
+ w_emb = F.dropout(w_emb, dropout, self.training)
21
+ len_seq = (padded_tokens != self.padding_idx).sum(dim=1).cpu()
22
+ x_packed = pack_padded_sequence(
23
+ w_emb, len_seq, enforce_sorted=False, batch_first=True
24
+ )
25
+ B = padded_tokens.shape[0]
26
+ rnn_out, _ = self.rnn(x_packed)
27
+ rnn_out, dummy = pad_packed_sequence(rnn_out, batch_first=True)
28
+ h = rnn_out[torch.arange(B), len_seq - 1]
29
+ final_feat, attn = self.word_attention(rnn_out, h, len_seq)
30
+ return final_feat, attn
31
+
32
+ def word_attention(self, R, h, len_seq):
33
+ """
34
+ Input:
35
+ R: hidden states of the entire words
36
+ h: the final hidden state after processing the entire words
37
+ len_seq: the length of the sequence
38
+ Output:
39
+ final_feat: the final feature after the bilinear attention
40
+ attn: word attention weights
41
+ """
42
+ B, N, D = R.shape
43
+ device = R.device
44
+ len_seq = len_seq.to(device)
45
+
46
+ W_attn = (self.w_attn * torch.eye(D).to(device))[None].repeat(B, 1, 1)
47
+ score = torch.bmm(torch.bmm(R, W_attn), h.unsqueeze(-1))
48
+
49
+ mask = torch.arange(N).reshape(1, N, 1).repeat(B, 1, 1).to(device)
50
+ mask = mask < len_seq.reshape(B, 1, 1)
51
+
52
+ score = score.masked_fill(mask == 0, -1e9)
53
+ attn = F.softmax(score, 1)
54
+ final_feat = torch.bmm(R.transpose(1, 2), attn).squeeze(-1)
55
+
56
+ return final_feat, attn.squeeze(-1)
salad/model_components/network.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dotmap import DotMap
5
+ from salad.model_components.simple_module import TimePointWiseEncoder, TimestepEmbedder
6
+
7
+
8
+ from salad.model_components.transformer import (
9
+ PositionalEncoding,
10
+ TimeTransformerDecoder,
11
+ TimeTransformerEncoder,
12
+ )
13
+
14
+ class UnCondDiffNetwork(nn.Module):
15
+ def __init__(self, input_dim, residual, **kwargs):
16
+ """
17
+ Transformer Encoder.
18
+ """
19
+ super().__init__()
20
+ self.input_dim = input_dim
21
+ self.residual = residual
22
+ self.__dict__.update(kwargs)
23
+ self.hparams = DotMap(self.__dict__)
24
+
25
+ self._build_model()
26
+
27
+ def _build_model(self):
28
+ self.act = F.leaky_relu
29
+ if self.hparams.get("use_timestep_embedder"):
30
+ self.time_embedder = TimestepEmbedder(self.hparams.timestep_embedder_dim)
31
+ dim_ctx = self.hparams.timestep_embedder_dim
32
+ else:
33
+ dim_ctx = 3
34
+
35
+ """
36
+ Encoder part
37
+ """
38
+ enc_dim = self.hparams.embedding_dim
39
+ self.embedding = nn.Linear(self.hparams.input_dim, enc_dim)
40
+ if not self.hparams.get("encoder_type"):
41
+ self.encoder = TimeTransformerEncoder(
42
+ enc_dim,
43
+ dim_ctx=dim_ctx,
44
+ num_heads=self.hparams.num_heads
45
+ if self.hparams.get("num_heads")
46
+ else 4,
47
+ use_time=True,
48
+ num_layers=self.hparams.enc_num_layers,
49
+ last_fc=True,
50
+ last_fc_dim_out=self.hparams.input_dim,
51
+ )
52
+ else:
53
+ if self.hparams.encoder_type == "transformer":
54
+ self.encoder = TimeTransformerEncoder(
55
+ enc_dim,
56
+ dim_ctx=dim_ctx,
57
+ num_heads=self.hparams.num_heads
58
+ if self.hparams.get("num_heads")
59
+ else 4,
60
+ use_time=True,
61
+ num_layers=self.hparams.enc_num_layers,
62
+ last_fc=True,
63
+ last_fc_dim_out=self.hparams.input_dim,
64
+ dropout=self.hparams.get("attn_dropout", 0.0)
65
+ )
66
+ else:
67
+ raise ValueError
68
+
69
+ def forward(self, x, beta):
70
+ """
71
+ Input:
72
+ x: [B,G,D] latent
73
+ beta: B
74
+ Output:
75
+ eta: [B,G,D]
76
+ """
77
+ B, G = x.shape[:2]
78
+ if self.hparams.get("use_timestep_embedder"):
79
+ time_emb = self.time_embedder(beta).unsqueeze(1)
80
+ else:
81
+ beta = beta.view(B, 1, 1)
82
+ time_emb = torch.cat(
83
+ [beta, torch.sin(beta), torch.cos(beta)], dim=-1
84
+ ) # [B,1,3]
85
+
86
+ ctx = time_emb
87
+ x_emb = self.embedding(x)
88
+
89
+ out = self.encoder(x_emb, ctx=ctx)
90
+
91
+ if self.hparams.residual:
92
+ out = out + x
93
+ return out
94
+
95
+
96
+ class CondDiffNetwork(nn.Module):
97
+ def __init__(self, input_dim, residual, **kwargs):
98
+ """
99
+ Transformer Encoder + Decoder.
100
+ """
101
+ super().__init__()
102
+ self.input_dim = input_dim
103
+ self.residual = residual
104
+ self.__dict__.update(kwargs)
105
+ self.hparams = DotMap(self.__dict__)
106
+
107
+ self._build_model()
108
+
109
+ def _build_model(self):
110
+ self.act = F.leaky_relu
111
+ if self.hparams.get("use_timestep_embedder"):
112
+ self.time_embedder = TimestepEmbedder(self.hparams.timestep_embedder_dim)
113
+ dim_ctx = self.hparams.timestep_embedder_dim
114
+ else:
115
+ dim_ctx = 3
116
+ """
117
+ Encoder part
118
+ """
119
+ enc_dim = self.hparams.context_embedding_dim
120
+ self.context_embedding = nn.Linear(self.hparams.context_dim, enc_dim)
121
+ if self.hparams.encoder_type == "transformer":
122
+ self.encoder = TimeTransformerEncoder(
123
+ enc_dim,
124
+ 3,
125
+ num_heads=4,
126
+ use_time=self.hparams.encoder_use_time,
127
+ num_layers=self.hparams.enc_num_layers
128
+ if self.hparams.get("enc_num_layers")
129
+ else 3,
130
+ last_fc=False,
131
+ )
132
+
133
+ elif self.hparams.encoder_type == "pointwise":
134
+ self.encoder = TimePointWiseEncoder(
135
+ enc_dim,
136
+ dim_ctx=None,
137
+ use_time=self.hparams.encoder_use_time,
138
+ num_layers=self.hparams.enc_num_layers,
139
+ )
140
+ else:
141
+ raise ValueError
142
+
143
+ """
144
+ Decoder part
145
+ """
146
+ dec_dim = self.hparams.embedding_dim
147
+ input_dim = self.hparams.input_dim
148
+ self.query_embedding = nn.Linear(self.hparams.input_dim, dec_dim)
149
+ if self.hparams.decoder_type == "transformer_decoder":
150
+ self.decoder = TimeTransformerDecoder(
151
+ dec_dim,
152
+ enc_dim,
153
+ dim_ctx=dim_ctx,
154
+ num_heads=4,
155
+ last_fc=True,
156
+ last_fc_dim_out=input_dim,
157
+ num_layers=self.hparams.dec_num_layers
158
+ if self.hparams.get("dec_num_layers")
159
+ else 3,
160
+ )
161
+ elif self.hparams.decoder_type == "transformer_encoder":
162
+ self.decoder = TimeTransformerEncoder(
163
+ dec_dim,
164
+ dim_ctx=enc_dim + dim_ctx,
165
+ num_heads=4,
166
+ last_fc=True,
167
+ last_fc_dim_out=input_dim,
168
+ num_layers=self.hparams.dec_num_layers
169
+ if self.hparams.get("dec_num_layers")
170
+ else 3,
171
+ )
172
+ else:
173
+ raise ValueError
174
+
175
+ def forward(self, x, beta, context):
176
+ """
177
+ Input:
178
+ x: [B,G,D] intrinsic
179
+ beta: B
180
+ context: [B,G,D2] or [B, D2] condition
181
+ Output:
182
+ eta: [B,G,D]
183
+ """
184
+ # print(f"x: {x.shape} context: {context.shape} beta: {beta.shape}")
185
+ B, G = x.shape[:2]
186
+
187
+ if self.hparams.get("use_timestep_embedder"):
188
+ time_emb = self.time_embedder(beta).unsqueeze(1)
189
+ else:
190
+ beta = beta.view(B, 1, 1)
191
+ time_emb = torch.cat(
192
+ [beta, torch.sin(beta), torch.cos(beta)], dim=-1
193
+ ) # [B,1,3]
194
+ ctx = time_emb
195
+ """
196
+ Encoding
197
+ """
198
+ cout = self.context_embedding(context)
199
+ cout = self.encoder(cout, ctx=ctx if self.hparams.encoder_use_time else None)
200
+
201
+ if cout.ndim == 2:
202
+ cout = cout.unsqueeze(1).expand(-1, G, -1)
203
+
204
+ """
205
+ Decoding
206
+ """
207
+ out = self.query_embedding(x)
208
+ if self.hparams.get("use_pos_encoding"):
209
+ out = self.pos_encoding(out)
210
+
211
+ if self.hparams.decoder_type == "transformer_encoder":
212
+ try:
213
+ ctx = ctx.expand(-1, G, -1)
214
+ if cout.ndim == 2:
215
+ cout = cout.unsqueeze(1)
216
+ cout = cout.expand(-1, G, -1)
217
+ ctx = torch.cat([ctx, cout], -1)
218
+ except Exception as e:
219
+ print(e, G, ctx.shape, cout.shape)
220
+ out = self.decoder(out, ctx=ctx)
221
+ else:
222
+ out = self.decoder(out, cout, ctx=ctx)
223
+
224
+ # if hasattr(self, "last_fc"):
225
+ # out = self.last_fc(out)
226
+
227
+ if self.hparams.residual:
228
+ out = out + x
229
+ return out
salad/model_components/simple_module.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ from salad.model_components.transformer import TimeMLP
7
+
8
+
9
+ class TimePointwiseLayer(nn.Module):
10
+ def __init__(
11
+ self,
12
+ dim_in,
13
+ dim_ctx,
14
+ mlp_ratio=2,
15
+ act=F.leaky_relu,
16
+ dropout=0.0,
17
+ use_time=False,
18
+ ):
19
+ super().__init__()
20
+ self.use_time = use_time
21
+ self.act = act
22
+ self.mlp1 = TimeMLP(
23
+ dim_in, dim_in * mlp_ratio, dim_in, dim_ctx, use_time=use_time
24
+ )
25
+ self.norm1 = nn.LayerNorm(dim_in)
26
+
27
+ self.mlp2 = TimeMLP(
28
+ dim_in, dim_in * mlp_ratio, dim_in, dim_ctx, use_time=use_time
29
+ )
30
+ self.norm2 = nn.LayerNorm(dim_in)
31
+ self.dropout = nn.Dropout(dropout)
32
+
33
+ def forward(self, x, ctx=None):
34
+ res = x
35
+ x = self.mlp1(x, ctx=ctx)
36
+ x = self.norm1(x + res)
37
+
38
+ res = x
39
+ x = self.mlp2(x, ctx=ctx)
40
+ x = self.norm2(x + res)
41
+ return x
42
+
43
+
44
+ class TimePointWiseEncoder(nn.Module):
45
+ def __init__(
46
+ self,
47
+ dim_in,
48
+ dim_ctx=None,
49
+ mlp_ratio=2,
50
+ act=F.leaky_relu,
51
+ dropout=0.0,
52
+ use_time=True,
53
+ num_layers=6,
54
+ last_fc=False,
55
+ last_fc_dim_out=None,
56
+ ):
57
+ super().__init__()
58
+ self.last_fc = last_fc
59
+ if last_fc:
60
+ self.fc = nn.Linear(dim_in, last_fc_dim_out)
61
+ self.layers = nn.ModuleList(
62
+ [
63
+ TimePointwiseLayer(
64
+ dim_in,
65
+ dim_ctx=dim_ctx,
66
+ mlp_ratio=mlp_ratio,
67
+ act=act,
68
+ dropout=dropout,
69
+ use_time=use_time,
70
+ )
71
+ for _ in range(num_layers)
72
+ ]
73
+ )
74
+
75
+ def forward(self, x, ctx=None):
76
+ for i, layer in enumerate(self.layers):
77
+ x = layer(x, ctx=ctx)
78
+ if self.last_fc:
79
+ x = self.fc(x)
80
+ return x
81
+
82
+
83
+ class TimestepEmbedder(nn.Module):
84
+ """
85
+ Embeds scalar timesteps into vector representations.
86
+ """
87
+
88
+ def __init__(self, hidden_size, frequency_embedding_size=256):
89
+ super().__init__()
90
+ self.mlp = nn.Sequential(
91
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
92
+ nn.SiLU(),
93
+ nn.Linear(hidden_size, hidden_size, bias=True),
94
+ )
95
+ self.frequency_embedding_size = frequency_embedding_size
96
+
97
+ @staticmethod
98
+ def timestep_embedding(t, dim, max_period=10000):
99
+ """
100
+ Create sinusoidal timestep embeddings.
101
+ :param t: a 1-D Tensor of N indices, one per batch element.
102
+ These may be fractional.
103
+ :param dim: the dimension of the output.
104
+ :param max_period: controls the minimum frequency of the embeddings.
105
+ :return: an (N, D) Tensor of positional embeddings.
106
+ """
107
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
108
+ half = dim // 2
109
+ freqs = torch.exp(
110
+ -math.log(max_period)
111
+ * torch.arange(start=0, end=half, dtype=torch.float32)
112
+ / half
113
+ ).to(device=t.device)
114
+ args = t[:, None].float() * freqs[None]
115
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
116
+ if dim % 2:
117
+ embedding = torch.cat(
118
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
119
+ )
120
+ return embedding
121
+
122
+ def forward(self, t):
123
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
124
+ t_emb = self.mlp(t_freq)
125
+ return t_emb
salad/model_components/transformer.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of time conditioned Transformer.
3
+ """
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class PositionalEncoding(nn.Module):
11
+ def __init__(self, d_hid, n_position=200):
12
+ super(PositionalEncoding, self).__init__()
13
+
14
+ # Not a parameter
15
+ self.register_buffer(
16
+ "pos_table", self._get_sinusoid_encoding_table(n_position, d_hid)
17
+ )
18
+
19
+ def _get_sinusoid_encoding_table(self, n_position, d_hid):
20
+ """Sinusoid position encoding table"""
21
+ # TODO: make it with torch instead of numpy
22
+
23
+ def get_position_angle_vec(position):
24
+ return [
25
+ position / np.power(10000, 2 * (hid_j // 2) / d_hid)
26
+ for hid_j in range(d_hid)
27
+ ]
28
+
29
+ sinusoid_table = np.array(
30
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
31
+ )
32
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
33
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
34
+
35
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
36
+
37
+ def forward(self, x):
38
+ """
39
+ Input:
40
+ x: [B,N,D]
41
+ """
42
+ return x + self.pos_table[:, : x.size(1)].clone().detach()
43
+
44
+
45
+ class ConcatSquashLinear(nn.Module):
46
+ def __init__(self, dim_in, dim_out, dim_ctx):
47
+ super(ConcatSquashLinear, self).__init__()
48
+ self._layer = nn.Linear(dim_in, dim_out)
49
+ self._hyper_bias = nn.Linear(dim_ctx, dim_out, bias=False)
50
+ self._hyper_gate = nn.Linear(dim_ctx, dim_out)
51
+
52
+ def forward(self, ctx, x):
53
+ assert ctx.dim() == x.dim()
54
+ gate = torch.sigmoid(self._hyper_gate(ctx))
55
+ bias = self._hyper_bias(ctx)
56
+ ret = self._layer(x) * gate + bias
57
+ return ret
58
+
59
+
60
+ class TimeMLP(nn.Module):
61
+ def __init__(
62
+ self,
63
+ dim_in,
64
+ dim_h,
65
+ dim_out,
66
+ dim_ctx=None,
67
+ act=F.relu,
68
+ dropout=0.0,
69
+ use_time=False,
70
+ ):
71
+ super().__init__()
72
+ self.act = act
73
+ self.use_time = use_time
74
+
75
+ dim_h = int(dim_h)
76
+ if use_time:
77
+ self.fc1 = ConcatSquashLinear(dim_in, dim_h, dim_ctx)
78
+ self.fc2 = ConcatSquashLinear(dim_h, dim_out, dim_ctx)
79
+ else:
80
+ self.fc1 = nn.Linear(dim_in, dim_h)
81
+ self.fc2 = nn.Linear(dim_h, dim_out)
82
+ self.dropout = nn.Dropout(dropout)
83
+
84
+ def forward(self, x, ctx=None):
85
+ if self.use_time:
86
+ x = self.fc1(x=x, ctx=ctx)
87
+ else:
88
+ x = self.fc1(x)
89
+
90
+ x = self.act(x)
91
+ x = self.dropout(x)
92
+ if self.use_time:
93
+ x = self.fc2(x=x, ctx=ctx)
94
+ else:
95
+ x = self.fc2(x)
96
+
97
+ x = self.dropout(x)
98
+ return x
99
+
100
+
101
+ class MultiHeadAttention(nn.Module):
102
+ def __init__(self, dim_self, dim_ref, num_heads, dropout=0.0):
103
+ super().__init__()
104
+ self.num_heads = num_heads
105
+ head_dim = dim_self // num_heads
106
+ self.scale = head_dim**-0.5
107
+ self.to_queries = nn.Linear(dim_self, dim_self)
108
+ self.to_keys_values = nn.Linear(dim_ref, dim_self * 2)
109
+ self.project = nn.Linear(dim_self, dim_self)
110
+ self.dropout = nn.Dropout(dropout)
111
+
112
+ def forward(
113
+ self,
114
+ x,
115
+ y=None,
116
+ mask=None,
117
+ alpha=None,
118
+ ):
119
+ y = y if y is not None else x
120
+ b_a, n, c = x.shape
121
+ b, m, d = y.shape
122
+ # b n h dh
123
+ queries = self.to_queries(x).reshape(
124
+ b_a, n, self.num_heads, c // self.num_heads
125
+ )
126
+ # b m 2 h dh
127
+ keys_values = self.to_keys_values(y).reshape(
128
+ b, m, 2, self.num_heads, c // self.num_heads
129
+ )
130
+ keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
131
+ if alpha is not None:
132
+ out, attention = self.forward_interpolation(
133
+ queries, keys, values, alpha, mask
134
+ )
135
+ else:
136
+ attention = torch.einsum("bnhd,bmhd->bnmh", queries, keys) * self.scale
137
+ if mask is not None:
138
+ if mask.dim() == 2:
139
+ mask = mask.unsqueeze(1)
140
+ attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
141
+ attention = attention.softmax(dim=2)
142
+ attention = self.dropout(attention)
143
+ out = torch.einsum("bnmh,bmhd->bnhd", attention, values).reshape(b, n, c)
144
+ out = self.project(out)
145
+ return out, attention
146
+
147
+
148
+ class TimeTransformerEncoderLayer(nn.Module):
149
+ def __init__(
150
+ self,
151
+ dim_self,
152
+ dim_ctx=None,
153
+ num_heads=1,
154
+ mlp_ratio=2.0,
155
+ act=F.leaky_relu,
156
+ dropout=0.0,
157
+ use_time=True,
158
+ ):
159
+ super().__init__()
160
+ self.use_time = use_time
161
+ self.act = act
162
+ self.attn = MultiHeadAttention(dim_self, dim_self, num_heads, dropout)
163
+ self.attn_norm = nn.LayerNorm(dim_self)
164
+
165
+ mlp_ratio = int(mlp_ratio)
166
+ self.mlp = TimeMLP(
167
+ dim_self, dim_self * mlp_ratio, dim_self, dim_ctx, use_time=use_time
168
+ )
169
+ self.norm = nn.LayerNorm(dim_self)
170
+ self.dropout = nn.Dropout(dropout)
171
+
172
+ def forward(self, x, ctx=None):
173
+ res = x
174
+ x, attn = self.attn(x)
175
+ x = self.attn_norm(x + res)
176
+
177
+ res = x
178
+ x = self.mlp(x, ctx=ctx)
179
+ x = self.norm(x + res)
180
+
181
+ return x, attn
182
+
183
+
184
+ class TimeTransformerDecoderLayer(TimeTransformerEncoderLayer):
185
+ def __init__(
186
+ self,
187
+ dim_self,
188
+ dim_ref,
189
+ dim_ctx=None,
190
+ num_heads=1,
191
+ mlp_ratio=2,
192
+ act=F.leaky_relu,
193
+ dropout=0.0,
194
+ use_time=True,
195
+ ):
196
+ super().__init__(
197
+ dim_self=dim_self,
198
+ dim_ctx=dim_ctx,
199
+ num_heads=num_heads,
200
+ mlp_ratio=mlp_ratio,
201
+ act=act,
202
+ dropout=dropout,
203
+ use_time=use_time,
204
+ )
205
+ self.cross_attn = MultiHeadAttention(dim_self, dim_ref, num_heads, dropout)
206
+ self.cross_attn_norm = nn.LayerNorm(dim_self)
207
+
208
+ def forward(self, x, y, ctx=None):
209
+ res = x
210
+ x, attn = self.attn(x)
211
+ x = self.attn_norm(x + res)
212
+
213
+ res = x
214
+ x, attn = self.cross_attn(x, y)
215
+ x = self.cross_attn_norm(x + res)
216
+
217
+ res = x
218
+ x = self.mlp(x, ctx=ctx)
219
+ x = self.norm(x + res)
220
+
221
+ return x, attn
222
+
223
+
224
+ class TimeTransformerEncoder(nn.Module):
225
+ def __init__(
226
+ self,
227
+ dim_self,
228
+ dim_ctx=None,
229
+ num_heads=1,
230
+ mlp_ratio=2.0,
231
+ act=F.leaky_relu,
232
+ dropout=0.0,
233
+ use_time=True,
234
+ num_layers=3,
235
+ last_fc=False,
236
+ last_fc_dim_out=None,
237
+ ):
238
+ super().__init__()
239
+ self.last_fc = last_fc
240
+ if last_fc:
241
+ self.fc = nn.Linear(dim_self, last_fc_dim_out)
242
+ self.layers = nn.ModuleList(
243
+ [
244
+ TimeTransformerEncoderLayer(
245
+ dim_self,
246
+ dim_ctx=dim_ctx,
247
+ num_heads=num_heads,
248
+ mlp_ratio=mlp_ratio,
249
+ act=act,
250
+ dropout=dropout,
251
+ use_time=use_time,
252
+ )
253
+ for _ in range(num_layers)
254
+ ]
255
+ )
256
+
257
+ def forward(self, x, ctx=None):
258
+ for i, layer in enumerate(self.layers):
259
+ x, attn = layer(x, ctx=ctx)
260
+
261
+ if self.last_fc:
262
+ x = self.fc(x)
263
+ return x
264
+
265
+
266
+ class TimeTransformerDecoder(nn.Module):
267
+ def __init__(
268
+ self,
269
+ dim_self,
270
+ dim_ref,
271
+ dim_ctx=None,
272
+ num_heads=1,
273
+ mlp_ratio=2.0,
274
+ act=F.leaky_relu,
275
+ dropout=0.0,
276
+ use_time=True,
277
+ num_layers=3,
278
+ last_fc=True,
279
+ last_fc_dim_out=None,
280
+ ):
281
+ super().__init__()
282
+ self.last_fc = last_fc
283
+ if last_fc:
284
+ self.fc = nn.Linear(dim_self, last_fc_dim_out)
285
+
286
+ self.layers = nn.ModuleList(
287
+ [
288
+ TimeTransformerDecoderLayer(
289
+ dim_self,
290
+ dim_ref,
291
+ dim_ctx,
292
+ num_heads,
293
+ mlp_ratio,
294
+ act,
295
+ dropout,
296
+ use_time,
297
+ )
298
+ for _ in range(num_layers)
299
+ ]
300
+ )
301
+
302
+ def forward(self, x, y, ctx=None):
303
+ for i, layer in enumerate(self.layers):
304
+ x, attn = layer(x, y=y, ctx=ctx)
305
+ if self.last_fc:
306
+ x = self.fc(x)
307
+
308
+ return x
salad/model_components/variance_schedule.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch.nn import Linear, Module
4
+
5
+ class VarianceSchedule(Module):
6
+ def __init__(self, num_steps, beta_1, beta_T, mode="linear"):
7
+ super().__init__()
8
+ # assert mode in ("linear",)
9
+ self.num_steps = num_steps
10
+ self.beta_1 = beta_1
11
+ self.beta_T = beta_T
12
+ self.mode = mode
13
+
14
+ if mode == "linear":
15
+ betas = torch.linspace(beta_1, beta_T, steps=num_steps)
16
+ elif mode == "quad":
17
+ betas = torch.linspace(beta_1 ** 0.5, beta_T ** 0.5, num_steps) ** 2
18
+ elif mode == "cosine":
19
+ cosine_s = 8e-3
20
+ timesteps = torch.arange(num_steps + 1) / num_steps + cosine_s
21
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
22
+ alphas = torch.cos(alphas).pow(2)
23
+ betas = 1 - alphas[1:] / alphas[:-1]
24
+ betas = betas.clamp(max=0.999)
25
+
26
+ betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding
27
+
28
+ alphas = 1 - betas
29
+ log_alphas = torch.log(alphas)
30
+ for i in range(1, log_alphas.size(0)): # 1 to T
31
+ log_alphas[i] += log_alphas[i - 1]
32
+ alpha_bars = log_alphas.exp()
33
+
34
+ sigmas_flex = torch.sqrt(betas)
35
+ sigmas_inflex = torch.zeros_like(sigmas_flex)
36
+ for i in range(1, sigmas_flex.size(0)):
37
+ sigmas_inflex[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[
38
+ i
39
+ ]
40
+ sigmas_inflex = torch.sqrt(sigmas_inflex)
41
+
42
+ self.register_buffer("betas", betas)
43
+ self.register_buffer("alphas", alphas)
44
+ self.register_buffer("alpha_bars", alpha_bars)
45
+ self.register_buffer("sigmas_flex", sigmas_flex)
46
+ self.register_buffer("sigmas_inflex", sigmas_inflex)
47
+
48
+ def uniform_sample_t(self, batch_size):
49
+ ts = np.random.choice(np.arange(1, self.num_steps + 1), batch_size)
50
+ return ts.tolist()
51
+
52
+ def get_sigmas(self, t, flexibility):
53
+ assert 0 <= flexibility and flexibility <= 1
54
+ sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (
55
+ 1 - flexibility
56
+ )
57
+ return sigmas
salad/models/__init__.py ADDED
File without changes
salad/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (163 Bytes). View file
 
salad/models/__pycache__/base_model.cpython-39.pyc ADDED
Binary file (4.6 kB). View file
 
salad/models/__pycache__/language_phase1.cpython-39.pyc ADDED
Binary file (8.83 kB). View file
 
salad/models/__pycache__/language_phase2.cpython-39.pyc ADDED
Binary file (6.12 kB). View file
 
salad/models/__pycache__/phase1.cpython-39.pyc ADDED
Binary file (2.12 kB). View file
 
salad/models/__pycache__/phase2.cpython-39.pyc ADDED
Binary file (5.37 kB). View file
 
salad/models/base_model.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from salad.data.dataset import SALADDataset
5
+ from salad.utils.train_util import PolyDecayScheduler
6
+
7
+
8
+ class BaseModel(pl.LightningModule):
9
+ def __init__(
10
+ self,
11
+ network,
12
+ variance_schedule,
13
+ **kwargs,
14
+ ):
15
+ super().__init__()
16
+ self.save_hyperparameters(logger=False)
17
+ self.net = network
18
+ self.var_sched = variance_schedule
19
+
20
+ def forward(self, x):
21
+ return self.get_loss(x)
22
+
23
+ def step(self, x, stage: str):
24
+ loss = self(x)
25
+ self.log(
26
+ f"{stage}/loss",
27
+ loss,
28
+ on_step=stage == "train",
29
+ prog_bar=True,
30
+ )
31
+ return loss
32
+
33
+ def training_step(self, batch, batch_idx):
34
+ x = batch
35
+ return self.step(x, "train")
36
+
37
+ def add_noise(self, x, t):
38
+ """
39
+ Input:
40
+ x: [B,D] or [B,G,D]
41
+ t: list of size B
42
+ Output:
43
+ x_noisy: [B,D]
44
+ beta: [B]
45
+ e_rand: [B,D]
46
+ """
47
+ alpha_bar = self.var_sched.alpha_bars[t]
48
+ beta = self.var_sched.betas[t]
49
+
50
+ c0 = torch.sqrt(alpha_bar).view(-1, 1) # [B,1]
51
+ c1 = torch.sqrt(1 - alpha_bar).view(-1, 1)
52
+
53
+ e_rand = torch.randn_like(x)
54
+ if e_rand.dim() == 3:
55
+ c0 = c0.unsqueeze(1)
56
+ c1 = c1.unsqueeze(1)
57
+
58
+ x_noisy = c0 * x + c1 * e_rand
59
+
60
+ return x_noisy, beta, e_rand
61
+
62
+ def get_loss(
63
+ self,
64
+ x0,
65
+ t=None,
66
+ noisy_in=False,
67
+ beta_in=None,
68
+ e_rand_in=None,
69
+ ):
70
+ if x0.dim() == 2:
71
+ B, D = x0.shape
72
+ else:
73
+ B, G, D = x0.shape
74
+ if not noisy_in:
75
+ if t is None:
76
+ t = self.var_sched.uniform_sample_t(B)
77
+ x_noisy, beta, e_rand = self.add_noise(x0, t)
78
+ else:
79
+ x_noisy = x0
80
+ beta = beta_in
81
+ e_rand = e_rand_in
82
+
83
+ e_theta = self.net(x_noisy, beta=beta)
84
+ loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean")
85
+ return loss
86
+
87
+ @torch.no_grad()
88
+ def sample(
89
+ self,
90
+ batch_size=0,
91
+ return_traj=False,
92
+ ):
93
+ raise NotImplementedError
94
+
95
+ def validation_epoch_end(self, outputs):
96
+ if self.hparams.no_run_validation:
97
+ return
98
+ if not self.trainer.sanity_checking:
99
+ if (self.current_epoch) % self.hparams.validation_step == 0:
100
+ self.validation()
101
+
102
+ def _build_dataset(self, stage):
103
+ if hasattr(self, f"data_{stage}"):
104
+ return getattr(self, f"data_{stage}")
105
+ if stage == "train":
106
+ ds = SALADDataset(**self.hparams.dataset_kwargs)
107
+ else:
108
+ dataset_kwargs = self.hparams.dataset_kwargs.copy()
109
+ dataset_kwargs["repeat"] = 1
110
+ ds = SALADDataset(**dataset_kwargs)
111
+ setattr(self, f"data_{stage}", ds)
112
+ return ds
113
+
114
+ def _build_dataloader(self, stage):
115
+ try:
116
+ ds = getattr(self, f"data_{stage}")
117
+ except:
118
+ ds = self._build_dataset(stage)
119
+
120
+ return torch.utils.data.DataLoader(
121
+ ds,
122
+ batch_size=self.hparams.batch_size,
123
+ shuffle=stage == "train",
124
+ drop_last=stage == "train",
125
+ num_workers=4,
126
+ )
127
+
128
+ def train_dataloader(self):
129
+ return self._build_dataloader("train")
130
+
131
+ def val_dataloader(self):
132
+ return self._build_dataloader("val")
133
+
134
+ def test_dataloader(self):
135
+ return self._build_dataloader("test")
136
+
137
+ def configure_optimizers(self):
138
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
139
+ scheduler = PolyDecayScheduler(optimizer, self.hparams.lr, power=0.999)
140
+ return [optimizer], [scheduler]
141
+
142
+ #TODO move get_wandb_logger to logutil.py
143
+ def get_wandb_logger(self):
144
+ for logger in self.logger:
145
+ if isinstance(logger, pl.loggers.wandb.WandbLogger):
146
+ return logger
147
+ return None
salad/models/language_phase1.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import BertModel, BertTokenizer
5
+
6
+ from salad.model_components.lstm import LSTM
7
+ from salad.models.phase1 import Phase1Model
8
+ from salad.utils import imageutil, nputil, visutil
9
+ from salad.utils.spaghetti_util import (clip_eigenvalues,
10
+ generate_zc_from_sj_gaus,
11
+ get_mesh_from_spaghetti, load_mesher,
12
+ load_spaghetti, project_eigenvectors)
13
+ from salad.utils.train_util import get_dropout_mask
14
+ from salad.data.dataset import LangSALADDataset
15
+
16
+
17
+ class LangPhase1Model(Phase1Model):
18
+ def __init__(self, network, variance_schedule, **kwargs):
19
+ super().__init__(network, variance_schedule, **kwargs)
20
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
21
+ if self.hparams.get("use_lstm"):
22
+ self.bertmodel = LSTM(
23
+ text_dim=768, embedding_dim=768, vocab_size=30522, padding_idx=0
24
+ )
25
+ else:
26
+ self.bertmodel = BertModel.from_pretrained("bert-base-uncased")
27
+ if self.hparams.get("text_encoder_freeze"):
28
+ for p in self.bertmodel.parameters():
29
+ p.requires_grad_(False)
30
+
31
+ def forward(self, x, text):
32
+ """
33
+ Input:
34
+ x: [B,G,16]
35
+ text: list of length [B]
36
+ """
37
+ B, G = x.shape[:2]
38
+ text = self.random_mask_text(text)
39
+ lang_emb = self.text_to_embedding(text)
40
+ return self.get_loss(x, lang_emb)
41
+
42
+ def tokenizing(self, text):
43
+ tokenized = self.tokenizer(
44
+ text, return_tensors="pt", padding=True, truncation=True
45
+ ).to(self.device)
46
+ return tokenized
47
+
48
+ def text_to_embedding(self, text):
49
+ """
50
+ text: list of length [B]
51
+ return [B,768]
52
+ """
53
+ tokenized = self.tokenizing(text)
54
+ if self.hparams.get("use_lstm"):
55
+ lang_emb, _ = self.bertmodel(tokenized.input_ids)
56
+ else:
57
+ if self.hparams.get("text_encoder_return_seq"):
58
+ lang_emb = self.bertmodel(**tokenized).last_hidden_state
59
+ else:
60
+ lang_emb = self.bertmodel(**tokenized).pooler_output
61
+ if lang_emb.ndim == 2:
62
+ lang_emb = lang_emb.unsqueeze(1)
63
+ return lang_emb
64
+
65
+ def random_mask_text(self, text):
66
+ text = list(text)
67
+ B = len(text)
68
+ if self.hparams.get("classifier_free_guidance"):
69
+ random_dp_mask = get_dropout_mask(
70
+ B, self.hparams.conditioning_dropout_prob, self.device
71
+ )
72
+ for i in range(B):
73
+ if random_dp_mask[i] == 0:
74
+ text[i] = ""
75
+ return text
76
+
77
+ def get_loss(self, x0, cond, t=None, noisy_in=False, beta_in=None, e_rand_in=None):
78
+ B, G, D = x0.shape
79
+
80
+ if not noisy_in:
81
+ if t is None:
82
+ t = self.var_sched.uniform_sample_t(B)
83
+ x_noisy, beta, e_rand = self.add_noise(x0, t)
84
+ else:
85
+ x_noisy = x0
86
+ beta = beta_in
87
+ e_rand = e_rand_in
88
+
89
+ e_theta = self.net(x_noisy, beta, cond)
90
+ loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean")
91
+ return loss
92
+
93
+ def step(self, batch, stage: str):
94
+ x, text = batch
95
+ loss = self(x, text)
96
+ self.log(f"{stage}/loss", loss, on_step=stage == "train", prog_bar=True)
97
+ return loss
98
+
99
+ @torch.no_grad()
100
+ def sample(
101
+ self,
102
+ num_samples_or_text,
103
+ return_traj=False,
104
+ return_cond=False,
105
+ classifier_free_guidance=True,
106
+ free_guidance_weight=2.0,
107
+ ):
108
+ if isinstance(num_samples_or_text, str):
109
+ num_samples_or_text = [num_samples_or_text]
110
+ if isinstance(num_samples_or_text, int):
111
+ batch_size = num_samples_or_text
112
+ ds = self._build_dataset("val")
113
+ texts = [ds[i][1] for i in range(batch_size)]
114
+ elif isinstance(num_samples_or_text, list):
115
+ texts = num_samples_or_text
116
+ batch_size = len(num_samples_or_text)
117
+ if self.hparams.get("use_zc"):
118
+ x_T = torch.randn([batch_size, 16, 512]).to(self.device)
119
+ else:
120
+ x_T = torch.randn([batch_size, 16, 16]).to(self.device)
121
+ G = x_T.shape[1]
122
+ lang_emb = self.text_to_embedding(texts)
123
+
124
+ if classifier_free_guidance:
125
+ null_texts = ["" for _ in range(batch_size)]
126
+ null_lang_emb = self.text_to_embedding(null_texts)
127
+
128
+ traj = {self.var_sched.num_steps: x_T}
129
+ for t in range(self.var_sched.num_steps, 0, -1):
130
+ z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
131
+ alpha = self.var_sched.alphas[t]
132
+ alpha_bar = self.var_sched.alpha_bars[t]
133
+ sigma = self.var_sched.get_sigmas(t, flexibility=0)
134
+
135
+ c0 = 1.0 / torch.sqrt(alpha)
136
+ c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
137
+
138
+ x_t = traj[t]
139
+
140
+ beta = self.var_sched.betas[[t] * batch_size]
141
+ e_theta = self.net(x_t, beta=beta, context=lang_emb)
142
+
143
+ if classifier_free_guidance:
144
+ null_e_theta = self.net(x_t, beta=beta, context=null_lang_emb)
145
+ w = free_guidance_weight
146
+ e_theta = (1 + w) * e_theta - w * null_e_theta
147
+
148
+ x_next = c0 * (x_t - c1 * e_theta) + sigma * z
149
+ traj[t - 1] = x_next.detach()
150
+
151
+ traj[t] = traj[t].cpu()
152
+
153
+ if not return_traj:
154
+ del traj[t]
155
+
156
+ if return_traj:
157
+ if return_cond:
158
+ return traj, lang_emb
159
+ return traj
160
+ else:
161
+ if return_cond:
162
+ return traj[0], lang_emb
163
+ return traj[0]
164
+
165
+ def sampling_gaussians(
166
+ self,
167
+ num_samples_or_text,
168
+ classifier_free_guidance=True,
169
+ free_guidance_weight=2.0,
170
+ return_cond=False,
171
+ ):
172
+ gaus = self.sample(
173
+ num_samples_or_text,
174
+ classifier_free_guidance=classifier_free_guidance,
175
+ free_guidance_weight=free_guidance_weight,
176
+ return_cond=return_cond,
177
+ )
178
+ if isinstance(gaus, tuple):
179
+ text = gaus[1]
180
+ gaus = gaus[0]
181
+ # gaus = reflect_and_concat_gmms(raw_gaus)
182
+ if self.hparams.get("global_normalization"):
183
+ if not hasattr(self, "data_val"):
184
+ self._build_dataset("val")
185
+ if self.hparams.get("global_normalization") == "partial":
186
+ gaus = self.data_val.unnormalize_global_static(gaus, slice(12, None))
187
+ elif self.hparams.get("global_normalization") == "all":
188
+ gaus = self.data_val.unnormalize_global_static(gaus, slice(None))
189
+
190
+ gaus = project_eigenvectors(clip_eigenvalues(gaus))
191
+ if return_cond:
192
+ return gaus, text
193
+ return gaus
194
+
195
+ def _build_dataset(self, stage):
196
+ if hasattr(self, f"data_{stage}"):
197
+ return getattr(self, f"data_{stage}")
198
+
199
+ ds_class = (
200
+ LangSALADDataset
201
+ )
202
+ if stage == "train":
203
+ ds = ds_class(**self.hparams.dataset_kwargs)
204
+ else:
205
+ dataset_kwargs = self.hparams.dataset_kwargs.copy()
206
+ dataset_kwargs["repeat"] = 1
207
+ ds = ds_class(**dataset_kwargs)
208
+ setattr(self, f"data_{stage}", ds)
209
+ return ds
210
+
211
+ def validation_zc(self):
212
+ vis_num_shapes = 4
213
+ vis_zcs = []
214
+ vis_texts = []
215
+ ds = self._build_dataset("val")
216
+ for i in [0, 1, 2, 3]:
217
+ zcs, text = ds[i]
218
+ vis_zcs.append(zcs)
219
+ vis_texts.append(text)
220
+ vis_zcs = torch.stack(vis_zcs, 0)
221
+ ldm_zcs = self.sample(vis_texts)
222
+
223
+ if not hasattr(self, "spaghetti"):
224
+ self.spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag)
225
+ spaghetti = self.spaghetti
226
+
227
+ if not hasattr(self, "mesher"):
228
+ self.mesher = load_mesher(self.device)
229
+ mesher = self.mesher
230
+
231
+ wandb_logger = self.get_wandb_logger()
232
+ images = []
233
+ for i in range(vis_num_shapes):
234
+ try:
235
+ v, f = get_mesh_from_spaghetti(spaghetti, mesher, vis_zcs[i], res=128)
236
+ gt_img = visutil.render_mesh(v, f, resolution=(256, 256))
237
+ except:
238
+ pass
239
+ try:
240
+ v, f = get_mesh_from_spaghetti(spaghetti, mesher, ldm_zcs[i], res=128)
241
+ pred_img = visutil.render_mesh(v, f, resolution=(256, 256))
242
+ except:
243
+ pass
244
+
245
+ img = imageutil.merge_images([gt_img, pred_img])
246
+ img = imageutil.draw_text(
247
+ img,
248
+ f"Left: GT | Right: Pred \n{vis_texts[i]}",
249
+ font_size=14,
250
+ max_seq_length=50,
251
+ )
252
+ images.append([img])
253
+
254
+ images = imageutil.merge_images(images)
255
+ wandb_logger.log_image("vis", [images])
256
+
257
+ def validation(self):
258
+ if self.hparams.get("use_zc"):
259
+ self.validation_zc()
260
+ return
261
+
262
+ vis_num_shapes = 4
263
+ vis_gaus = []
264
+ vis_texts = []
265
+ ds = self._build_dataset("val")
266
+ vis_indices = [18453, 13036, 13204, 48244]
267
+ for i in vis_indices:
268
+ gaus, text = ds[i]
269
+ vis_gaus.append(gaus)
270
+ vis_texts.append(text)
271
+
272
+ vis_gaus = torch.stack(vis_gaus, 0)
273
+ if self.hparams.get("global_normalization"):
274
+ if self.hparams.get("global_normalization") == "partial":
275
+ vis_gaus = self.data_val.unnormalize_global_static(
276
+ vis_gaus, slice(12, None)
277
+ )
278
+ elif self.hparams.get("global_normalization") == "all":
279
+ vis_gaus = self.dataval.unnormalize_global_static(vis_gaus, slice(None))
280
+
281
+ # vis_gaus = reflect_and_concat_gmms(vis_gaus)
282
+ pred_gaus = self.sampling_gaussians(vis_texts)
283
+
284
+ if not hasattr(self, "spaghetti"):
285
+ self.spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag)
286
+ spaghetti = self.spaghetti
287
+
288
+ if not hasattr(self, "mesher"):
289
+ self.mesher = load_mesher(self.device)
290
+ mesher = self.mesher
291
+
292
+ """ get intrinsics """
293
+ # TODO change the ckpt path.
294
+ if not hasattr(self, "phase2_model"):
295
+ phase2_ckpt = "/home/juil/pvddir/results/phase2/augment_final_0214/0214_202607/checkpoints/epoch=4999-val_loss=0.0000.ckpt"
296
+ self.phase2_model = SpaghettiConditionSALDM.load_from_checkpoint(
297
+ phase2_ckpt, strict=False
298
+ ).to(self.device)
299
+ self.phase2_model.eval()
300
+ for p in self.phase2_model.parameters():
301
+ p.requires_grad_(False)
302
+
303
+ phase2_model = self.phase2_model
304
+
305
+ gt_sj = phase2_model.sample(vis_gaus)
306
+ pred_sj = phase2_model.sample(pred_gaus)
307
+
308
+ gt_zcs = generate_zc_from_sj_gaus(spaghetti, gt_sj, vis_gaus)
309
+ pred_zcs = generate_zc_from_sj_gaus(spaghetti, pred_sj, pred_gaus)
310
+
311
+ wandb_logger = self.get_wandb_logger()
312
+ images = []
313
+ for i in range(vis_num_shapes):
314
+ gt_img = visutil.render_gaussians(vis_gaus[i], resolution=(256, 256))
315
+ try:
316
+ v, f = get_mesh_from_spaghetti(spaghetti, mesher, gt_zcs[i], res=128)
317
+ gt_mesh_img = visutil.render_mesh(v, f, resolution=(256, 256))
318
+ gt_img = imageutil.merge_images([gt_img, gt_mesh_img])
319
+ except:
320
+ pass
321
+
322
+ pred_img = visutil.render_gaussians(pred_gaus[i], resolution=(256, 256))
323
+ try:
324
+ v, f = get_mesh_from_spaghetti(spaghetti, mesher, pred_zcs[i], res=128)
325
+ pred_mesh_img = visutil.render_mesh(v, f, resolution=(256, 256))
326
+ pred_img = imageutil.merge_images([pred_img, pred_mesh_img])
327
+ except:
328
+ pass
329
+
330
+ img = imageutil.merge_images([gt_img, pred_img])
331
+ img = imageutil.draw_text(
332
+ img,
333
+ f"Left: GT | Right: Pred \n{vis_texts[i]}",
334
+ font_size=14,
335
+ max_seq_length=50,
336
+ )
337
+ images.append([img])
338
+
339
+ images = imageutil.merge_images(images)
340
+ wandb_logger.log_image("vis", [images])
salad/models/language_phase2.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import BertModel, BertTokenizer
5
+
6
+ from salad.model_components.lstm import LSTM
7
+ from salad.models.language_phase1 import LangPhase1Model
8
+ from salad.utils import imageutil, nputil, visutil
9
+ from salad.utils.spaghetti_util import (generate_zc_from_sj_gaus,
10
+ get_mesh_from_spaghetti, load_mesher,
11
+ load_spaghetti)
12
+ from salad.utils.train_util import get_dropout_mask
13
+
14
+
15
+ class LangPhase2Model(LangPhase1Model):
16
+ def __init__(self, network, variance_schedule, **kwargs):
17
+ super().__init__(network, variance_schedule, **kwargs)
18
+
19
+ def random_mask_gaus_text(self, gaus, text):
20
+ if self.hparams.get("classifier_free_guidance"):
21
+ text = list(text)
22
+ B = gaus.shape[0]
23
+ random_dp_mask = get_dropout_mask(
24
+ B, self.hparams.conditioning_dropout_prob, self.device
25
+ )
26
+ gaus = gaus * random_dp_mask.unsqueeze(1).unsqueeze(2)
27
+ for i in range(B):
28
+ if random_dp_mask[i] == 0:
29
+ text[i] = ""
30
+
31
+ return gaus, text
32
+
33
+ def forward(self, x, gaus, text):
34
+ """
35
+ Input:
36
+ x: [B,G,512]
37
+ gaus: [B,G,16]
38
+ text: list of [B]
39
+ """
40
+ B, G = x.shape[:2]
41
+ gaus, text = self.random_mask_gaus_text(gaus, text)
42
+ lang_emb = self.text_to_embedding(text)
43
+ cond = self.cond_from_gaus_lang_f(gaus, lang_emb)
44
+
45
+ return self.get_loss(x, cond)
46
+
47
+ def step(self, batch, stage):
48
+ x, gaus, text = batch
49
+ loss = self(x, gaus, text)
50
+ self.log(f"{stage}/loss", loss, on_step=stage == "train", prog_bar=True)
51
+ return loss
52
+
53
+ def get_loss(self, x0, cond, t=None, noisy_in=False, beta_in=None, e_rand_in=None):
54
+ B, G, D = x0.shape
55
+ if not noisy_in:
56
+ if t is None:
57
+ t = self.var_sched.uniform_sample_t(B)
58
+ x_noisy, beta, e_rand = self.add_noise(x0, t)
59
+ else:
60
+ x_noisy = x0
61
+ beta = beta_in
62
+ e_rand = e_rand_in
63
+ e_theta = self.net(x_noisy, beta, cond)
64
+ loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean")
65
+ return loss
66
+
67
+ def cond_from_gaus_lang_f(self, gaus, lang_f):
68
+ gaus = nputil.np2th(gaus).to(self.device)
69
+ G = gaus.shape[1]
70
+ lang_f = nputil.np2th(lang_f).to(self.device)
71
+ assert gaus.ndim == 3
72
+ if lang_f.ndim == 2:
73
+ lang_f = lang_f.unsqueeze(1)
74
+ lang_f = lang_f.expand(-1, G, -1)
75
+ return torch.cat([gaus, lang_f], -1)
76
+
77
+ def generate_null_cond(self, B, G):
78
+ text = ["" for _ in range(B)]
79
+ lang_emb = self.text_to_embedding(text)
80
+ gaus = torch.zeros(B, G, 16, dtype=torch.float, device=self.device)
81
+ return self.cond_from_gaus_lang_f(gaus, lang_emb)
82
+
83
+ @torch.no_grad()
84
+ def sample(
85
+ self,
86
+ num_samples_or_cond,
87
+ return_traj=False,
88
+ return_cond=False,
89
+ classifier_free_guidance=False,
90
+ free_guidance_weight=0.7,
91
+ ):
92
+
93
+ if isinstance(num_samples_or_cond, int):
94
+ batch_size = num_samples_or_cond
95
+ ds = self._build_dataset("val")
96
+ batch_gaus = []
97
+ batch_text = []
98
+ for i in range(batch_size):
99
+ _, gaus, text = ds[i]
100
+ batch_gaus.append(gaus)
101
+ batch_text.append(text)
102
+
103
+ batch_gaus = torch.stack(batch_gaus, 0)
104
+ lang_emb = self.text_to_embedding(batch_text)
105
+ cond = self.cond_from_gaus_lang_f(batch_gaus, lang_emb).to(self.device)
106
+
107
+ elif isinstance(num_samples_or_cond, np.ndarray) or isinstance(
108
+ num_samples_or_cond, torch.Tensor
109
+ ):
110
+ cond = nputil.np2th(num_samples_or_cond).to(self.device)
111
+ batch_size = len(cond)
112
+
113
+ G = cond.shape[1]
114
+ if classifier_free_guidance:
115
+ null_cond = self.generate_null_cond(batch_size, G)
116
+
117
+ x_T = torch.randn([batch_size, 16, 512]).to(self.device)
118
+ traj = {self.var_sched.num_steps: x_T}
119
+ for t in range(self.var_sched.num_steps, 0, -1):
120
+ z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
121
+ alpha = self.var_sched.alphas[t]
122
+ alpha_bar = self.var_sched.alpha_bars[t]
123
+ sigma = self.var_sched.get_sigmas(t, flexibility=0)
124
+
125
+ c0 = 1.0 / torch.sqrt(alpha)
126
+ c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
127
+
128
+ x_t = traj[t]
129
+
130
+ beta = self.var_sched.betas[[t] * batch_size]
131
+ e_theta = self.net(x_t, beta=beta, context=cond)
132
+
133
+ if classifier_free_guidance:
134
+ null_e_theta = self.net(x_t, beta=beta, context=null_cond)
135
+ w = free_guidance_weight
136
+ e_theta = (1 + w) * e_theta - w * null_e_theta
137
+
138
+ x_next = c0 * (x_t - c1 * e_theta) + sigma * z
139
+ traj[t - 1] = x_next.detach()
140
+
141
+ traj[t] = traj[t].cpu()
142
+
143
+ if not return_traj:
144
+ del traj[t]
145
+
146
+ if return_traj:
147
+ if return_cond:
148
+ return traj, cond
149
+ return traj
150
+ else:
151
+ if return_cond:
152
+ return traj[0], cond
153
+ return traj[0]
154
+
155
+ def validation(self):
156
+ vis_num_shapes = 4
157
+ vis_gt_sj = []
158
+ vis_gaus = []
159
+ vis_texts = []
160
+ ds = self._build_dataset("val")
161
+ vis_indices = [18453, 13036, 13204, 48244]
162
+ for i in vis_indices:
163
+ sj, gaus, text = ds[i]
164
+ vis_gt_sj.append(sj)
165
+ vis_gaus.append(gaus)
166
+ vis_texts.append(text)
167
+
168
+ vis_gt_sj = torch.stack(vis_gt_sj, 0)
169
+ vis_gaus = torch.stack(vis_gaus, 0).to(self.device)
170
+ vis_lang_f = self.text_to_embedding(vis_texts)
171
+ vis_cond = self.cond_from_gaus_lang_f(vis_gaus, vis_lang_f)
172
+ pred_sj = self.sample(vis_cond)
173
+
174
+ if not hasattr(self, "spaghetti"):
175
+ self.spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag)
176
+ spaghetti = self.spaghetti
177
+
178
+ if not hasattr(self, "mesher"):
179
+ self.mesher = load_mesher(self.device)
180
+ mesher = self.mesher
181
+
182
+ gt_zcs = generate_zc_from_sj_gaus(spaghetti, vis_gt_sj, vis_gaus)
183
+ pred_zcs = generate_zc_from_sj_gaus(spaghetti, pred_sj, vis_gaus)
184
+
185
+ wandb_logger = self.get_wandb_logger()
186
+ for i in range(vis_num_shapes):
187
+ gaus_img = visutil.render_gaussians(vis_gaus[i], resolution=(256, 256))
188
+ vert, face = get_mesh_from_spaghetti(spaghetti, mesher, gt_zcs[i], res=128)
189
+ gt_mesh_img = visutil.render_mesh(vert, face, resolution=(256, 256))
190
+ img = [gaus_img, gt_mesh_img]
191
+ try:
192
+ vert, face = get_mesh_from_spaghetti(spaghetti, mesher, pred_zcs[i])
193
+ pred_mesh_img = visutil.render_mesh(vert, face, resolution=(256, 256))
194
+ img.append(pred_mesh_img)
195
+ except Exception as e:
196
+ print(e)
197
+ img = imageutil.merge_images(img)
198
+ img = imageutil.draw_text(
199
+ img, vis_texts[i], font_size=14, max_seq_length=50
200
+ )
201
+ wandb_logger.log_image("vis", [img])
salad/models/phase1.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from salad.models.base_model import BaseModel
4
+ from salad.utils import nputil, thutil
5
+ from salad.utils.spaghetti_util import clip_eigenvalues, project_eigenvectors
6
+
7
+ class Phase1Model(BaseModel):
8
+ def __init__(self, network, variance_schedule, **kwargs):
9
+ super().__init__(network, variance_schedule, **kwargs)
10
+
11
+ @torch.no_grad()
12
+ def sample(
13
+ self,
14
+ batch_size=0,
15
+ return_traj=False,
16
+ ):
17
+ x_T = torch.randn([batch_size, 16, 16]).to(self.device)
18
+
19
+ traj = {self.var_sched.num_steps: x_T}
20
+ for t in range(self.var_sched.num_steps, 0, -1):
21
+ z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
22
+ alpha = self.var_sched.alphas[t]
23
+ alpha_bar = self.var_sched.alpha_bars[t]
24
+ sigma = self.var_sched.get_sigmas(t, flexibility=0)
25
+
26
+ c0 = 1.0 / torch.sqrt(alpha)
27
+ c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
28
+
29
+ x_t = traj[t]
30
+
31
+ beta = self.var_sched.betas[[t] * batch_size]
32
+ e_theta = self.net(x_t, beta=beta)
33
+ # print(e_theta.norm(-1).mean())
34
+
35
+ x_next = c0 * (x_t - c1 * e_theta) + sigma * z
36
+ traj[t - 1] = x_next.detach()
37
+
38
+ traj[t] = traj[t].cpu()
39
+
40
+ if not return_traj:
41
+ del traj[t]
42
+ if return_traj:
43
+ return traj
44
+ else:
45
+ return traj[0]
46
+
47
+ def sampling_gaussians(self, num_shapes):
48
+ """
49
+ Return:
50
+ ldm_gaus: np.ndarray
51
+ gt_gaus: np.ndarray
52
+ """
53
+ ldm_gaus = self.sample(num_shapes)
54
+
55
+ if self.hparams.get("global_normalization"):
56
+ if not hasattr(self, "data_val"):
57
+ self._build_dataset("val")
58
+ if self.hparams.get("global_normalization") == "partial":
59
+ ldm_gaus = self.data_val.unnormalize_global_static(ldm_gaus, slice(12,None))
60
+ elif self.hparams.get("global_normalization") == "all":
61
+ ldm_gaus = self.data_val.unnormalize_global_static(ldm_gaus, slice(None))
62
+
63
+ ldm_gaus = clip_eigenvalues(ldm_gaus)
64
+ ldm_gaus = project_eigenvectors(ldm_gaus)
65
+ return ldm_gaus
salad/models/phase2.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from salad.models.base_model import BaseModel
8
+ from salad.utils import imageutil, nputil, sysutil, thutil, visutil
9
+ from salad.utils.spaghetti_util import (clip_eigenvalues,
10
+ generate_zc_from_sj_gaus,
11
+ get_mesh_from_spaghetti, load_mesher,
12
+ load_spaghetti, project_eigenvectors)
13
+
14
+
15
+ class Phase2Model(BaseModel):
16
+ def __init__(self, network, variance_schedule, **kwargs):
17
+ super().__init__(network, variance_schedule, **kwargs)
18
+
19
+ def forward(self, x, cond):
20
+ return self.get_loss(x, cond)
21
+
22
+ def step(self, batch, stage: str):
23
+ x, cond = batch
24
+ loss = self(x, cond)
25
+ self.log(f"{stage}/loss", loss, on_step=stage == "train", prog_bar=True)
26
+ return loss
27
+
28
+ def get_loss(self, x0, cond, t=None, noisy_in=False, beta_in=None, e_rand_in=None):
29
+ B, G, D = x0.shape
30
+
31
+ if not noisy_in:
32
+ if t is None:
33
+ t = self.var_sched.uniform_sample_t(B)
34
+ x_noisy, beta, e_rand = self.add_noise(x0, t)
35
+ else:
36
+ x_noisy = x0
37
+ beta = beta_in
38
+ e_rand = e_rand_in
39
+
40
+ e_theta = self.net(x_noisy, beta, cond)
41
+ loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean")
42
+ return loss
43
+
44
+ @torch.no_grad()
45
+ def sample(
46
+ self,
47
+ num_samples_or_gaus: Union[torch.Tensor, np.ndarray, int],
48
+ return_traj=False,
49
+ classifier_free_guidance=None,
50
+ free_guidance_weight=-0.7,
51
+ augment_condition_in_test=False,
52
+ return_cond=False,
53
+ ):
54
+ if isinstance(num_samples_or_gaus, int):
55
+ batch_size = num_samples_or_gaus
56
+ ds = self._build_dataset("val")
57
+ cond = torch.stack([ds[i][1] for i in range(batch_size)], 0)
58
+
59
+ elif isinstance(num_samples_or_gaus, np.ndarray) or isinstance(
60
+ num_samples_or_gaus, torch.Tensor
61
+ ):
62
+ cond = nputil.np2th(num_samples_or_gaus)
63
+ if cond.dim() == 2:
64
+ cond = cond[None]
65
+ batch_size = len(cond)
66
+ else:
67
+ raise ValueError(
68
+ "'num_samples_or_gaus' should be int, torch.Tensor or np.ndarray."
69
+ )
70
+
71
+ x_T = torch.randn([batch_size, 16, 512]).to(self.device)
72
+ cond = cond.to(self.device)
73
+
74
+ traj = {self.var_sched.num_steps: x_T}
75
+ for t in range(self.var_sched.num_steps, 0, -1):
76
+ z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
77
+ alpha = self.var_sched.alphas[t]
78
+ alpha_bar = self.var_sched.alpha_bars[t]
79
+ sigma = self.var_sched.get_sigmas(t, flexibility=0)
80
+
81
+ c0 = 1.0 / torch.sqrt(alpha)
82
+ c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
83
+
84
+ x_t = traj[t]
85
+
86
+ beta = self.var_sched.betas[[t] * batch_size]
87
+ e_theta = self.net(x_t, beta=beta, context=cond)
88
+
89
+ x_next = c0 * (x_t - c1 * e_theta) + sigma * z
90
+ traj[t - 1] = x_next.detach()
91
+
92
+ traj[t] = traj[t].cpu()
93
+
94
+ if not return_traj:
95
+ del traj[t]
96
+
97
+ if return_traj:
98
+ if return_cond:
99
+ return traj, cond
100
+ return traj
101
+ else:
102
+ if return_cond:
103
+ return traj[0], cond
104
+ return traj[0]
105
+
106
+ def validation(self):
107
+ latent_ds = self._build_dataset("val")
108
+ vis_num_shapes = 3
109
+ num_variations = 3
110
+ sysutil.clean_gpu()
111
+
112
+ if not hasattr(self, "spaghetti"):
113
+ spaghetti = load_spaghetti(
114
+ self.device,
115
+ self.hparams.spaghetti_tag
116
+ if self.hparams.get("spaghetti_tag")
117
+ else "chairs_large",
118
+ )
119
+ self.spaghetti = spaghetti
120
+ else:
121
+ spaghetti = self.spaghetti
122
+
123
+ if not hasattr(self, "mesher"):
124
+ mesher = load_mesher(self.device)
125
+ self.mesher = mesher
126
+ else:
127
+ mesher = self.mesher
128
+
129
+ """======== Sampling ========"""
130
+ gt_zs = []
131
+ gt_gaus = []
132
+
133
+ gt_zs, gt_gaus = zip(*[latent_ds[i + 3] for i in range(vis_num_shapes)])
134
+ gt_zs, gt_gaus = list(map(lambda x: torch.stack(x), [gt_zs, gt_gaus]))
135
+ if self.hparams.get("sj_global_normalization"):
136
+ gt_zs = thutil.th2np(gt_zs)
137
+ gt_zs = latent_ds.unnormalize_sj_global_static(gt_zs)
138
+ gt_zs = nputil.np2th(gt_zs).to(self.device)
139
+
140
+ gt_gaus_repeated = gt_gaus.repeat_interleave(num_variations, 0)
141
+ clean_ldm_zs, clean_gaus = self.sample(gt_gaus_repeated, return_cond=True)
142
+ clean_gaus = project_eigenvectors(clip_eigenvalues(clean_gaus))
143
+ clean_zcs = generate_zc_from_sj_gaus(spaghetti, clean_ldm_zs, clean_gaus)
144
+ gt_zcs = generate_zc_from_sj_gaus(spaghetti, gt_zs, gt_gaus)
145
+ sysutil.clean_gpu()
146
+
147
+ """=========================="""
148
+
149
+ """ Spaghetti Decoding """
150
+ wandb_logger = self.get_wandb_logger()
151
+ resolution = (256, 256)
152
+ for i in range(vis_num_shapes):
153
+ img_per_shape = []
154
+ gaus_img = visutil.render_gaussians(gt_gaus[i], resolution=resolution)
155
+ vert, face = get_mesh_from_spaghetti(spaghetti, mesher, gt_zcs[i], res=128)
156
+ gt_mesh_img = visutil.render_mesh(vert, face, resolution=resolution)
157
+ gt_img = imageutil.merge_images([gaus_img, gt_mesh_img])
158
+ gt_img = imageutil.draw_text(gt_img, "GT", font_size=24)
159
+ img_per_shape.append(gt_img)
160
+ for j in range(num_variations):
161
+ try:
162
+ gaus_img = visutil.render_gaussians(
163
+ clean_gaus[i * num_variations + j], resolution=resolution
164
+ )
165
+ vert, face = get_mesh_from_spaghetti(
166
+ spaghetti, mesher, clean_zcs[i * num_variations + j], res=128
167
+ )
168
+ mesh_img = visutil.render_mesh(vert, face, resolution=resolution)
169
+ pred_img = imageutil.merge_images([gaus_img, mesh_img])
170
+ pred_img = imageutil.draw_text(
171
+ pred_img, f"{j}-th clean gaus", font_size=24
172
+ )
173
+ img_per_shape.append(pred_img)
174
+ except Exception as e:
175
+ print(e)
176
+
177
+ try:
178
+ image = imageutil.merge_images(img_per_shape)
179
+ wandb_logger.log_image("visualization", [image])
180
+ except Exception as e:
181
+ print(e)
182
+
183
+ """ ================== """
salad/spaghetti/.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /assets/*
2
+ !/assets/readme_resources/
3
+ !/assets/ui_resources/
4
+ !/assets/splits/
5
+ !/assets/mesh/
6
+ *.vtk
7
+ .idea/
8
+ __pycache__/
9
+ **_ig_**