Spaces:
Runtime error
Runtime error
Commit
·
ad7bc89
1
Parent(s):
2954c0e
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +64 -0
- README.md +4 -4
- app.py +395 -0
- assets/car0_mesh_centered_flipped.obj +0 -0
- assets/chair191_mesh_centered_flipped.obj +0 -0
- assets/motorcycle12_mesh_centered_flipped.obj +0 -0
- assets/motorcycle29_mesh_centered_flipped.obj +0 -0
- assets/plane.obj +14 -0
- assets/teddybear0_mesh_centered_flipped.obj +0 -0
- assets/teddybear31_mesh_centered_flipped.obj +0 -0
- configs/train_co3d_concept.yaml +198 -0
- pretrained_models/car0/checkpoints/step=000001600.ckpt +3 -0
- pretrained_models/car0/configs/2024-04-12T21-30-20-lightning.yaml +31 -0
- pretrained_models/car0/configs/2024-04-12T21-30-20-project.yaml +168 -0
- pretrained_models/car0/configs/2024-04-13T11-31-55-lightning.yaml +31 -0
- pretrained_models/car0/configs/2024-04-13T11-31-55-project.yaml +170 -0
- pretrained_models/car0/configs/2024-04-13T11-42-30-lightning.yaml +31 -0
- pretrained_models/car0/configs/2024-04-13T11-42-30-project.yaml +170 -0
- pretrained_models/chair191/checkpoints/step=000001600.ckpt +3 -0
- pretrained_models/chair191/configs/2024-04-12T22-10-18-lightning.yaml +31 -0
- pretrained_models/chair191/configs/2024-04-12T22-10-18-project.yaml +168 -0
- pretrained_models/motorcycle12/checkpoints/step=000001600.ckpt +3 -0
- pretrained_models/motorcycle12/configs/2024-04-12T23-30-18-project.yaml +168 -0
- pretrained_models/teddybear31/checkpoints/step=000001600.ckpt +3 -0
- pretrained_models/teddybear31/configs/2024-04-12T22-50-24-lightning.yaml +31 -0
- pretrained_models/teddybear31/configs/2024-04-12T22-50-24-project.yaml +168 -0
- requirements.txt +37 -0
- sampling_for_demo.py +487 -0
- scripts.js +147 -0
- sgm/__init__.py +4 -0
- sgm/data/__init__.py +1 -0
- sgm/data/data_co3d.py +762 -0
- sgm/lr_scheduler.py +135 -0
- sgm/models/__init__.py +2 -0
- sgm/models/autoencoder.py +335 -0
- sgm/models/diffusion.py +556 -0
- sgm/modules/__init__.py +6 -0
- sgm/modules/attention.py +1202 -0
- sgm/modules/autoencoding/__init__.py +0 -0
- sgm/modules/autoencoding/lpips/__init__.py +0 -0
- sgm/modules/autoencoding/lpips/loss.py +0 -0
- sgm/modules/autoencoding/lpips/loss/LICENSE +23 -0
- sgm/modules/autoencoding/lpips/loss/__init__.py +0 -0
- sgm/modules/autoencoding/lpips/loss/lpips.py +147 -0
- sgm/modules/autoencoding/lpips/model/LICENSE +58 -0
- sgm/modules/autoencoding/lpips/model/__init__.py +0 -0
- sgm/modules/autoencoding/lpips/model/model.py +88 -0
- sgm/modules/autoencoding/lpips/util.py +128 -0
- sgm/modules/autoencoding/lpips/vqperceptual.py +17 -0
- sgm/modules/autoencoding/regularizers/__init__.py +31 -0
Dockerfile
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
|
2 |
+
|
3 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
4 |
+
|
5 |
+
ENV PYTHONUNBUFFERED=1
|
6 |
+
|
7 |
+
RUN apt-get update && apt-get install --no-install-recommends -y \
|
8 |
+
build-essential \
|
9 |
+
wget \
|
10 |
+
git \
|
11 |
+
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
12 |
+
|
13 |
+
WORKDIR /code
|
14 |
+
|
15 |
+
COPY ./requirements.txt /code/requirements.txt
|
16 |
+
|
17 |
+
# Set up a new user named "user" with user ID 1000
|
18 |
+
RUN useradd -m -u 1000 user
|
19 |
+
# Switch to the "user" user
|
20 |
+
USER user
|
21 |
+
# Set home to the user's home directory
|
22 |
+
ENV HOME=/home/user \
|
23 |
+
PATH=/home/user/.local/bin:$PATH \
|
24 |
+
PYTHONPATH=$HOME/app \
|
25 |
+
PYTHONUNBUFFERED=1 \
|
26 |
+
SYSTEM=spaces
|
27 |
+
|
28 |
+
# Install miniconda
|
29 |
+
RUN mkdir -p /home/user/conda
|
30 |
+
ENV CONDA_DIR /home/user/conda
|
31 |
+
RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
|
32 |
+
/bin/bash ~/miniconda.sh -b -p /home/user/conda
|
33 |
+
|
34 |
+
# Put conda in path so we can use conda activate
|
35 |
+
ENV PATH=$CONDA_DIR/bin:$PATH
|
36 |
+
|
37 |
+
# Activate
|
38 |
+
RUN conda init bash
|
39 |
+
|
40 |
+
RUN . /home/user/conda/bin/activate
|
41 |
+
|
42 |
+
# Install dependencies
|
43 |
+
RUN conda create -n pose python=3.8
|
44 |
+
RUN conda activate pose
|
45 |
+
|
46 |
+
RUN pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
|
47 |
+
|
48 |
+
RUN pip install -r /code/requirements.txt
|
49 |
+
|
50 |
+
RUN conda install -c conda-forge cudatoolkit-dev -y
|
51 |
+
ENV CUDA_HOME=$CONDA_PREFIX/pkgs/cuda-toolkit/
|
52 |
+
RUN pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
|
53 |
+
|
54 |
+
|
55 |
+
RUN wget https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P /code/pretrained_models
|
56 |
+
RUN wget https://huggingface.co/stabilityai/sdxl-vae/resolve/main/sdxl_vae.safetensors -P /code/pretrained_models
|
57 |
+
|
58 |
+
# Set the working directory to the user's home directory
|
59 |
+
WORKDIR $HOME/app
|
60 |
+
|
61 |
+
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
62 |
+
COPY --chown=user . $HOME/app
|
63 |
+
|
64 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: gray
|
5 |
colorTo: yellow
|
6 |
-
sdk:
|
7 |
-
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
1 |
---
|
2 |
+
title: CustomDiffusion360
|
3 |
+
emoji: 📷
|
4 |
colorFrom: gray
|
5 |
colorTo: yellow
|
6 |
+
sdk: docker
|
7 |
+
app_port: 7860
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
app.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
import torch
|
5 |
+
import json
|
6 |
+
import glob
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
import time
|
10 |
+
import tqdm
|
11 |
+
import copy
|
12 |
+
|
13 |
+
# Mesh imports
|
14 |
+
from pytorch3d.io import load_objs_as_meshes
|
15 |
+
from pytorch3d.vis.plotly_vis import AxisArgs, plot_scene
|
16 |
+
from pytorch3d.transforms import Transform3d, RotateAxisAngle, Translate, Rotate
|
17 |
+
|
18 |
+
from sampling_for_demo import load_and_return_model_and_data, sample, load_base_model
|
19 |
+
|
20 |
+
|
21 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
22 |
+
|
23 |
+
|
24 |
+
def transform_mesh(mesh, transform, scale=1.0):
|
25 |
+
mesh = mesh.clone()
|
26 |
+
verts = mesh.verts_packed() * scale
|
27 |
+
verts = transform.transform_points(verts)
|
28 |
+
mesh.offset_verts_(verts - mesh.verts_packed())
|
29 |
+
return mesh
|
30 |
+
|
31 |
+
|
32 |
+
def get_input_pose_fig():
|
33 |
+
global curr_camera_dict
|
34 |
+
global obj_filename
|
35 |
+
global plane_trans
|
36 |
+
|
37 |
+
plane_filename = 'assets/plane.obj'
|
38 |
+
|
39 |
+
mesh_scale = 0.75
|
40 |
+
mesh = load_objs_as_meshes([obj_filename], device=device)
|
41 |
+
mesh.scale_verts_(mesh_scale)
|
42 |
+
|
43 |
+
plane = load_objs_as_meshes([plane_filename], device=device)
|
44 |
+
|
45 |
+
### plane
|
46 |
+
rotate_x = RotateAxisAngle(angle=90.0, axis='X', device=device)
|
47 |
+
plane = transform_mesh(plane, rotate_x)
|
48 |
+
translate_y = Translate(0, plane_trans * mesh_scale, 0, device=device)
|
49 |
+
plane = transform_mesh(plane, translate_y)
|
50 |
+
|
51 |
+
fig = plot_scene({
|
52 |
+
"plot": {
|
53 |
+
"object": mesh,
|
54 |
+
},
|
55 |
+
},
|
56 |
+
axis_args=AxisArgs(showgrid=True, backgroundcolor='#cccde0'),
|
57 |
+
xaxis=dict(range=[-1, 1]),
|
58 |
+
yaxis=dict(range=[-1, 1]),
|
59 |
+
zaxis=dict(range=[-1, 1])
|
60 |
+
)
|
61 |
+
|
62 |
+
plane = plane.detach().cpu()
|
63 |
+
verts = plane.verts_packed()
|
64 |
+
faces = plane.faces_packed()
|
65 |
+
|
66 |
+
fig.add_trace(
|
67 |
+
go.Mesh3d(
|
68 |
+
x=verts[:, 0],
|
69 |
+
y=verts[:, 1],
|
70 |
+
z=verts[:, 2],
|
71 |
+
i=faces[:, 0],
|
72 |
+
j=faces[:, 1],
|
73 |
+
k=faces[:, 2],
|
74 |
+
opacity=0.7,
|
75 |
+
color='gray',
|
76 |
+
hoverinfo='skip',
|
77 |
+
),
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
print("fig: curr camera dict")
|
82 |
+
print(curr_camera_dict)
|
83 |
+
camera_dict = curr_camera_dict
|
84 |
+
|
85 |
+
fig.update_layout(scene=dict(
|
86 |
+
xaxis=dict(showticklabels=True, visible=True),
|
87 |
+
yaxis=dict(showticklabels=True, visible=True),
|
88 |
+
zaxis=dict(showticklabels=True, visible=True),
|
89 |
+
))
|
90 |
+
# show grid
|
91 |
+
fig.update_layout(scene=dict(
|
92 |
+
xaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'),
|
93 |
+
yaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'),
|
94 |
+
zaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'),
|
95 |
+
bgcolor='#dedede',
|
96 |
+
))
|
97 |
+
|
98 |
+
fig.update_layout(
|
99 |
+
camera_dict,
|
100 |
+
width=512, height=512,
|
101 |
+
)
|
102 |
+
|
103 |
+
return fig
|
104 |
+
|
105 |
+
|
106 |
+
def run_inference(cam_pose_json, prompt, scale_im, scale, steps, seed):
|
107 |
+
print("prompt is ", prompt)
|
108 |
+
global current_data, current_model
|
109 |
+
|
110 |
+
# run model
|
111 |
+
images = sample(
|
112 |
+
current_model, current_data,
|
113 |
+
num_images=1,
|
114 |
+
prompt=prompt,
|
115 |
+
appendpath="",
|
116 |
+
camera_json=cam_pose_json,
|
117 |
+
train=False,
|
118 |
+
scale=scale,
|
119 |
+
scale_im=scale_im,
|
120 |
+
beta=1.0,
|
121 |
+
num_ref=8,
|
122 |
+
skipreflater=False,
|
123 |
+
num_steps=steps,
|
124 |
+
valid=False,
|
125 |
+
max_images=20,
|
126 |
+
seed=seed
|
127 |
+
)
|
128 |
+
|
129 |
+
result = images[0]
|
130 |
+
print(result.shape)
|
131 |
+
result = Image.fromarray((np.clip(((result+1.0)/2.0).permute(1, 2, 0).cpu().numpy(), 0., 1.)*255).astype(np.uint8))
|
132 |
+
print('result obtained')
|
133 |
+
return result
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
def update_curr_camera_dict(camera_json):
|
138 |
+
# TODO: this does not always update the figure, also there's always flashes
|
139 |
+
global curr_camera_dict
|
140 |
+
global prev_camera_dict
|
141 |
+
if camera_json is None:
|
142 |
+
camera_json = json.dumps(prev_camera_dict)
|
143 |
+
camera_json = camera_json.replace("'", "\"")
|
144 |
+
curr_camera_dict = json.loads(camera_json) # ["scene.camera"]
|
145 |
+
print("update curr camera dict")
|
146 |
+
print(curr_camera_dict)
|
147 |
+
return camera_json
|
148 |
+
|
149 |
+
|
150 |
+
MODELS_DIR = "pretrained-models/"
|
151 |
+
|
152 |
+
def select_and_load_model(category, category_single_id):
|
153 |
+
global current_data, current_model, base_model
|
154 |
+
current_model = None
|
155 |
+
current_model = copy.deepcopy(base_model)
|
156 |
+
|
157 |
+
### choose model checkpoint and config
|
158 |
+
delta_ckpt = glob.glob(f"{MODELS_DIR}/*{category}{category_single_id}*/checkpoints/step=*.ckpt")[0]
|
159 |
+
print(f"Loading model from {delta_ckpt}")
|
160 |
+
|
161 |
+
logdir = delta_ckpt.split('/checkpoints')[0]
|
162 |
+
config = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))[-1]
|
163 |
+
|
164 |
+
start_time = time.time()
|
165 |
+
current_model, current_data = load_and_return_model_and_data(config, current_model,
|
166 |
+
delta_ckpt=delta_ckpt
|
167 |
+
)
|
168 |
+
|
169 |
+
print(f"Time taken to load delta model: {time.time() - start_time:.2f}s")
|
170 |
+
|
171 |
+
print("!!! model loaded")
|
172 |
+
|
173 |
+
input_prompt = f"photo of a <new1> {category}"
|
174 |
+
return "### Model loaded!", input_prompt
|
175 |
+
|
176 |
+
|
177 |
+
global current_data
|
178 |
+
global current_model
|
179 |
+
current_data = None
|
180 |
+
current_model = None
|
181 |
+
|
182 |
+
global base_model
|
183 |
+
BASE_CONFIG = "configs/train_co3d_concept.yaml"
|
184 |
+
BASE_CKPT = "pretrained-models/sd_xl_base_1.0.safetensors"
|
185 |
+
|
186 |
+
start_time = time.time()
|
187 |
+
base_model = load_base_model(BASE_CONFIG, ckpt=BASE_CKPT, verbose=False)
|
188 |
+
print(f"Time taken to load base model: {time.time() - start_time:.2f}s")
|
189 |
+
|
190 |
+
global curr_camera_dict
|
191 |
+
curr_camera_dict = {
|
192 |
+
"scene.camera": {
|
193 |
+
"up": {"x": -0.13227683305740356,
|
194 |
+
"y": -0.9911391735076904,
|
195 |
+
"z": -0.013464212417602539},
|
196 |
+
"center": {"x": -0.005292057991027832,
|
197 |
+
"y": 0.020704858005046844,
|
198 |
+
"z": 0.0873757004737854},
|
199 |
+
"eye": {"x": 0.8585731983184814,
|
200 |
+
"y": -0.08790968358516693,
|
201 |
+
"z": -0.40458938479423523},
|
202 |
+
},
|
203 |
+
"scene.aspectratio": {"x": 1.974, "y": 1.974, "z": 1.974},
|
204 |
+
"scene.aspectmode": "manual"
|
205 |
+
}
|
206 |
+
|
207 |
+
global prev_camera_dict
|
208 |
+
prev_camera_dict = copy.deepcopy(curr_camera_dict)
|
209 |
+
|
210 |
+
global obj_filename
|
211 |
+
obj_filename = "assets/car0_mesh_centered_flipped.obj"
|
212 |
+
global plane_trans
|
213 |
+
plane_trans = 0.16
|
214 |
+
|
215 |
+
my_fig = get_input_pose_fig()
|
216 |
+
|
217 |
+
scripts = open("scripts.js", "r").read()
|
218 |
+
|
219 |
+
|
220 |
+
def update_category_single_id(category):
|
221 |
+
global curr_camera_dict
|
222 |
+
global prev_camera_dict
|
223 |
+
global obj_filename
|
224 |
+
global plane_trans
|
225 |
+
choices = None
|
226 |
+
|
227 |
+
if category == "car":
|
228 |
+
choices = ["0"]
|
229 |
+
curr_camera_dict = {
|
230 |
+
"scene.camera": {
|
231 |
+
"up": {"x": -0.13227683305740356,
|
232 |
+
"y": -0.9911391735076904,
|
233 |
+
"z": -0.013464212417602539},
|
234 |
+
"center": {"x": -0.005292057991027832,
|
235 |
+
"y": 0.020704858005046844,
|
236 |
+
"z": 0.0873757004737854},
|
237 |
+
"eye": {"x": 0.8585731983184814,
|
238 |
+
"y": -0.08790968358516693,
|
239 |
+
"z": -0.40458938479423523},
|
240 |
+
},
|
241 |
+
"scene.aspectratio": {"x": 1.974, "y": 1.974, "z": 1.974},
|
242 |
+
"scene.aspectmode": "manual"
|
243 |
+
}
|
244 |
+
plane_trans = 0.16
|
245 |
+
|
246 |
+
elif category == "chair":
|
247 |
+
choices = ["191"]
|
248 |
+
curr_camera_dict = {
|
249 |
+
"scene.camera": {
|
250 |
+
"up": {"x": 1.0477e-04,
|
251 |
+
"y": -9.9995e-01,
|
252 |
+
"z": 1.0288e-02},
|
253 |
+
"center": {"x": 0.0539,
|
254 |
+
"y": 0.0015,
|
255 |
+
"z": 0.0007},
|
256 |
+
"eye": {"x": 0.0410,
|
257 |
+
"y": -0.0091,
|
258 |
+
"z": -0.9991},
|
259 |
+
},
|
260 |
+
"scene.aspectratio": {"x": 0.9084, "y": 0.9084, "z": 0.9084},
|
261 |
+
"scene.aspectmode": "manual"
|
262 |
+
}
|
263 |
+
plane_trans = 0.38
|
264 |
+
|
265 |
+
elif category == "motorcycle":
|
266 |
+
choices = ["12"]
|
267 |
+
curr_camera_dict = {
|
268 |
+
"scene.camera": {
|
269 |
+
"up": {"x": 0.0308,
|
270 |
+
"y": -0.9994,
|
271 |
+
"z": -0.0147},
|
272 |
+
"center": {"x": 0.0240,
|
273 |
+
"y": -0.0310,
|
274 |
+
"z": -0.0016},
|
275 |
+
"eye": {"x": -0.0580,
|
276 |
+
"y": -0.0188,
|
277 |
+
"z": -0.9981},
|
278 |
+
},
|
279 |
+
"scene.aspectratio": {"x": 1.5786, "y": 1.5786, "z": 1.5786},
|
280 |
+
"scene.aspectmode": "manual"
|
281 |
+
}
|
282 |
+
plane_trans = 0.16
|
283 |
+
|
284 |
+
elif category == "teddybear":
|
285 |
+
choices = ["31"]
|
286 |
+
curr_camera_dict = {
|
287 |
+
"scene.camera": {
|
288 |
+
"up": {"x": 0.4304,
|
289 |
+
"y": -0.9023,
|
290 |
+
"z": -0.0221},
|
291 |
+
"center": {"x": -0.0658,
|
292 |
+
"y": 0.2081,
|
293 |
+
"z": 0.0175},
|
294 |
+
"eye": {"x": -0.4456,
|
295 |
+
"y": 0.0493,
|
296 |
+
"z": -0.8939},
|
297 |
+
},
|
298 |
+
"scene.aspectratio": {"x": 1.8052, "y": 1.8052, "z": 1.8052},
|
299 |
+
"scene.aspectmode": "manual",
|
300 |
+
}
|
301 |
+
plane_trans = 0.23
|
302 |
+
|
303 |
+
obj_filename = f"assets/{category}{choices[0]}_mesh_centered_flipped.obj"
|
304 |
+
prev_camera_dict = copy.deepcopy(curr_camera_dict)
|
305 |
+
return gr.Dropdown(choices=choices, label="Object ID", value=choices[0])
|
306 |
+
|
307 |
+
|
308 |
+
head = """
|
309 |
+
<script src="https://cdn.plot.ly/plotly-2.30.0.min.js" charset="utf-8"></script>
|
310 |
+
"""
|
311 |
+
|
312 |
+
ORIGINAL_SPACE_ID = 'customdiffusion360'
|
313 |
+
SPACE_ID = os.getenv('SPACE_ID')
|
314 |
+
|
315 |
+
SHARED_UI_WARNING = f'''## Attention - the demo requires at least 40GB VRAM for inference. Please clone this repository to run on your own machine.
|
316 |
+
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
|
317 |
+
'''
|
318 |
+
|
319 |
+
with gr.Blocks(head=head,
|
320 |
+
css="style.css",
|
321 |
+
js=scripts,
|
322 |
+
title="Customizing Text-to-Image Diffusion with Camera Viewpoint Control") as demo:
|
323 |
+
|
324 |
+
gr.HTML("""
|
325 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
326 |
+
<div>
|
327 |
+
<h1>Customizing Text-to-Image Diffusion with Camera Viewpoint Control</h1>
|
328 |
+
</div>
|
329 |
+
</div>
|
330 |
+
<div>
|
331 |
+
</br>
|
332 |
+
</div>
|
333 |
+
<hr></hr>
|
334 |
+
""",
|
335 |
+
visible=True
|
336 |
+
)
|
337 |
+
|
338 |
+
|
339 |
+
if SPACE_ID == ORIGINAL_SPACE_ID:
|
340 |
+
gr.Markdown(SHARED_UI_WARNING)
|
341 |
+
|
342 |
+
with gr.Row():
|
343 |
+
with gr.Column(min_width=150):
|
344 |
+
gr.Markdown("## 1. SELECT CUSTOMIZED MODEL")
|
345 |
+
|
346 |
+
category = gr.Dropdown(choices=["car", "chair", "motorcycle", "teddybear"], label="Category", value="car")
|
347 |
+
|
348 |
+
category_single_id = gr.Dropdown(label="Object ID", choices=["0"], type="value", value="0", visible=False)
|
349 |
+
|
350 |
+
category.change(update_category_single_id, [category], [category_single_id])
|
351 |
+
|
352 |
+
load_model_btn = gr.Button(value="Load Model", elem_id="load_model_button")
|
353 |
+
|
354 |
+
load_model_status = gr.Markdown(elem_id="load_model_status", value="### Please select and load a model.")
|
355 |
+
|
356 |
+
with gr.Column(min_width=512):
|
357 |
+
gr.Markdown("## 2. CAMERA POSE VISUALIZATION")
|
358 |
+
|
359 |
+
# TODO ? don't use gradio plotly element so we can remove menu buttons
|
360 |
+
map = gr.Plot(value=my_fig, min_width=512, elem_id="map")
|
361 |
+
|
362 |
+
### hidden elements
|
363 |
+
update_pose_btn = gr.Button(value="Update Camera Pose", visible=False, elem_id="update_pose_button")
|
364 |
+
input_pose = gr.TextArea(value=curr_camera_dict, label="Input Camera Pose", visible=False, elem_id="input_pose", interactive=False)
|
365 |
+
check_pose_btn = gr.Button(value="Check Camera Pose", visible=False, elem_id="check_pose_button")
|
366 |
+
|
367 |
+
## TODO: track init_camera_dict and with js?
|
368 |
+
|
369 |
+
### visible elements
|
370 |
+
input_prompt = gr.Textbox(value="photo of a <new1> car", label="Prompt", interactive=True)
|
371 |
+
scale_im = gr.Slider(value=3.5, label="Image guidance scale", minimum=0, maximum=20.0, step=0.1)
|
372 |
+
scale = gr.Slider(value=7.5, label="Text guidance scale", minimum=0, maximum=20.0, step=0.1)
|
373 |
+
steps = gr.Slider(value=10, label="Inference steps", minimum=1, maximum=50, step=1)
|
374 |
+
seed = gr.Textbox(value=42, label="Seed")
|
375 |
+
|
376 |
+
with gr.Column(min_width=50, elem_id="column_process", scale=0.3):
|
377 |
+
run_btn = gr.Button(value="Run", elem_id="run_button", min_width=50)
|
378 |
+
|
379 |
+
|
380 |
+
with gr.Column(min_width=512):
|
381 |
+
gr.Markdown("## 3. OUR OUTPUT")
|
382 |
+
result = gr.Image(show_label=False, show_download_button=True, width=512, height=512, elem_id="result")
|
383 |
+
|
384 |
+
load_model_btn.click(select_and_load_model, [category, category_single_id], [load_model_status, input_prompt])
|
385 |
+
load_model_btn.click(get_input_pose_fig, [], [map])
|
386 |
+
|
387 |
+
update_pose_btn.click(update_curr_camera_dict, [input_pose], [input_pose],) # js=send_js_camera_to_gradio)
|
388 |
+
# check_pose_btn.click(check_curr_camera_dict, [], [input_pose])
|
389 |
+
run_btn.click(run_inference, [input_pose, input_prompt, scale_im, scale, steps, seed], result)
|
390 |
+
|
391 |
+
demo.load(js=scripts)
|
392 |
+
|
393 |
+
|
394 |
+
if __name__ == "__main__":
|
395 |
+
demo.queue().launch(debug=True)
|
assets/car0_mesh_centered_flipped.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
assets/chair191_mesh_centered_flipped.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
assets/motorcycle12_mesh_centered_flipped.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
assets/motorcycle29_mesh_centered_flipped.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
assets/plane.obj
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
o plane
|
2 |
+
v 3.000000 -3.000000 0.000000
|
3 |
+
v 3.000000 3.000000 0.000000
|
4 |
+
v -3.000000 3.000000 0.000000
|
5 |
+
v -3.000000 -3.000000 0.000000
|
6 |
+
|
7 |
+
vt 3.000000 0.000000
|
8 |
+
vt 3.000000 3.000000
|
9 |
+
vt 0.000000 3.000000
|
10 |
+
vt 0.000000 0.000000
|
11 |
+
|
12 |
+
s off
|
13 |
+
f 1/1 2/2 3/3
|
14 |
+
f 1/1 3/3 4/4
|
assets/teddybear0_mesh_centered_flipped.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
assets/teddybear31_mesh_centered_flipped.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
configs/train_co3d_concept.yaml
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-4
|
3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
4 |
+
params:
|
5 |
+
scale_factor: 0.13025
|
6 |
+
disable_first_stage_autocast: True
|
7 |
+
trainkeys: pose
|
8 |
+
multiplier: 0.05
|
9 |
+
loss_rgb_lambda: 5
|
10 |
+
loss_fg_lambda: 10
|
11 |
+
loss_bg_lambda: 10
|
12 |
+
log_keys:
|
13 |
+
- txt
|
14 |
+
|
15 |
+
denoiser_config:
|
16 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
17 |
+
params:
|
18 |
+
num_idx: 1000
|
19 |
+
|
20 |
+
weighting_config:
|
21 |
+
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
22 |
+
scaling_config:
|
23 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
24 |
+
discretization_config:
|
25 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
26 |
+
|
27 |
+
network_config:
|
28 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
29 |
+
params:
|
30 |
+
adm_in_channels: 2816
|
31 |
+
num_classes: sequential
|
32 |
+
use_checkpoint: False
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [4, 2]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [1, 2, 4]
|
39 |
+
num_head_channels: 64
|
40 |
+
use_linear_in_transformer: True
|
41 |
+
transformer_depth: [1, 2, 10]
|
42 |
+
context_dim: 2048
|
43 |
+
spatial_transformer_attn_type: softmax-xformers
|
44 |
+
image_cross_blocks: [0, 2, 4, 6, 8, 10]
|
45 |
+
rgb: True
|
46 |
+
far: 2
|
47 |
+
num_samples: 24
|
48 |
+
not_add_context_in_triplane: False
|
49 |
+
rgb_predict: True
|
50 |
+
add_lora: False
|
51 |
+
average: False
|
52 |
+
use_prev_weights_imp_sample: True
|
53 |
+
stratified: True
|
54 |
+
imp_sampling_percent: 0.9
|
55 |
+
|
56 |
+
conditioner_config:
|
57 |
+
target: sgm.modules.GeneralConditioner
|
58 |
+
params:
|
59 |
+
emb_models:
|
60 |
+
# crossattn cond
|
61 |
+
- is_trainable: False
|
62 |
+
input_keys: txt,txt_ref
|
63 |
+
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
64 |
+
params:
|
65 |
+
layer: hidden
|
66 |
+
layer_idx: 11
|
67 |
+
modifier_token: <new1>
|
68 |
+
# crossattn and vector cond
|
69 |
+
- is_trainable: False
|
70 |
+
input_keys: txt,txt_ref
|
71 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
72 |
+
params:
|
73 |
+
arch: ViT-bigG-14
|
74 |
+
version: laion2b_s39b_b160k
|
75 |
+
layer: penultimate
|
76 |
+
always_return_pooled: True
|
77 |
+
legacy: False
|
78 |
+
modifier_token: <new1>
|
79 |
+
# vector cond
|
80 |
+
- is_trainable: False
|
81 |
+
input_keys: original_size_as_tuple,original_size_as_tuple_ref
|
82 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
83 |
+
params:
|
84 |
+
outdim: 256 # multiplied by two
|
85 |
+
# vector cond
|
86 |
+
- is_trainable: False
|
87 |
+
input_keys: crop_coords_top_left,crop_coords_top_left_ref
|
88 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
89 |
+
params:
|
90 |
+
outdim: 256 # multiplied by two
|
91 |
+
# vector cond
|
92 |
+
- is_trainable: False
|
93 |
+
input_keys: target_size_as_tuple,target_size_as_tuple_ref
|
94 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
95 |
+
params:
|
96 |
+
outdim: 256 # multiplied by two
|
97 |
+
|
98 |
+
first_stage_config:
|
99 |
+
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
100 |
+
params:
|
101 |
+
ckpt_path: pretrained-models/sdxl_vae.safetensors
|
102 |
+
embed_dim: 4
|
103 |
+
monitor: val/rec_loss
|
104 |
+
ddconfig:
|
105 |
+
attn_type: vanilla-xformers
|
106 |
+
double_z: true
|
107 |
+
z_channels: 4
|
108 |
+
resolution: 256
|
109 |
+
in_channels: 3
|
110 |
+
out_ch: 3
|
111 |
+
ch: 128
|
112 |
+
ch_mult: [1, 2, 4, 4]
|
113 |
+
num_res_blocks: 2
|
114 |
+
attn_resolutions: []
|
115 |
+
dropout: 0.0
|
116 |
+
lossconfig:
|
117 |
+
target: torch.nn.Identity
|
118 |
+
|
119 |
+
loss_fn_config:
|
120 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef
|
121 |
+
params:
|
122 |
+
sigma_sampler_config:
|
123 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling
|
124 |
+
params:
|
125 |
+
num_idx: 1000
|
126 |
+
discretization_config:
|
127 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
128 |
+
sigma_sampler_config_ref:
|
129 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
130 |
+
params:
|
131 |
+
num_idx: 50
|
132 |
+
|
133 |
+
discretization_config:
|
134 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
135 |
+
|
136 |
+
sampler_config:
|
137 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
138 |
+
params:
|
139 |
+
num_steps: 50
|
140 |
+
|
141 |
+
discretization_config:
|
142 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
143 |
+
|
144 |
+
guider_config:
|
145 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef
|
146 |
+
params:
|
147 |
+
scale: 7.5
|
148 |
+
|
149 |
+
data:
|
150 |
+
target: sgm.data.data_co3d.CustomDataDictLoader
|
151 |
+
params:
|
152 |
+
batch_size: 1
|
153 |
+
num_workers: 4
|
154 |
+
category: teddybear
|
155 |
+
img_size: 512
|
156 |
+
skip: 2
|
157 |
+
num_images: 5
|
158 |
+
mask_images: True
|
159 |
+
single_id: 0
|
160 |
+
bbox: True
|
161 |
+
addreg: True
|
162 |
+
drop_ratio: 0.25
|
163 |
+
drop_txt: 0.1
|
164 |
+
modifier_token: <new1>
|
165 |
+
|
166 |
+
lightning:
|
167 |
+
modelcheckpoint:
|
168 |
+
params:
|
169 |
+
every_n_train_steps: 1600
|
170 |
+
save_top_k: -1
|
171 |
+
save_on_train_epoch_end: False
|
172 |
+
|
173 |
+
callbacks:
|
174 |
+
metrics_over_trainsteps_checkpoint:
|
175 |
+
params:
|
176 |
+
every_n_train_steps: 25000
|
177 |
+
|
178 |
+
image_logger:
|
179 |
+
target: main.ImageLogger
|
180 |
+
params:
|
181 |
+
disabled: False
|
182 |
+
enable_autocast: False
|
183 |
+
batch_frequency: 5000
|
184 |
+
max_images: 8
|
185 |
+
increase_log_steps: False
|
186 |
+
log_first_step: False
|
187 |
+
log_images_kwargs:
|
188 |
+
use_ema_scope: False
|
189 |
+
N: 1
|
190 |
+
n_rows: 2
|
191 |
+
|
192 |
+
trainer:
|
193 |
+
devices: 0,1,2,3
|
194 |
+
benchmark: True
|
195 |
+
num_sanity_val_steps: 0
|
196 |
+
accumulate_grad_batches: 1
|
197 |
+
max_steps: 1610
|
198 |
+
# val_check_interval: 400
|
pretrained_models/car0/checkpoints/step=000001600.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5b073c96fe525f9530dc69e5fd8e94d6527a7651c5bb4ede5953750fbe157ebd
|
3 |
+
size 777852660
|
pretrained_models/car0/configs/2024-04-12T21-30-20-lightning.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lightning:
|
2 |
+
modelcheckpoint:
|
3 |
+
params:
|
4 |
+
every_n_train_steps: 1600
|
5 |
+
save_top_k: -1
|
6 |
+
save_on_train_epoch_end: false
|
7 |
+
callbacks:
|
8 |
+
metrics_over_trainsteps_checkpoint:
|
9 |
+
params:
|
10 |
+
every_n_train_steps: 25000
|
11 |
+
image_logger:
|
12 |
+
target: main.ImageLogger
|
13 |
+
params:
|
14 |
+
disabled: false
|
15 |
+
enable_autocast: false
|
16 |
+
batch_frequency: 5000
|
17 |
+
max_images: 8
|
18 |
+
increase_log_steps: false
|
19 |
+
log_first_step: false
|
20 |
+
log_images_kwargs:
|
21 |
+
use_ema_scope: false
|
22 |
+
'N': 1
|
23 |
+
n_rows: 2
|
24 |
+
trainer:
|
25 |
+
devices: 0,1,2,3
|
26 |
+
benchmark: true
|
27 |
+
num_sanity_val_steps: 0
|
28 |
+
accumulate_grad_batches: 1
|
29 |
+
max_steps: 1610
|
30 |
+
val_check_interval: 400
|
31 |
+
accelerator: gpu
|
pretrained_models/car0/configs/2024-04-12T21-30-20-project.yaml
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 0.0001
|
3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
4 |
+
params:
|
5 |
+
scale_factor: 0.13025
|
6 |
+
disable_first_stage_autocast: true
|
7 |
+
trainkeys: pose
|
8 |
+
multiplier: 0.05
|
9 |
+
loss_rgb_lambda: 5
|
10 |
+
loss_fg_lambda: 10
|
11 |
+
loss_bg_lambda: 10
|
12 |
+
log_keys:
|
13 |
+
- txt
|
14 |
+
denoiser_config:
|
15 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
16 |
+
params:
|
17 |
+
num_idx: 1000
|
18 |
+
weighting_config:
|
19 |
+
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
20 |
+
scaling_config:
|
21 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
22 |
+
discretization_config:
|
23 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
24 |
+
network_config:
|
25 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
26 |
+
params:
|
27 |
+
adm_in_channels: 2816
|
28 |
+
num_classes: sequential
|
29 |
+
use_checkpoint: false
|
30 |
+
in_channels: 4
|
31 |
+
out_channels: 4
|
32 |
+
model_channels: 320
|
33 |
+
attention_resolutions:
|
34 |
+
- 4
|
35 |
+
- 2
|
36 |
+
num_res_blocks: 2
|
37 |
+
channel_mult:
|
38 |
+
- 1
|
39 |
+
- 2
|
40 |
+
- 4
|
41 |
+
num_head_channels: 64
|
42 |
+
use_linear_in_transformer: true
|
43 |
+
transformer_depth:
|
44 |
+
- 1
|
45 |
+
- 2
|
46 |
+
- 10
|
47 |
+
context_dim: 2048
|
48 |
+
spatial_transformer_attn_type: softmax-xformers
|
49 |
+
image_cross_blocks:
|
50 |
+
- 0
|
51 |
+
- 2
|
52 |
+
- 4
|
53 |
+
- 6
|
54 |
+
- 8
|
55 |
+
- 10
|
56 |
+
rgb: true
|
57 |
+
far: 2
|
58 |
+
num_samples: 24
|
59 |
+
not_add_context_in_triplane: false
|
60 |
+
rgb_predict: true
|
61 |
+
add_lora: false
|
62 |
+
average: false
|
63 |
+
use_prev_weights_imp_sample: true
|
64 |
+
stratified: true
|
65 |
+
imp_sampling_percent: 0.9
|
66 |
+
use_prev_weights_imp_sample: true
|
67 |
+
conditioner_config:
|
68 |
+
target: sgm.modules.GeneralConditioner
|
69 |
+
params:
|
70 |
+
emb_models:
|
71 |
+
- is_trainable: false
|
72 |
+
input_keys: txt,txt_ref
|
73 |
+
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
74 |
+
params:
|
75 |
+
layer: hidden
|
76 |
+
layer_idx: 11
|
77 |
+
modifier_token: <new1>
|
78 |
+
- is_trainable: false
|
79 |
+
input_keys: txt,txt_ref
|
80 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
81 |
+
params:
|
82 |
+
arch: ViT-bigG-14
|
83 |
+
version: laion2b_s39b_b160k
|
84 |
+
layer: penultimate
|
85 |
+
always_return_pooled: true
|
86 |
+
legacy: false
|
87 |
+
modifier_token: <new1>
|
88 |
+
- is_trainable: false
|
89 |
+
input_keys: original_size_as_tuple,original_size_as_tuple_ref
|
90 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
91 |
+
params:
|
92 |
+
outdim: 256
|
93 |
+
- is_trainable: false
|
94 |
+
input_keys: crop_coords_top_left,crop_coords_top_left_ref
|
95 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
96 |
+
params:
|
97 |
+
outdim: 256
|
98 |
+
- is_trainable: false
|
99 |
+
input_keys: target_size_as_tuple,target_size_as_tuple_ref
|
100 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
101 |
+
params:
|
102 |
+
outdim: 256
|
103 |
+
first_stage_config:
|
104 |
+
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
105 |
+
params:
|
106 |
+
ckpt_path: /sensei-fs/tenants/Sensei-AdobeResearchTeam/nupkumar1/custom-pose/pretrained-models/sdxl_vae.safetensors
|
107 |
+
embed_dim: 4
|
108 |
+
monitor: val/rec_loss
|
109 |
+
ddconfig:
|
110 |
+
attn_type: vanilla-xformers
|
111 |
+
double_z: true
|
112 |
+
z_channels: 4
|
113 |
+
resolution: 256
|
114 |
+
in_channels: 3
|
115 |
+
out_ch: 3
|
116 |
+
ch: 128
|
117 |
+
ch_mult:
|
118 |
+
- 1
|
119 |
+
- 2
|
120 |
+
- 4
|
121 |
+
- 4
|
122 |
+
num_res_blocks: 2
|
123 |
+
attn_resolutions: []
|
124 |
+
dropout: 0.0
|
125 |
+
lossconfig:
|
126 |
+
target: torch.nn.Identity
|
127 |
+
loss_fn_config:
|
128 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef
|
129 |
+
params:
|
130 |
+
sigma_sampler_config:
|
131 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling
|
132 |
+
params:
|
133 |
+
num_idx: 1000
|
134 |
+
discretization_config:
|
135 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
136 |
+
sigma_sampler_config_ref:
|
137 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
138 |
+
params:
|
139 |
+
num_idx: 50
|
140 |
+
discretization_config:
|
141 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
142 |
+
sampler_config:
|
143 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
144 |
+
params:
|
145 |
+
num_steps: 50
|
146 |
+
discretization_config:
|
147 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
148 |
+
guider_config:
|
149 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef
|
150 |
+
params:
|
151 |
+
scale: 7.5
|
152 |
+
data:
|
153 |
+
target: sgm.data.data_co3d.CustomDataDictLoader
|
154 |
+
params:
|
155 |
+
batch_size: 1
|
156 |
+
num_workers: 4
|
157 |
+
category: car
|
158 |
+
img_size: 512
|
159 |
+
skip: 2
|
160 |
+
num_images: 5
|
161 |
+
mask_images: true
|
162 |
+
single_id: 0
|
163 |
+
bbox: true
|
164 |
+
addreg: true
|
165 |
+
drop_ratio: 0.25
|
166 |
+
drop_txt: 0.1
|
167 |
+
modifier_token: <new1>
|
168 |
+
categoryname: null
|
pretrained_models/car0/configs/2024-04-13T11-31-55-lightning.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lightning:
|
2 |
+
modelcheckpoint:
|
3 |
+
params:
|
4 |
+
every_n_train_steps: 1600
|
5 |
+
save_top_k: -1
|
6 |
+
save_on_train_epoch_end: false
|
7 |
+
callbacks:
|
8 |
+
metrics_over_trainsteps_checkpoint:
|
9 |
+
params:
|
10 |
+
every_n_train_steps: 25000
|
11 |
+
image_logger:
|
12 |
+
target: main.ImageLogger
|
13 |
+
params:
|
14 |
+
disabled: false
|
15 |
+
enable_autocast: false
|
16 |
+
batch_frequency: 5000
|
17 |
+
max_images: 8
|
18 |
+
increase_log_steps: false
|
19 |
+
log_first_step: false
|
20 |
+
log_images_kwargs:
|
21 |
+
use_ema_scope: false
|
22 |
+
'N': 1
|
23 |
+
n_rows: 2
|
24 |
+
trainer:
|
25 |
+
devices: 0,1,2,3
|
26 |
+
benchmark: true
|
27 |
+
num_sanity_val_steps: 0
|
28 |
+
accumulate_grad_batches: 1
|
29 |
+
max_steps: 1610
|
30 |
+
val_check_interval: 400
|
31 |
+
accelerator: gpu
|
pretrained_models/car0/configs/2024-04-13T11-31-55-project.yaml
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 0.0001
|
3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
4 |
+
params:
|
5 |
+
scale_factor: 0.13025
|
6 |
+
disable_first_stage_autocast: true
|
7 |
+
trainkeys: pose
|
8 |
+
multiplier: 0.05
|
9 |
+
loss_rgb_lambda: 5
|
10 |
+
loss_fg_lambda: 10
|
11 |
+
loss_bg_lambda: 10
|
12 |
+
log_keys:
|
13 |
+
- txt
|
14 |
+
denoiser_config:
|
15 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
16 |
+
params:
|
17 |
+
num_idx: 1000
|
18 |
+
weighting_config:
|
19 |
+
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
20 |
+
scaling_config:
|
21 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
22 |
+
discretization_config:
|
23 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
24 |
+
network_config:
|
25 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
26 |
+
params:
|
27 |
+
adm_in_channels: 2816
|
28 |
+
num_classes: sequential
|
29 |
+
use_checkpoint: false
|
30 |
+
in_channels: 4
|
31 |
+
out_channels: 4
|
32 |
+
model_channels: 320
|
33 |
+
attention_resolutions:
|
34 |
+
- 4
|
35 |
+
- 2
|
36 |
+
num_res_blocks: 2
|
37 |
+
channel_mult:
|
38 |
+
- 1
|
39 |
+
- 2
|
40 |
+
- 4
|
41 |
+
num_head_channels: 64
|
42 |
+
use_linear_in_transformer: true
|
43 |
+
transformer_depth:
|
44 |
+
- 1
|
45 |
+
- 2
|
46 |
+
- 10
|
47 |
+
context_dim: 2048
|
48 |
+
spatial_transformer_attn_type: softmax-xformers
|
49 |
+
image_cross_blocks:
|
50 |
+
- 0
|
51 |
+
- 2
|
52 |
+
- 4
|
53 |
+
- 6
|
54 |
+
- 8
|
55 |
+
- 10
|
56 |
+
rgb: true
|
57 |
+
far: 2
|
58 |
+
num_samples: 24
|
59 |
+
not_add_context_in_triplane: false
|
60 |
+
rgb_predict: true
|
61 |
+
add_lora: false
|
62 |
+
average: false
|
63 |
+
use_prev_weights_imp_sample: true
|
64 |
+
stratified: true
|
65 |
+
imp_sampling_percent: 0.9
|
66 |
+
use_prev_weights_imp_sample: true
|
67 |
+
conditioner_config:
|
68 |
+
target: sgm.modules.GeneralConditioner
|
69 |
+
params:
|
70 |
+
emb_models:
|
71 |
+
- is_trainable: false
|
72 |
+
input_keys: txt,txt_ref
|
73 |
+
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
74 |
+
params:
|
75 |
+
layer: hidden
|
76 |
+
layer_idx: 11
|
77 |
+
modifier_token: <new1>
|
78 |
+
- is_trainable: false
|
79 |
+
input_keys: txt,txt_ref
|
80 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
81 |
+
params:
|
82 |
+
arch: ViT-bigG-14
|
83 |
+
version: laion2b_s39b_b160k
|
84 |
+
layer: penultimate
|
85 |
+
always_return_pooled: true
|
86 |
+
legacy: false
|
87 |
+
modifier_token: <new1>
|
88 |
+
- is_trainable: false
|
89 |
+
input_keys: original_size_as_tuple,original_size_as_tuple_ref
|
90 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
91 |
+
params:
|
92 |
+
outdim: 256
|
93 |
+
- is_trainable: false
|
94 |
+
input_keys: crop_coords_top_left,crop_coords_top_left_ref
|
95 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
96 |
+
params:
|
97 |
+
outdim: 256
|
98 |
+
- is_trainable: false
|
99 |
+
input_keys: target_size_as_tuple,target_size_as_tuple_ref
|
100 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
101 |
+
params:
|
102 |
+
outdim: 256
|
103 |
+
first_stage_config:
|
104 |
+
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
105 |
+
params:
|
106 |
+
ckpt_path: /sensei-fs/tenants/Sensei-AdobeResearchTeam/nupkumar1/custom-pose/pretrained-models/sdxl_vae.safetensors
|
107 |
+
embed_dim: 4
|
108 |
+
monitor: val/rec_loss
|
109 |
+
ddconfig:
|
110 |
+
attn_type: vanilla-xformers
|
111 |
+
double_z: true
|
112 |
+
z_channels: 4
|
113 |
+
resolution: 256
|
114 |
+
in_channels: 3
|
115 |
+
out_ch: 3
|
116 |
+
ch: 128
|
117 |
+
ch_mult:
|
118 |
+
- 1
|
119 |
+
- 2
|
120 |
+
- 4
|
121 |
+
- 4
|
122 |
+
num_res_blocks: 2
|
123 |
+
attn_resolutions: []
|
124 |
+
dropout: 0.0
|
125 |
+
lossconfig:
|
126 |
+
target: torch.nn.Identity
|
127 |
+
loss_fn_config:
|
128 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef
|
129 |
+
params:
|
130 |
+
sigma_sampler_config:
|
131 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling
|
132 |
+
params:
|
133 |
+
num_idx: 1000
|
134 |
+
discretization_config:
|
135 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
136 |
+
sigma_sampler_config_ref:
|
137 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
138 |
+
params:
|
139 |
+
num_idx: 50
|
140 |
+
discretization_config:
|
141 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
142 |
+
sampler_config:
|
143 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
144 |
+
params:
|
145 |
+
num_steps: 50
|
146 |
+
discretization_config:
|
147 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
148 |
+
guider_config:
|
149 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef
|
150 |
+
params:
|
151 |
+
scale: 7.5
|
152 |
+
data:
|
153 |
+
target: sgm.data.data_co3d.CustomDataDictLoader
|
154 |
+
params:
|
155 |
+
batch_size: 1
|
156 |
+
num_workers: 4
|
157 |
+
category: car
|
158 |
+
img_size: 512
|
159 |
+
skip: 2
|
160 |
+
num_images: 5
|
161 |
+
mask_images: true
|
162 |
+
single_id: 0
|
163 |
+
bbox: true
|
164 |
+
addreg: true
|
165 |
+
drop_ratio: 0.25
|
166 |
+
drop_txt: 0.1
|
167 |
+
modifier_token: <new1>
|
168 |
+
categoryname: null
|
169 |
+
--log_dir: null
|
170 |
+
check_logs: null
|
pretrained_models/car0/configs/2024-04-13T11-42-30-lightning.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lightning:
|
2 |
+
modelcheckpoint:
|
3 |
+
params:
|
4 |
+
every_n_train_steps: 1600
|
5 |
+
save_top_k: -1
|
6 |
+
save_on_train_epoch_end: false
|
7 |
+
callbacks:
|
8 |
+
metrics_over_trainsteps_checkpoint:
|
9 |
+
params:
|
10 |
+
every_n_train_steps: 25000
|
11 |
+
image_logger:
|
12 |
+
target: main.ImageLogger
|
13 |
+
params:
|
14 |
+
disabled: false
|
15 |
+
enable_autocast: false
|
16 |
+
batch_frequency: 5000
|
17 |
+
max_images: 8
|
18 |
+
increase_log_steps: false
|
19 |
+
log_first_step: false
|
20 |
+
log_images_kwargs:
|
21 |
+
use_ema_scope: false
|
22 |
+
'N': 1
|
23 |
+
n_rows: 2
|
24 |
+
trainer:
|
25 |
+
devices: 0,1,2,3
|
26 |
+
benchmark: true
|
27 |
+
num_sanity_val_steps: 0
|
28 |
+
accumulate_grad_batches: 1
|
29 |
+
max_steps: 1610
|
30 |
+
val_check_interval: 400
|
31 |
+
accelerator: gpu
|
pretrained_models/car0/configs/2024-04-13T11-42-30-project.yaml
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 0.0001
|
3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
4 |
+
params:
|
5 |
+
scale_factor: 0.13025
|
6 |
+
disable_first_stage_autocast: true
|
7 |
+
trainkeys: pose
|
8 |
+
multiplier: 0.05
|
9 |
+
loss_rgb_lambda: 5
|
10 |
+
loss_fg_lambda: 10
|
11 |
+
loss_bg_lambda: 10
|
12 |
+
log_keys:
|
13 |
+
- txt
|
14 |
+
denoiser_config:
|
15 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
16 |
+
params:
|
17 |
+
num_idx: 1000
|
18 |
+
weighting_config:
|
19 |
+
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
20 |
+
scaling_config:
|
21 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
22 |
+
discretization_config:
|
23 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
24 |
+
network_config:
|
25 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
26 |
+
params:
|
27 |
+
adm_in_channels: 2816
|
28 |
+
num_classes: sequential
|
29 |
+
use_checkpoint: false
|
30 |
+
in_channels: 4
|
31 |
+
out_channels: 4
|
32 |
+
model_channels: 320
|
33 |
+
attention_resolutions:
|
34 |
+
- 4
|
35 |
+
- 2
|
36 |
+
num_res_blocks: 2
|
37 |
+
channel_mult:
|
38 |
+
- 1
|
39 |
+
- 2
|
40 |
+
- 4
|
41 |
+
num_head_channels: 64
|
42 |
+
use_linear_in_transformer: true
|
43 |
+
transformer_depth:
|
44 |
+
- 1
|
45 |
+
- 2
|
46 |
+
- 10
|
47 |
+
context_dim: 2048
|
48 |
+
spatial_transformer_attn_type: softmax-xformers
|
49 |
+
image_cross_blocks:
|
50 |
+
- 0
|
51 |
+
- 2
|
52 |
+
- 4
|
53 |
+
- 6
|
54 |
+
- 8
|
55 |
+
- 10
|
56 |
+
rgb: true
|
57 |
+
far: 2
|
58 |
+
num_samples: 24
|
59 |
+
not_add_context_in_triplane: false
|
60 |
+
rgb_predict: true
|
61 |
+
add_lora: false
|
62 |
+
average: false
|
63 |
+
use_prev_weights_imp_sample: true
|
64 |
+
stratified: true
|
65 |
+
imp_sampling_percent: 0.9
|
66 |
+
use_prev_weights_imp_sample: true
|
67 |
+
conditioner_config:
|
68 |
+
target: sgm.modules.GeneralConditioner
|
69 |
+
params:
|
70 |
+
emb_models:
|
71 |
+
- is_trainable: false
|
72 |
+
input_keys: txt,txt_ref
|
73 |
+
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
74 |
+
params:
|
75 |
+
layer: hidden
|
76 |
+
layer_idx: 11
|
77 |
+
modifier_token: <new1>
|
78 |
+
- is_trainable: false
|
79 |
+
input_keys: txt,txt_ref
|
80 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
81 |
+
params:
|
82 |
+
arch: ViT-bigG-14
|
83 |
+
version: laion2b_s39b_b160k
|
84 |
+
layer: penultimate
|
85 |
+
always_return_pooled: true
|
86 |
+
legacy: false
|
87 |
+
modifier_token: <new1>
|
88 |
+
- is_trainable: false
|
89 |
+
input_keys: original_size_as_tuple,original_size_as_tuple_ref
|
90 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
91 |
+
params:
|
92 |
+
outdim: 256
|
93 |
+
- is_trainable: false
|
94 |
+
input_keys: crop_coords_top_left,crop_coords_top_left_ref
|
95 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
96 |
+
params:
|
97 |
+
outdim: 256
|
98 |
+
- is_trainable: false
|
99 |
+
input_keys: target_size_as_tuple,target_size_as_tuple_ref
|
100 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
101 |
+
params:
|
102 |
+
outdim: 256
|
103 |
+
first_stage_config:
|
104 |
+
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
105 |
+
params:
|
106 |
+
ckpt_path: /sensei-fs/tenants/Sensei-AdobeResearchTeam/nupkumar1/custom-pose/pretrained-models/sdxl_vae.safetensors
|
107 |
+
embed_dim: 4
|
108 |
+
monitor: val/rec_loss
|
109 |
+
ddconfig:
|
110 |
+
attn_type: vanilla-xformers
|
111 |
+
double_z: true
|
112 |
+
z_channels: 4
|
113 |
+
resolution: 256
|
114 |
+
in_channels: 3
|
115 |
+
out_ch: 3
|
116 |
+
ch: 128
|
117 |
+
ch_mult:
|
118 |
+
- 1
|
119 |
+
- 2
|
120 |
+
- 4
|
121 |
+
- 4
|
122 |
+
num_res_blocks: 2
|
123 |
+
attn_resolutions: []
|
124 |
+
dropout: 0.0
|
125 |
+
lossconfig:
|
126 |
+
target: torch.nn.Identity
|
127 |
+
loss_fn_config:
|
128 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef
|
129 |
+
params:
|
130 |
+
sigma_sampler_config:
|
131 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling
|
132 |
+
params:
|
133 |
+
num_idx: 1000
|
134 |
+
discretization_config:
|
135 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
136 |
+
sigma_sampler_config_ref:
|
137 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
138 |
+
params:
|
139 |
+
num_idx: 50
|
140 |
+
discretization_config:
|
141 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
142 |
+
sampler_config:
|
143 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
144 |
+
params:
|
145 |
+
num_steps: 50
|
146 |
+
discretization_config:
|
147 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
148 |
+
guider_config:
|
149 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef
|
150 |
+
params:
|
151 |
+
scale: 7.5
|
152 |
+
data:
|
153 |
+
target: sgm.data.data_co3d.CustomDataDictLoader
|
154 |
+
params:
|
155 |
+
batch_size: 1
|
156 |
+
num_workers: 4
|
157 |
+
category: car
|
158 |
+
img_size: 512
|
159 |
+
skip: 2
|
160 |
+
num_images: 5
|
161 |
+
mask_images: true
|
162 |
+
single_id: 0
|
163 |
+
bbox: true
|
164 |
+
addreg: true
|
165 |
+
drop_ratio: 0.25
|
166 |
+
drop_txt: 0.1
|
167 |
+
modifier_token: <new1>
|
168 |
+
categoryname: null
|
169 |
+
--log_dir: null
|
170 |
+
check_logs: null
|
pretrained_models/chair191/checkpoints/step=000001600.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d83a24db919f95ea487572b45524ae1073a1638612a751ecb9589d2060e9b991
|
3 |
+
size 777852660
|
pretrained_models/chair191/configs/2024-04-12T22-10-18-lightning.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lightning:
|
2 |
+
modelcheckpoint:
|
3 |
+
params:
|
4 |
+
every_n_train_steps: 1600
|
5 |
+
save_top_k: -1
|
6 |
+
save_on_train_epoch_end: false
|
7 |
+
callbacks:
|
8 |
+
metrics_over_trainsteps_checkpoint:
|
9 |
+
params:
|
10 |
+
every_n_train_steps: 25000
|
11 |
+
image_logger:
|
12 |
+
target: main.ImageLogger
|
13 |
+
params:
|
14 |
+
disabled: false
|
15 |
+
enable_autocast: false
|
16 |
+
batch_frequency: 5000
|
17 |
+
max_images: 8
|
18 |
+
increase_log_steps: false
|
19 |
+
log_first_step: false
|
20 |
+
log_images_kwargs:
|
21 |
+
use_ema_scope: false
|
22 |
+
'N': 1
|
23 |
+
n_rows: 2
|
24 |
+
trainer:
|
25 |
+
devices: 0,1,2,3
|
26 |
+
benchmark: true
|
27 |
+
num_sanity_val_steps: 0
|
28 |
+
accumulate_grad_batches: 1
|
29 |
+
max_steps: 1610
|
30 |
+
val_check_interval: 400
|
31 |
+
accelerator: gpu
|
pretrained_models/chair191/configs/2024-04-12T22-10-18-project.yaml
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 0.0001
|
3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
4 |
+
params:
|
5 |
+
scale_factor: 0.13025
|
6 |
+
disable_first_stage_autocast: true
|
7 |
+
trainkeys: pose
|
8 |
+
multiplier: 0.05
|
9 |
+
loss_rgb_lambda: 5
|
10 |
+
loss_fg_lambda: 10
|
11 |
+
loss_bg_lambda: 10
|
12 |
+
log_keys:
|
13 |
+
- txt
|
14 |
+
denoiser_config:
|
15 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
16 |
+
params:
|
17 |
+
num_idx: 1000
|
18 |
+
weighting_config:
|
19 |
+
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
20 |
+
scaling_config:
|
21 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
22 |
+
discretization_config:
|
23 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
24 |
+
network_config:
|
25 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
26 |
+
params:
|
27 |
+
adm_in_channels: 2816
|
28 |
+
num_classes: sequential
|
29 |
+
use_checkpoint: false
|
30 |
+
in_channels: 4
|
31 |
+
out_channels: 4
|
32 |
+
model_channels: 320
|
33 |
+
attention_resolutions:
|
34 |
+
- 4
|
35 |
+
- 2
|
36 |
+
num_res_blocks: 2
|
37 |
+
channel_mult:
|
38 |
+
- 1
|
39 |
+
- 2
|
40 |
+
- 4
|
41 |
+
num_head_channels: 64
|
42 |
+
use_linear_in_transformer: true
|
43 |
+
transformer_depth:
|
44 |
+
- 1
|
45 |
+
- 2
|
46 |
+
- 10
|
47 |
+
context_dim: 2048
|
48 |
+
spatial_transformer_attn_type: softmax-xformers
|
49 |
+
image_cross_blocks:
|
50 |
+
- 0
|
51 |
+
- 2
|
52 |
+
- 4
|
53 |
+
- 6
|
54 |
+
- 8
|
55 |
+
- 10
|
56 |
+
rgb: true
|
57 |
+
far: 2
|
58 |
+
num_samples: 24
|
59 |
+
not_add_context_in_triplane: false
|
60 |
+
rgb_predict: true
|
61 |
+
add_lora: false
|
62 |
+
average: false
|
63 |
+
use_prev_weights_imp_sample: true
|
64 |
+
stratified: true
|
65 |
+
imp_sampling_percent: 0.9
|
66 |
+
use_prev_weights_imp_sample: true
|
67 |
+
conditioner_config:
|
68 |
+
target: sgm.modules.GeneralConditioner
|
69 |
+
params:
|
70 |
+
emb_models:
|
71 |
+
- is_trainable: false
|
72 |
+
input_keys: txt,txt_ref
|
73 |
+
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
74 |
+
params:
|
75 |
+
layer: hidden
|
76 |
+
layer_idx: 11
|
77 |
+
modifier_token: <new1>
|
78 |
+
- is_trainable: false
|
79 |
+
input_keys: txt,txt_ref
|
80 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
81 |
+
params:
|
82 |
+
arch: ViT-bigG-14
|
83 |
+
version: laion2b_s39b_b160k
|
84 |
+
layer: penultimate
|
85 |
+
always_return_pooled: true
|
86 |
+
legacy: false
|
87 |
+
modifier_token: <new1>
|
88 |
+
- is_trainable: false
|
89 |
+
input_keys: original_size_as_tuple,original_size_as_tuple_ref
|
90 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
91 |
+
params:
|
92 |
+
outdim: 256
|
93 |
+
- is_trainable: false
|
94 |
+
input_keys: crop_coords_top_left,crop_coords_top_left_ref
|
95 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
96 |
+
params:
|
97 |
+
outdim: 256
|
98 |
+
- is_trainable: false
|
99 |
+
input_keys: target_size_as_tuple,target_size_as_tuple_ref
|
100 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
101 |
+
params:
|
102 |
+
outdim: 256
|
103 |
+
first_stage_config:
|
104 |
+
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
105 |
+
params:
|
106 |
+
ckpt_path: /sensei-fs/tenants/Sensei-AdobeResearchTeam/nupkumar1/custom-pose/pretrained-models/sdxl_vae.safetensors
|
107 |
+
embed_dim: 4
|
108 |
+
monitor: val/rec_loss
|
109 |
+
ddconfig:
|
110 |
+
attn_type: vanilla-xformers
|
111 |
+
double_z: true
|
112 |
+
z_channels: 4
|
113 |
+
resolution: 256
|
114 |
+
in_channels: 3
|
115 |
+
out_ch: 3
|
116 |
+
ch: 128
|
117 |
+
ch_mult:
|
118 |
+
- 1
|
119 |
+
- 2
|
120 |
+
- 4
|
121 |
+
- 4
|
122 |
+
num_res_blocks: 2
|
123 |
+
attn_resolutions: []
|
124 |
+
dropout: 0.0
|
125 |
+
lossconfig:
|
126 |
+
target: torch.nn.Identity
|
127 |
+
loss_fn_config:
|
128 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef
|
129 |
+
params:
|
130 |
+
sigma_sampler_config:
|
131 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling
|
132 |
+
params:
|
133 |
+
num_idx: 1000
|
134 |
+
discretization_config:
|
135 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
136 |
+
sigma_sampler_config_ref:
|
137 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
138 |
+
params:
|
139 |
+
num_idx: 50
|
140 |
+
discretization_config:
|
141 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
142 |
+
sampler_config:
|
143 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
144 |
+
params:
|
145 |
+
num_steps: 50
|
146 |
+
discretization_config:
|
147 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
148 |
+
guider_config:
|
149 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef
|
150 |
+
params:
|
151 |
+
scale: 7.5
|
152 |
+
data:
|
153 |
+
target: sgm.data.data_co3d.CustomDataDictLoader
|
154 |
+
params:
|
155 |
+
batch_size: 1
|
156 |
+
num_workers: 4
|
157 |
+
category: chair
|
158 |
+
img_size: 512
|
159 |
+
skip: 2
|
160 |
+
num_images: 5
|
161 |
+
mask_images: true
|
162 |
+
single_id: 191
|
163 |
+
bbox: true
|
164 |
+
addreg: true
|
165 |
+
drop_ratio: 0.25
|
166 |
+
drop_txt: 0.1
|
167 |
+
modifier_token: <new1>
|
168 |
+
categoryname: null
|
pretrained_models/motorcycle12/checkpoints/step=000001600.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:11acc84e7c6fbbc9b47f7021dcdaa55032c1e48b88b6bc9a8bb8689f59521c99
|
3 |
+
size 777852660
|
pretrained_models/motorcycle12/configs/2024-04-12T23-30-18-project.yaml
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 0.0001
|
3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
4 |
+
params:
|
5 |
+
scale_factor: 0.13025
|
6 |
+
disable_first_stage_autocast: true
|
7 |
+
trainkeys: pose
|
8 |
+
multiplier: 0.05
|
9 |
+
loss_rgb_lambda: 5
|
10 |
+
loss_fg_lambda: 10
|
11 |
+
loss_bg_lambda: 10
|
12 |
+
log_keys:
|
13 |
+
- txt
|
14 |
+
denoiser_config:
|
15 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
16 |
+
params:
|
17 |
+
num_idx: 1000
|
18 |
+
weighting_config:
|
19 |
+
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
20 |
+
scaling_config:
|
21 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
22 |
+
discretization_config:
|
23 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
24 |
+
network_config:
|
25 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
26 |
+
params:
|
27 |
+
adm_in_channels: 2816
|
28 |
+
num_classes: sequential
|
29 |
+
use_checkpoint: false
|
30 |
+
in_channels: 4
|
31 |
+
out_channels: 4
|
32 |
+
model_channels: 320
|
33 |
+
attention_resolutions:
|
34 |
+
- 4
|
35 |
+
- 2
|
36 |
+
num_res_blocks: 2
|
37 |
+
channel_mult:
|
38 |
+
- 1
|
39 |
+
- 2
|
40 |
+
- 4
|
41 |
+
num_head_channels: 64
|
42 |
+
use_linear_in_transformer: true
|
43 |
+
transformer_depth:
|
44 |
+
- 1
|
45 |
+
- 2
|
46 |
+
- 10
|
47 |
+
context_dim: 2048
|
48 |
+
spatial_transformer_attn_type: softmax-xformers
|
49 |
+
image_cross_blocks:
|
50 |
+
- 0
|
51 |
+
- 2
|
52 |
+
- 4
|
53 |
+
- 6
|
54 |
+
- 8
|
55 |
+
- 10
|
56 |
+
rgb: true
|
57 |
+
far: 2
|
58 |
+
num_samples: 24
|
59 |
+
not_add_context_in_triplane: false
|
60 |
+
rgb_predict: true
|
61 |
+
add_lora: false
|
62 |
+
average: false
|
63 |
+
use_prev_weights_imp_sample: true
|
64 |
+
stratified: true
|
65 |
+
imp_sampling_percent: 0.9
|
66 |
+
use_prev_weights_imp_sample: true
|
67 |
+
conditioner_config:
|
68 |
+
target: sgm.modules.GeneralConditioner
|
69 |
+
params:
|
70 |
+
emb_models:
|
71 |
+
- is_trainable: false
|
72 |
+
input_keys: txt,txt_ref
|
73 |
+
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
74 |
+
params:
|
75 |
+
layer: hidden
|
76 |
+
layer_idx: 11
|
77 |
+
modifier_token: <new1>
|
78 |
+
- is_trainable: false
|
79 |
+
input_keys: txt,txt_ref
|
80 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
81 |
+
params:
|
82 |
+
arch: ViT-bigG-14
|
83 |
+
version: laion2b_s39b_b160k
|
84 |
+
layer: penultimate
|
85 |
+
always_return_pooled: true
|
86 |
+
legacy: false
|
87 |
+
modifier_token: <new1>
|
88 |
+
- is_trainable: false
|
89 |
+
input_keys: original_size_as_tuple,original_size_as_tuple_ref
|
90 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
91 |
+
params:
|
92 |
+
outdim: 256
|
93 |
+
- is_trainable: false
|
94 |
+
input_keys: crop_coords_top_left,crop_coords_top_left_ref
|
95 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
96 |
+
params:
|
97 |
+
outdim: 256
|
98 |
+
- is_trainable: false
|
99 |
+
input_keys: target_size_as_tuple,target_size_as_tuple_ref
|
100 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
101 |
+
params:
|
102 |
+
outdim: 256
|
103 |
+
first_stage_config:
|
104 |
+
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
105 |
+
params:
|
106 |
+
ckpt_path: pretrained-models/sdxl_vae.safetensors
|
107 |
+
embed_dim: 4
|
108 |
+
monitor: val/rec_loss
|
109 |
+
ddconfig:
|
110 |
+
attn_type: vanilla-xformers
|
111 |
+
double_z: true
|
112 |
+
z_channels: 4
|
113 |
+
resolution: 256
|
114 |
+
in_channels: 3
|
115 |
+
out_ch: 3
|
116 |
+
ch: 128
|
117 |
+
ch_mult:
|
118 |
+
- 1
|
119 |
+
- 2
|
120 |
+
- 4
|
121 |
+
- 4
|
122 |
+
num_res_blocks: 2
|
123 |
+
attn_resolutions: []
|
124 |
+
dropout: 0.0
|
125 |
+
lossconfig:
|
126 |
+
target: torch.nn.Identity
|
127 |
+
loss_fn_config:
|
128 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef
|
129 |
+
params:
|
130 |
+
sigma_sampler_config:
|
131 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling
|
132 |
+
params:
|
133 |
+
num_idx: 1000
|
134 |
+
discretization_config:
|
135 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
136 |
+
sigma_sampler_config_ref:
|
137 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
138 |
+
params:
|
139 |
+
num_idx: 50
|
140 |
+
discretization_config:
|
141 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
142 |
+
sampler_config:
|
143 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
144 |
+
params:
|
145 |
+
num_steps: 50
|
146 |
+
discretization_config:
|
147 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
148 |
+
guider_config:
|
149 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef
|
150 |
+
params:
|
151 |
+
scale: 7.5
|
152 |
+
data:
|
153 |
+
target: sgm.data.data_co3d.CustomDataDictLoader
|
154 |
+
params:
|
155 |
+
batch_size: 1
|
156 |
+
num_workers: 4
|
157 |
+
category: motorcycle
|
158 |
+
img_size: 512
|
159 |
+
skip: 2
|
160 |
+
num_images: 5
|
161 |
+
mask_images: true
|
162 |
+
single_id: 12
|
163 |
+
bbox: true
|
164 |
+
addreg: true
|
165 |
+
drop_ratio: 0.25
|
166 |
+
drop_txt: 0.1
|
167 |
+
modifier_token: <new1>
|
168 |
+
categoryname: null
|
pretrained_models/teddybear31/checkpoints/step=000001600.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c474759a84423022acf8eddadb58c5dcb49a9c80077b8f08f3a137f44e5eb76
|
3 |
+
size 777852660
|
pretrained_models/teddybear31/configs/2024-04-12T22-50-24-lightning.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lightning:
|
2 |
+
modelcheckpoint:
|
3 |
+
params:
|
4 |
+
every_n_train_steps: 1600
|
5 |
+
save_top_k: -1
|
6 |
+
save_on_train_epoch_end: false
|
7 |
+
callbacks:
|
8 |
+
metrics_over_trainsteps_checkpoint:
|
9 |
+
params:
|
10 |
+
every_n_train_steps: 25000
|
11 |
+
image_logger:
|
12 |
+
target: main.ImageLogger
|
13 |
+
params:
|
14 |
+
disabled: false
|
15 |
+
enable_autocast: false
|
16 |
+
batch_frequency: 5000
|
17 |
+
max_images: 8
|
18 |
+
increase_log_steps: false
|
19 |
+
log_first_step: false
|
20 |
+
log_images_kwargs:
|
21 |
+
use_ema_scope: false
|
22 |
+
'N': 1
|
23 |
+
n_rows: 2
|
24 |
+
trainer:
|
25 |
+
devices: 0,1,2,3
|
26 |
+
benchmark: true
|
27 |
+
num_sanity_val_steps: 0
|
28 |
+
accumulate_grad_batches: 1
|
29 |
+
max_steps: 1610
|
30 |
+
val_check_interval: 400
|
31 |
+
accelerator: gpu
|
pretrained_models/teddybear31/configs/2024-04-12T22-50-24-project.yaml
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 0.0001
|
3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
4 |
+
params:
|
5 |
+
scale_factor: 0.13025
|
6 |
+
disable_first_stage_autocast: true
|
7 |
+
trainkeys: pose
|
8 |
+
multiplier: 0.05
|
9 |
+
loss_rgb_lambda: 5
|
10 |
+
loss_fg_lambda: 10
|
11 |
+
loss_bg_lambda: 10
|
12 |
+
log_keys:
|
13 |
+
- txt
|
14 |
+
denoiser_config:
|
15 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
16 |
+
params:
|
17 |
+
num_idx: 1000
|
18 |
+
weighting_config:
|
19 |
+
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
20 |
+
scaling_config:
|
21 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
22 |
+
discretization_config:
|
23 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
24 |
+
network_config:
|
25 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
26 |
+
params:
|
27 |
+
adm_in_channels: 2816
|
28 |
+
num_classes: sequential
|
29 |
+
use_checkpoint: false
|
30 |
+
in_channels: 4
|
31 |
+
out_channels: 4
|
32 |
+
model_channels: 320
|
33 |
+
attention_resolutions:
|
34 |
+
- 4
|
35 |
+
- 2
|
36 |
+
num_res_blocks: 2
|
37 |
+
channel_mult:
|
38 |
+
- 1
|
39 |
+
- 2
|
40 |
+
- 4
|
41 |
+
num_head_channels: 64
|
42 |
+
use_linear_in_transformer: true
|
43 |
+
transformer_depth:
|
44 |
+
- 1
|
45 |
+
- 2
|
46 |
+
- 10
|
47 |
+
context_dim: 2048
|
48 |
+
spatial_transformer_attn_type: softmax-xformers
|
49 |
+
image_cross_blocks:
|
50 |
+
- 0
|
51 |
+
- 2
|
52 |
+
- 4
|
53 |
+
- 6
|
54 |
+
- 8
|
55 |
+
- 10
|
56 |
+
rgb: true
|
57 |
+
far: 2
|
58 |
+
num_samples: 24
|
59 |
+
not_add_context_in_triplane: false
|
60 |
+
rgb_predict: true
|
61 |
+
add_lora: false
|
62 |
+
average: false
|
63 |
+
use_prev_weights_imp_sample: true
|
64 |
+
stratified: true
|
65 |
+
imp_sampling_percent: 0.9
|
66 |
+
use_prev_weights_imp_sample: true
|
67 |
+
conditioner_config:
|
68 |
+
target: sgm.modules.GeneralConditioner
|
69 |
+
params:
|
70 |
+
emb_models:
|
71 |
+
- is_trainable: false
|
72 |
+
input_keys: txt,txt_ref
|
73 |
+
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
74 |
+
params:
|
75 |
+
layer: hidden
|
76 |
+
layer_idx: 11
|
77 |
+
modifier_token: <new1>
|
78 |
+
- is_trainable: false
|
79 |
+
input_keys: txt,txt_ref
|
80 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
81 |
+
params:
|
82 |
+
arch: ViT-bigG-14
|
83 |
+
version: laion2b_s39b_b160k
|
84 |
+
layer: penultimate
|
85 |
+
always_return_pooled: true
|
86 |
+
legacy: false
|
87 |
+
modifier_token: <new1>
|
88 |
+
- is_trainable: false
|
89 |
+
input_keys: original_size_as_tuple,original_size_as_tuple_ref
|
90 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
91 |
+
params:
|
92 |
+
outdim: 256
|
93 |
+
- is_trainable: false
|
94 |
+
input_keys: crop_coords_top_left,crop_coords_top_left_ref
|
95 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
96 |
+
params:
|
97 |
+
outdim: 256
|
98 |
+
- is_trainable: false
|
99 |
+
input_keys: target_size_as_tuple,target_size_as_tuple_ref
|
100 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
101 |
+
params:
|
102 |
+
outdim: 256
|
103 |
+
first_stage_config:
|
104 |
+
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
105 |
+
params:
|
106 |
+
ckpt_path: /sensei-fs/tenants/Sensei-AdobeResearchTeam/nupkumar1/custom-pose/pretrained-models/sdxl_vae.safetensors
|
107 |
+
embed_dim: 4
|
108 |
+
monitor: val/rec_loss
|
109 |
+
ddconfig:
|
110 |
+
attn_type: vanilla-xformers
|
111 |
+
double_z: true
|
112 |
+
z_channels: 4
|
113 |
+
resolution: 256
|
114 |
+
in_channels: 3
|
115 |
+
out_ch: 3
|
116 |
+
ch: 128
|
117 |
+
ch_mult:
|
118 |
+
- 1
|
119 |
+
- 2
|
120 |
+
- 4
|
121 |
+
- 4
|
122 |
+
num_res_blocks: 2
|
123 |
+
attn_resolutions: []
|
124 |
+
dropout: 0.0
|
125 |
+
lossconfig:
|
126 |
+
target: torch.nn.Identity
|
127 |
+
loss_fn_config:
|
128 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef
|
129 |
+
params:
|
130 |
+
sigma_sampler_config:
|
131 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling
|
132 |
+
params:
|
133 |
+
num_idx: 1000
|
134 |
+
discretization_config:
|
135 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
136 |
+
sigma_sampler_config_ref:
|
137 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
138 |
+
params:
|
139 |
+
num_idx: 50
|
140 |
+
discretization_config:
|
141 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
142 |
+
sampler_config:
|
143 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
144 |
+
params:
|
145 |
+
num_steps: 50
|
146 |
+
discretization_config:
|
147 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
148 |
+
guider_config:
|
149 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef
|
150 |
+
params:
|
151 |
+
scale: 7.5
|
152 |
+
data:
|
153 |
+
target: sgm.data.data_co3d.CustomDataDictLoader
|
154 |
+
params:
|
155 |
+
batch_size: 1
|
156 |
+
num_workers: 4
|
157 |
+
category: teddybear
|
158 |
+
img_size: 512
|
159 |
+
skip: 2
|
160 |
+
num_images: 5
|
161 |
+
mask_images: true
|
162 |
+
single_id: 31
|
163 |
+
bbox: true
|
164 |
+
addreg: true
|
165 |
+
drop_ratio: 0.25
|
166 |
+
drop_txt: 0.1
|
167 |
+
modifier_token: <new1>
|
168 |
+
categoryname: null
|
requirements.txt
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
omegaconf
|
2 |
+
einops
|
3 |
+
fire
|
4 |
+
tqdm
|
5 |
+
pillow
|
6 |
+
numpy
|
7 |
+
webdataset>=0.2.33
|
8 |
+
ninja
|
9 |
+
matplotlib
|
10 |
+
torchmetrics
|
11 |
+
opencv-python==4.6.0.66
|
12 |
+
fairscale
|
13 |
+
pytorch-lightning==2.0.1
|
14 |
+
fire
|
15 |
+
fsspec
|
16 |
+
kornia==0.6.9
|
17 |
+
natsort
|
18 |
+
open-clip-torch
|
19 |
+
chardet==5.1.0
|
20 |
+
tensorboardx==2.6
|
21 |
+
pandas
|
22 |
+
pudb
|
23 |
+
pyyaml
|
24 |
+
urllib3<1.27,>=1.25.4
|
25 |
+
scipy
|
26 |
+
streamlit>=0.73.1
|
27 |
+
timm
|
28 |
+
tokenizers==0.12.1
|
29 |
+
transformers==4.19.1
|
30 |
+
triton==2.1.0
|
31 |
+
torchdata==0.7.0
|
32 |
+
wandb
|
33 |
+
invisible-watermark
|
34 |
+
xformers
|
35 |
+
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
36 |
+
-e git+https://github.com/openai/CLIP.git@main#egg=clip
|
37 |
+
-e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
|
sampling_for_demo.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import copy
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from einops import rearrange
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from PIL import Image
|
12 |
+
from pytorch_lightning import seed_everything
|
13 |
+
from pytorch3d.renderer.cameras import PerspectiveCameras
|
14 |
+
from pytorch3d.renderer import look_at_view_transform
|
15 |
+
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
16 |
+
|
17 |
+
import json
|
18 |
+
|
19 |
+
sys.path.append('./')
|
20 |
+
from sgm.util import instantiate_from_config, load_safetensors
|
21 |
+
|
22 |
+
choices = []
|
23 |
+
|
24 |
+
def append_dims(x, target_dims):
|
25 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
26 |
+
dims_to_append = target_dims - x.ndim
|
27 |
+
if dims_to_append < 0:
|
28 |
+
raise ValueError(
|
29 |
+
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
30 |
+
)
|
31 |
+
return x[(...,) + (None,) * dims_to_append]
|
32 |
+
|
33 |
+
|
34 |
+
def load_base_model(config, ckpt=None, verbose=True):
|
35 |
+
config = OmegaConf.load(config)
|
36 |
+
# load model
|
37 |
+
config.model.params.network_config.params.far = 3
|
38 |
+
config.model.params.first_stage_config.params.ckpt_path = "pretrained-models/sdxl_vae.safetensors"
|
39 |
+
guider_config = {'target': 'sgm.modules.diffusionmodules.guiders.ScheduledCFGImgTextRef',
|
40 |
+
'params': {'scale': 7.5, 'scale_im': 3.5}
|
41 |
+
}
|
42 |
+
config.model.params.sampler_config.params.guider_config = guider_config
|
43 |
+
|
44 |
+
model = instantiate_from_config(config.model)
|
45 |
+
|
46 |
+
if ckpt is not None:
|
47 |
+
print(f"Loading model from {ckpt}")
|
48 |
+
if ckpt.endswith("ckpt"):
|
49 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
50 |
+
if "global_step" in pl_sd:
|
51 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
52 |
+
sd = pl_sd["state_dict"]
|
53 |
+
elif ckpt.endswith("safetensors"):
|
54 |
+
sd = load_safetensors(ckpt)
|
55 |
+
if 'modifier_token' in config.data.params:
|
56 |
+
del sd['conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight']
|
57 |
+
del sd['conditioner.embedders.1.model.token_embedding.weight']
|
58 |
+
else:
|
59 |
+
raise NotImplementedError
|
60 |
+
|
61 |
+
m, u = model.load_state_dict(sd, strict=False)
|
62 |
+
|
63 |
+
model.cuda()
|
64 |
+
model.eval()
|
65 |
+
return model
|
66 |
+
|
67 |
+
|
68 |
+
def load_delta_model(model, delta_ckpt=None, verbose=True, freeze=True):
|
69 |
+
"""
|
70 |
+
model is preloaded base stable diffusion model
|
71 |
+
"""
|
72 |
+
|
73 |
+
msg = None
|
74 |
+
if delta_ckpt is not None:
|
75 |
+
pl_sd_delta = torch.load(delta_ckpt, map_location="cpu")
|
76 |
+
sd_delta = pl_sd_delta["delta_state_dict"]
|
77 |
+
|
78 |
+
# TODO: add new delta loading embedding stuff?
|
79 |
+
|
80 |
+
for name, module in model.model.diffusion_model.named_modules():
|
81 |
+
if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks':
|
82 |
+
if hasattr(module, 'pose_emb_layers'):
|
83 |
+
module.register_buffer('references', sd_delta[f'model.diffusion_model.{name}.references'])
|
84 |
+
del sd_delta[f'model.diffusion_model.{name}.references']
|
85 |
+
|
86 |
+
m, u = model.load_state_dict(sd_delta, strict=False)
|
87 |
+
|
88 |
+
|
89 |
+
if len(m) > 0 and verbose:
|
90 |
+
print("missing keys:")
|
91 |
+
if len(u) > 0 and verbose:
|
92 |
+
print("unexpected keys:")
|
93 |
+
|
94 |
+
if freeze:
|
95 |
+
for param in model.parameters():
|
96 |
+
param.requires_grad = False
|
97 |
+
|
98 |
+
model.cuda()
|
99 |
+
model.eval()
|
100 |
+
return model, msg
|
101 |
+
|
102 |
+
|
103 |
+
def get_unique_embedder_keys_from_conditioner(conditioner):
|
104 |
+
p = [x.input_keys for x in conditioner.embedders]
|
105 |
+
return list(set([item for sublist in p for item in sublist])) + ['jpg_ref']
|
106 |
+
|
107 |
+
|
108 |
+
def customforward(self, x, xr, context=None, contextr=None, pose=None, mask_ref=None, prev_weights=None, timesteps=None, drop_im=None):
|
109 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
110 |
+
if not isinstance(context, list):
|
111 |
+
context = [context]
|
112 |
+
b, c, h, w = x.shape
|
113 |
+
x_in = x
|
114 |
+
fg_masks = []
|
115 |
+
alphas = []
|
116 |
+
rgbs = []
|
117 |
+
|
118 |
+
x = self.norm(x)
|
119 |
+
|
120 |
+
if not self.use_linear:
|
121 |
+
x = self.proj_in(x)
|
122 |
+
|
123 |
+
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
124 |
+
if self.use_linear:
|
125 |
+
x = self.proj_in(x)
|
126 |
+
|
127 |
+
prev_weights = None
|
128 |
+
counter = 0
|
129 |
+
for i, block in enumerate(self.transformer_blocks):
|
130 |
+
if i > 0 and len(context) == 1:
|
131 |
+
i = 0 # use same context for each block
|
132 |
+
if self.image_cross and (counter % self.poscontrol_interval == 0):
|
133 |
+
x, fg_mask, weights, alpha, rgb = block(x, context=context[i], context_ref=x, pose=pose, mask_ref=mask_ref, prev_weights=prev_weights, drop_im=drop_im)
|
134 |
+
prev_weights = weights
|
135 |
+
fg_masks.append(fg_mask)
|
136 |
+
if alpha is not None:
|
137 |
+
alphas.append(alpha)
|
138 |
+
if rgb is not None:
|
139 |
+
rgbs.append(rgb)
|
140 |
+
else:
|
141 |
+
x, _, _, _, _ = block(x, context=context[i], drop_im=drop_im)
|
142 |
+
counter += 1
|
143 |
+
if self.use_linear:
|
144 |
+
x = self.proj_out(x)
|
145 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
146 |
+
if not self.use_linear:
|
147 |
+
x = self.proj_out(x)
|
148 |
+
if len(fg_masks) > 0:
|
149 |
+
if len(rgbs) <= 0:
|
150 |
+
rgbs = None
|
151 |
+
if len(alphas) <= 0:
|
152 |
+
alphas = None
|
153 |
+
return x + x_in, None, fg_masks, prev_weights, alphas, rgbs
|
154 |
+
else:
|
155 |
+
return x + x_in, None, None, prev_weights, None, None
|
156 |
+
|
157 |
+
|
158 |
+
def _customforward(
|
159 |
+
self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, drop_im=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
160 |
+
):
|
161 |
+
if context_ref is not None:
|
162 |
+
global choices
|
163 |
+
batch_size = x.size(0)
|
164 |
+
# IP2P like sampling or default sampling
|
165 |
+
if batch_size % 3 == 0:
|
166 |
+
batch_size = batch_size // 3
|
167 |
+
context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1)
|
168 |
+
context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref, context_ref], dim=0)
|
169 |
+
else:
|
170 |
+
batch_size = batch_size // 2
|
171 |
+
context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1)
|
172 |
+
context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref], dim=0)
|
173 |
+
|
174 |
+
fg_mask = None
|
175 |
+
weights = None
|
176 |
+
alphas = None
|
177 |
+
predicted_rgb = None
|
178 |
+
|
179 |
+
x = (
|
180 |
+
self.attn1(
|
181 |
+
self.norm1(x),
|
182 |
+
context=context if self.disable_self_attn else None,
|
183 |
+
additional_tokens=additional_tokens,
|
184 |
+
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
185 |
+
if not self.disable_self_attn
|
186 |
+
else 0,
|
187 |
+
)
|
188 |
+
+ x
|
189 |
+
)
|
190 |
+
|
191 |
+
x = (
|
192 |
+
self.attn2(
|
193 |
+
self.norm2(x), context=context, additional_tokens=additional_tokens,
|
194 |
+
)
|
195 |
+
+ x
|
196 |
+
)
|
197 |
+
|
198 |
+
if context_ref is not None:
|
199 |
+
if self.rendered_feat is not None:
|
200 |
+
x = self.pose_emb_layers(torch.cat([x, self.rendered_feat], dim=-1))
|
201 |
+
else:
|
202 |
+
xref, fg_mask, weights, alphas, predicted_rgb = self.reference_attn(x,
|
203 |
+
context_ref,
|
204 |
+
context,
|
205 |
+
pose,
|
206 |
+
prev_weights,
|
207 |
+
mask_ref)
|
208 |
+
self.rendered_feat = xref
|
209 |
+
x = self.pose_emb_layers(torch.cat([x, xref], -1))
|
210 |
+
|
211 |
+
x = self.ff(self.norm3(x)) + x
|
212 |
+
return x, fg_mask, weights, alphas, predicted_rgb
|
213 |
+
|
214 |
+
|
215 |
+
def log_images(
|
216 |
+
model,
|
217 |
+
batch,
|
218 |
+
N: int = 1,
|
219 |
+
noise=None,
|
220 |
+
scale_im=3.5,
|
221 |
+
num_steps: int = 10,
|
222 |
+
ucg_keys: List[str] = None,
|
223 |
+
**kwargs,
|
224 |
+
):
|
225 |
+
|
226 |
+
log = dict()
|
227 |
+
conditioner_input_keys = [e.input_keys for e in model.conditioner.embedders]
|
228 |
+
ucg_keys = conditioner_input_keys
|
229 |
+
pose = batch['pose']
|
230 |
+
|
231 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
232 |
+
batch,
|
233 |
+
force_uc_zero_embeddings=ucg_keys
|
234 |
+
if len(model.conditioner.embedders) > 0
|
235 |
+
else [],
|
236 |
+
force_ref_zero_embeddings=True
|
237 |
+
)
|
238 |
+
|
239 |
+
_, n = 1, len(pose)-1
|
240 |
+
sampling_kwargs = {}
|
241 |
+
|
242 |
+
if scale_im > 0:
|
243 |
+
if uc is not None:
|
244 |
+
if isinstance(pose, list):
|
245 |
+
pose = pose[:N]*3
|
246 |
+
else:
|
247 |
+
pose = torch.cat([pose[:N]] * 3)
|
248 |
+
else:
|
249 |
+
if uc is not None:
|
250 |
+
if isinstance(pose, list):
|
251 |
+
pose = pose[:N]*2
|
252 |
+
else:
|
253 |
+
pose = torch.cat([pose[:N]] * 2)
|
254 |
+
|
255 |
+
sampling_kwargs['pose'] = pose
|
256 |
+
sampling_kwargs['drop_im'] = None
|
257 |
+
sampling_kwargs['mask_ref'] = None
|
258 |
+
|
259 |
+
for k in c:
|
260 |
+
if isinstance(c[k], torch.Tensor):
|
261 |
+
c[k], uc[k] = map(lambda y: y[k][:(n+1)*N].to('cuda'), (c, uc))
|
262 |
+
|
263 |
+
import time
|
264 |
+
st = time.time()
|
265 |
+
with model.ema_scope("Plotting"):
|
266 |
+
samples = model.sample(
|
267 |
+
c, shape=noise.shape[1:], uc=uc, batch_size=N, num_steps=num_steps, noise=noise, **sampling_kwargs
|
268 |
+
)
|
269 |
+
model.clear_rendered_feat()
|
270 |
+
samples = model.decode_first_stage(samples)
|
271 |
+
print("Time taken for sampling", time.time() - st)
|
272 |
+
log["samples"] = samples.cpu()
|
273 |
+
|
274 |
+
return log
|
275 |
+
|
276 |
+
|
277 |
+
def process_camera_json(camera_json, example_cam):
|
278 |
+
# replace all single quotes in the camera_json with quotes quotes
|
279 |
+
camera_json = camera_json.replace("'", "\"")
|
280 |
+
print("input camera json")
|
281 |
+
print(camera_json)
|
282 |
+
|
283 |
+
camera_dict = json.loads(camera_json)["scene.camera"]
|
284 |
+
eye = torch.tensor([camera_dict["eye"]["x"], camera_dict["eye"]["y"], camera_dict["eye"]["z"]], dtype=torch.float32).unsqueeze(0)
|
285 |
+
up = torch.tensor([camera_dict["up"]["x"], camera_dict["up"]["y"], camera_dict["up"]["z"]], dtype=torch.float32).unsqueeze(0)
|
286 |
+
center = torch.tensor([camera_dict["center"]["x"], camera_dict["center"]["y"], camera_dict["center"]["z"]], dtype=torch.float32).unsqueeze(0)
|
287 |
+
new_R, new_T = look_at_view_transform(eye=eye, at=center, up=up)
|
288 |
+
|
289 |
+
## temp
|
290 |
+
# new_R = torch.tensor([[[ 0.4988, 0.2666, 0.8247],
|
291 |
+
# [-0.1917, -0.8940, 0.4049],
|
292 |
+
# [ 0.8453, -0.3601, -0.3948]]], dtype=torch.float32)
|
293 |
+
# new_T = torch.tensor([[ 0.0739, -0.0013, 0.9973]], dtype=torch.float32)
|
294 |
+
|
295 |
+
|
296 |
+
# new_R = torch.tensor([[[ 0.2530, 0.2989, 0.9201],
|
297 |
+
# [-0.2652, -0.8932, 0.3631],
|
298 |
+
# [ 0.9304, -0.3359, -0.1467],]], dtype=torch.float32)
|
299 |
+
# new_T = torch.tensor([[ 0.0081, 0.0337, 1.0452]], dtype=torch.float32)
|
300 |
+
|
301 |
+
|
302 |
+
print("focal length", example_cam.focal_length)
|
303 |
+
print("principal point", example_cam.principal_point)
|
304 |
+
|
305 |
+
newcam = PerspectiveCameras(R=new_R,
|
306 |
+
T=new_T,
|
307 |
+
focal_length=example_cam.focal_length,
|
308 |
+
principal_point=example_cam.principal_point,
|
309 |
+
image_size=512)
|
310 |
+
|
311 |
+
print("input pose")
|
312 |
+
print(newcam.get_world_to_view_transform().get_matrix())
|
313 |
+
return newcam
|
314 |
+
|
315 |
+
|
316 |
+
def load_and_return_model_and_data(config, model,
|
317 |
+
ckpt="/data/gdsu/customization3d/stable-diffusion-xl-base-1.0/sd_xl_base_1.0.safetensors",
|
318 |
+
delta_ckpt=None,
|
319 |
+
train=False,
|
320 |
+
valid=False,
|
321 |
+
far=3,
|
322 |
+
num_images=1,
|
323 |
+
num_ref=8,
|
324 |
+
max_images=20,
|
325 |
+
):
|
326 |
+
config = OmegaConf.load(config)
|
327 |
+
# load data
|
328 |
+
data = None
|
329 |
+
# config.data.params.jitter = False
|
330 |
+
# config.data.params.addreg = False
|
331 |
+
# config.data.params.bbox = False
|
332 |
+
|
333 |
+
# data = instantiate_from_config(config.data)
|
334 |
+
# data = data.train_dataset
|
335 |
+
|
336 |
+
# single_id = data.single_id
|
337 |
+
|
338 |
+
# if hasattr(data, 'rotations'):
|
339 |
+
# total_images = len(data.rotations[data.sequence_list[single_id]])
|
340 |
+
# else:
|
341 |
+
# total_images = len(data.annotations['chair'])
|
342 |
+
# print(f"Total images in dataset: {total_images}")
|
343 |
+
|
344 |
+
model, msg = load_delta_model(model, delta_ckpt,)
|
345 |
+
|
346 |
+
# change forward methods to store rendered features and use the pre-calculated reference features
|
347 |
+
def register_recr(net_):
|
348 |
+
if net_.__class__.__name__ == 'SpatialTransformer':
|
349 |
+
print(net_.__class__.__name__, "adding control")
|
350 |
+
bound_method = customforward.__get__(net_, net_.__class__)
|
351 |
+
setattr(net_, 'forward', bound_method)
|
352 |
+
return
|
353 |
+
elif hasattr(net_, 'children'):
|
354 |
+
for net__ in net_.children():
|
355 |
+
register_recr(net__)
|
356 |
+
return
|
357 |
+
|
358 |
+
def register_recr2(net_):
|
359 |
+
if net_.__class__.__name__ == 'BasicTransformerBlock':
|
360 |
+
print(net_.__class__.__name__, "adding control")
|
361 |
+
bound_method = _customforward.__get__(net_, net_.__class__)
|
362 |
+
setattr(net_, 'forward', bound_method)
|
363 |
+
return
|
364 |
+
elif hasattr(net_, 'children'):
|
365 |
+
for net__ in net_.children():
|
366 |
+
register_recr2(net__)
|
367 |
+
return
|
368 |
+
|
369 |
+
sub_nets = model.model.diffusion_model.named_children()
|
370 |
+
for net in sub_nets:
|
371 |
+
register_recr(net[1])
|
372 |
+
register_recr2(net[1])
|
373 |
+
|
374 |
+
# start sampling
|
375 |
+
model.clear_rendered_feat()
|
376 |
+
|
377 |
+
return model, data
|
378 |
+
|
379 |
+
|
380 |
+
def sample(model, data,
|
381 |
+
num_images=1,
|
382 |
+
prompt="",
|
383 |
+
appendpath="",
|
384 |
+
camera_json=None,
|
385 |
+
train=False,
|
386 |
+
scale=7.5,
|
387 |
+
scale_im=3.5,
|
388 |
+
beta=1.0,
|
389 |
+
num_ref=8,
|
390 |
+
skipreflater=False,
|
391 |
+
num_steps=10,
|
392 |
+
valid=False,
|
393 |
+
max_images=20,
|
394 |
+
seed=42,
|
395 |
+
camera_path="pretrained-models/car0/camera.bin",
|
396 |
+
):
|
397 |
+
|
398 |
+
"""
|
399 |
+
Only works with num_images=1 (because of camera_json processing)
|
400 |
+
"""
|
401 |
+
|
402 |
+
if num_images != 1:
|
403 |
+
print("forcing num_images to be 1")
|
404 |
+
num_images = 1
|
405 |
+
|
406 |
+
# set guidance scales
|
407 |
+
model.sampler.guider.scale_im = scale_im
|
408 |
+
model.sampler.guider.scale = scale
|
409 |
+
|
410 |
+
seed_everything(seed)
|
411 |
+
|
412 |
+
# load cameras
|
413 |
+
cameras_val, cameras_train = torch.load(camera_path)
|
414 |
+
global choices
|
415 |
+
num_ref = 8
|
416 |
+
max_diff = len(cameras_train)/num_ref
|
417 |
+
choices = [int(x) for x in torch.linspace(0, len(cameras_train) - max_diff, num_ref)]
|
418 |
+
cameras_train_final = [cameras_train[i] for i in choices]
|
419 |
+
|
420 |
+
# start sampling
|
421 |
+
model.clear_rendered_feat()
|
422 |
+
|
423 |
+
if prompt == "":
|
424 |
+
prompt = None
|
425 |
+
|
426 |
+
noise = torch.randn(1, 4, 64, 64).to('cuda').repeat(num_images, 1, 1, 1)
|
427 |
+
|
428 |
+
# random sample camera poses
|
429 |
+
pose_ids = np.random.choice(len(cameras_val), num_images, replace=False)
|
430 |
+
print(pose_ids)
|
431 |
+
pose_ids[0] = 21
|
432 |
+
|
433 |
+
pose = [cameras_val[i] for i in pose_ids]
|
434 |
+
|
435 |
+
print("example camera")
|
436 |
+
print(pose[0].R)
|
437 |
+
print(pose[0].T)
|
438 |
+
print(pose[0].focal_length)
|
439 |
+
print(pose[0].principal_point)
|
440 |
+
|
441 |
+
# prepare batches [if translating then call required functions on the target pose]
|
442 |
+
batches = []
|
443 |
+
for i in range(num_images):
|
444 |
+
batch = {'pose': [pose[i]] + cameras_train_final,
|
445 |
+
"original_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2),
|
446 |
+
"target_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2),
|
447 |
+
"crop_coords_top_left": torch.tensor([0, 0]).reshape(-1, 2),
|
448 |
+
"original_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2),
|
449 |
+
"target_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2),
|
450 |
+
"crop_coords_top_left_ref": torch.tensor([0, 0]).reshape(-1, 2),
|
451 |
+
}
|
452 |
+
batch_ = copy.deepcopy(batch)
|
453 |
+
batch_["pose"][0] = process_camera_json(camera_json, pose[0])
|
454 |
+
batch_["pose"] = [join_cameras_as_batch(batch_["pose"])]
|
455 |
+
# print('batched')
|
456 |
+
# print(batch_["pose"][0].get_world_to_view_transform().get_matrix())
|
457 |
+
batches.append(batch_)
|
458 |
+
|
459 |
+
print(f'len batches: {len(batches)}')
|
460 |
+
|
461 |
+
image = None
|
462 |
+
|
463 |
+
with torch.no_grad():
|
464 |
+
for batch in batches:
|
465 |
+
for key in batch.keys():
|
466 |
+
if isinstance(batch[key], torch.Tensor):
|
467 |
+
batch[key] = batch[key].to('cuda')
|
468 |
+
elif 'pose' in key:
|
469 |
+
batch[key] = [x.to('cuda') for x in batch[key]]
|
470 |
+
else:
|
471 |
+
pass
|
472 |
+
|
473 |
+
if prompt is not None:
|
474 |
+
batch["txt"] = [prompt for _ in range(1)]
|
475 |
+
batch["txt_ref"] = [prompt for _ in range(len(batch["pose"])-1)]
|
476 |
+
|
477 |
+
print(batch["txt"])
|
478 |
+
N = 1
|
479 |
+
log_ = log_images(model, batch, N=N, noise=noise.clone()[:N], num_steps=num_steps, scale_im=scale_im)
|
480 |
+
image = log_["samples"]
|
481 |
+
|
482 |
+
torch.cuda.empty_cache()
|
483 |
+
model.clear_rendered_feat()
|
484 |
+
|
485 |
+
print("generation done")
|
486 |
+
return image
|
487 |
+
|
scripts.js
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
async () => {
|
2 |
+
|
3 |
+
globalThis.init_camera_dict = {
|
4 |
+
"scene.camera": {
|
5 |
+
"up": {"x": -0.13227683305740356,
|
6 |
+
"y": -0.9911391735076904,
|
7 |
+
"z": -0.013464212417602539},
|
8 |
+
"center": {"x": -0.005292057991027832,
|
9 |
+
"y": 0.020704858005046844,
|
10 |
+
"z": 0.0873757004737854},
|
11 |
+
"eye": {"x": 0.8585731983184814,
|
12 |
+
"y": -0.08790968358516693,
|
13 |
+
"z": -0.40458938479423523},
|
14 |
+
},
|
15 |
+
"scene.aspectratio": {"x": 1.974, "y": 1.974, "z": 1.974},
|
16 |
+
"scene.aspectmode": "manual"
|
17 |
+
};
|
18 |
+
|
19 |
+
// globalThis.restrictCamera = (data) => {
|
20 |
+
// var plotlyDiv = document.getElementById("map").getElementsByClassName('js-plotly-plot')[0];
|
21 |
+
// // var curr_eye = plotlyDiv.layout.scene.camera.eye;
|
22 |
+
// // var curr_center = plotlyDiv.layout.scene.camera.center;
|
23 |
+
// // var curr_up = plotlyDiv.layout.scene.camera.up;
|
24 |
+
|
25 |
+
// var curr_eye = data["scene.camera"]["eye"];
|
26 |
+
// var curr_center = data["scene.camera"]["center"];
|
27 |
+
// var curr_up = data["scene.camera"]["up"];
|
28 |
+
|
29 |
+
// var D = Math.sqrt((curr_eye.x - curr_center.x)**2 + (curr_eye.y - curr_center.y)**2 + (curr_eye.z - curr_center.z)**2);
|
30 |
+
// console.log("D", D);
|
31 |
+
|
32 |
+
// const max_D = 1.47;
|
33 |
+
// const min_D = 0.8;
|
34 |
+
|
35 |
+
// // calculate elevation
|
36 |
+
// var elevation = Math.atan2(curr_eye.y - curr_center.y, Math.sqrt((curr_eye.x - curr_center.x)**2 + (curr_eye.z - curr_center.z)**2)) * 180 / Math.PI;
|
37 |
+
// console.log("elevation", elevation);
|
38 |
+
// const max_elev = 3.2;
|
39 |
+
// const min_elev = -30;
|
40 |
+
|
41 |
+
// const eps = 0.01;
|
42 |
+
|
43 |
+
// if (D > max_D) {
|
44 |
+
// // find new_eye such that D = max_D
|
45 |
+
// var new_dict = {
|
46 |
+
// "scene.camera": {
|
47 |
+
// "eye": {
|
48 |
+
// "x": curr_center.x + (curr_eye.x - curr_center.x) * max_D / D - eps,
|
49 |
+
// "y": curr_center.y + (curr_eye.y - curr_center.y) * max_D / D - eps,
|
50 |
+
// "z": curr_center.z + (curr_eye.z - curr_center.z) * max_D / D - eps,
|
51 |
+
// },
|
52 |
+
// "up": curr_up,
|
53 |
+
// "center": curr_center,
|
54 |
+
// }
|
55 |
+
// };
|
56 |
+
|
57 |
+
// Plotly.relayout(plotlyDiv, new_dict);
|
58 |
+
|
59 |
+
// } else if (D < min_D) {
|
60 |
+
// // find new_eye such that D = min_D
|
61 |
+
// var new_dict = {
|
62 |
+
// "scene.camera": {
|
63 |
+
// "eye": {
|
64 |
+
// "x": curr_center.x + (curr_eye.x - curr_center.x) * min_D / D - eps,
|
65 |
+
// "y": curr_center.y + (curr_eye.y - curr_center.y) * min_D / D - eps,
|
66 |
+
// "z": curr_center.z + (curr_eye.z - curr_center.z) * min_D / D - eps,
|
67 |
+
// },
|
68 |
+
// "up": curr_up,
|
69 |
+
// "center": curr_center,
|
70 |
+
// }
|
71 |
+
// };
|
72 |
+
|
73 |
+
// Plotly.relayout(plotlyDiv, new_dict);
|
74 |
+
// }
|
75 |
+
|
76 |
+
// const eta = 0.001;
|
77 |
+
// if (elevation > max_elev) {
|
78 |
+
// // find new eye such that y elevation = max_elev
|
79 |
+
// var new_dict = {
|
80 |
+
// "scene.camera": {
|
81 |
+
// "eye": {
|
82 |
+
// "x": curr_eye.x,
|
83 |
+
// "y": curr_center.y + (curr_eye.y - curr_center.y) * Math.tan((max_elev - eta) * Math.PI / 180),
|
84 |
+
// "z": curr_eye.z,
|
85 |
+
// },
|
86 |
+
// "up": curr_up,
|
87 |
+
// "center": curr_center,
|
88 |
+
// }
|
89 |
+
// };
|
90 |
+
|
91 |
+
|
92 |
+
// Plotly.relayout(plotlyDiv, new_dict);
|
93 |
+
|
94 |
+
// } else if (elevation < min_elev) {
|
95 |
+
// // find new eye such that y elevation = min_elev
|
96 |
+
// var new_dict = {
|
97 |
+
// "scene.camera": {
|
98 |
+
// "eye": {
|
99 |
+
// "x": curr_eye.x,
|
100 |
+
// "y": curr_center.y + (curr_eye.y - curr_center.y) * Math.tan((min_elev + eta) * Math.PI / 180),
|
101 |
+
// "z": curr_eye.z,
|
102 |
+
// },
|
103 |
+
// "up": curr_up,
|
104 |
+
// "center": curr_center,
|
105 |
+
// }
|
106 |
+
// };
|
107 |
+
|
108 |
+
// Plotly.relayout(plotlyDiv, new_dict);
|
109 |
+
// }
|
110 |
+
|
111 |
+
// }
|
112 |
+
|
113 |
+
globalThis.latestCam = () => {
|
114 |
+
var plotlyDiv = document.getElementById("map").getElementsByClassName('js-plotly-plot')[0];
|
115 |
+
|
116 |
+
globalThis.prev_camera_dict = {};
|
117 |
+
console.log("prev camera dict", globalThis.prev_camera_dict);
|
118 |
+
|
119 |
+
// Listen for the event and log to the console
|
120 |
+
plotlyDiv.on('plotly_relayout', function(data) {
|
121 |
+
console.log('plotly_relayout event triggered:', data);
|
122 |
+
|
123 |
+
if ("scene.camera.up" in data) {
|
124 |
+
Object.assign(globalThis.prev_camera_dict, globalThis.camera_dict);
|
125 |
+
Object.assign(globalThis.camera_dict, globalThis.init_camera_dict);
|
126 |
+
}
|
127 |
+
|
128 |
+
if ('scene.camera' in data) {
|
129 |
+
Object.assign(globalThis.prev_camera_dict, globalThis.camera_dict);
|
130 |
+
globalThis.camera_dict = data;
|
131 |
+
}
|
132 |
+
|
133 |
+
var camera_json = JSON.stringify(globalThis.camera_dict);
|
134 |
+
var input_pose = document.getElementById("input_pose").getElementsByTagName("textarea")[0];
|
135 |
+
let myEvent = new Event("input")
|
136 |
+
input_pose.value = camera_json;
|
137 |
+
input_pose.dispatchEvent(myEvent);
|
138 |
+
|
139 |
+
var update_pose_btn = document.getElementById("update_pose_button");
|
140 |
+
update_pose_btn.dispatchEvent(new Event("click"));
|
141 |
+
// globalThis.restrictCamera(data);
|
142 |
+
});
|
143 |
+
}
|
144 |
+
|
145 |
+
return latestCam(this);
|
146 |
+
|
147 |
+
}
|
sgm/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .models import AutoencodingEngine, DiffusionEngine
|
2 |
+
from .util import get_configs_path, instantiate_from_config
|
3 |
+
|
4 |
+
__version__ = "0.1.0"
|
sgm/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# from .dataset import StableDataModuleFromConfig
|
sgm/data/data_co3d.py
ADDED
@@ -0,0 +1,762 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code taken and modified from https://github.com/amyxlase/relpose-plus-plus/blob/b33f7d5000cf2430bfcda6466c8e89bc2dcde43f/relpose/dataset/co3d_v2.py#L346)
|
2 |
+
import os.path as osp
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
|
9 |
+
from PIL import Image, ImageFile
|
10 |
+
import json
|
11 |
+
import gzip
|
12 |
+
from torch.utils.data import DataLoader, Dataset
|
13 |
+
from torchvision import transforms
|
14 |
+
from pytorch3d.renderer.cameras import PerspectiveCameras
|
15 |
+
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
16 |
+
from pytorch3d.implicitron.dataset.utils import adjust_camera_to_bbox_crop_, adjust_camera_to_image_scale_
|
17 |
+
from pytorch3d.transforms import Rotate, Translate
|
18 |
+
|
19 |
+
|
20 |
+
CO3D_DIR = "data/training/"
|
21 |
+
|
22 |
+
Image.MAX_IMAGE_PIXELS = None
|
23 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
24 |
+
|
25 |
+
|
26 |
+
# Added: normalize camera poses
|
27 |
+
def intersect_skew_line_groups(p, r, mask):
|
28 |
+
# p, r both of shape (B, N, n_intersected_lines, 3)
|
29 |
+
# mask of shape (B, N, n_intersected_lines)
|
30 |
+
p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
|
31 |
+
_, p_line_intersect = _point_line_distance(
|
32 |
+
p, r, p_intersect[..., None, :].expand_as(p)
|
33 |
+
)
|
34 |
+
intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(
|
35 |
+
dim=-1
|
36 |
+
)
|
37 |
+
return p_intersect, p_line_intersect, intersect_dist_squared, r
|
38 |
+
|
39 |
+
|
40 |
+
def intersect_skew_lines_high_dim(p, r, mask=None):
|
41 |
+
# Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
|
42 |
+
dim = p.shape[-1]
|
43 |
+
# make sure the heading vectors are l2-normed
|
44 |
+
if mask is None:
|
45 |
+
mask = torch.ones_like(p[..., 0])
|
46 |
+
r = torch.nn.functional.normalize(r, dim=-1)
|
47 |
+
|
48 |
+
eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
|
49 |
+
I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
|
50 |
+
sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
|
51 |
+
p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
|
52 |
+
|
53 |
+
if torch.any(torch.isnan(p_intersect)):
|
54 |
+
print(p_intersect)
|
55 |
+
assert False
|
56 |
+
return p_intersect, r
|
57 |
+
|
58 |
+
|
59 |
+
def _point_line_distance(p1, r1, p2):
|
60 |
+
df = p2 - p1
|
61 |
+
proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
|
62 |
+
line_pt_nearest = p2 - proj_vector
|
63 |
+
d = (proj_vector).norm(dim=-1)
|
64 |
+
return d, line_pt_nearest
|
65 |
+
|
66 |
+
|
67 |
+
def compute_optical_axis_intersection(cameras):
|
68 |
+
centers = cameras.get_camera_center()
|
69 |
+
principal_points = cameras.principal_point
|
70 |
+
|
71 |
+
one_vec = torch.ones((len(cameras), 1))
|
72 |
+
optical_axis = torch.cat((principal_points, one_vec), -1)
|
73 |
+
|
74 |
+
pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
|
75 |
+
|
76 |
+
pp2 = torch.zeros((pp.shape[0], 3))
|
77 |
+
for i in range(0, pp.shape[0]):
|
78 |
+
pp2[i] = pp[i][i]
|
79 |
+
|
80 |
+
directions = pp2 - centers
|
81 |
+
centers = centers.unsqueeze(0).unsqueeze(0)
|
82 |
+
directions = directions.unsqueeze(0).unsqueeze(0)
|
83 |
+
|
84 |
+
p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(
|
85 |
+
p=centers, r=directions, mask=None
|
86 |
+
)
|
87 |
+
|
88 |
+
p_intersect = p_intersect.squeeze().unsqueeze(0)
|
89 |
+
dist = (p_intersect - centers).norm(dim=-1)
|
90 |
+
|
91 |
+
return p_intersect, dist, p_line_intersect, pp2, r
|
92 |
+
|
93 |
+
|
94 |
+
def normalize_cameras(cameras, scale=1.0):
|
95 |
+
"""
|
96 |
+
Normalizes cameras such that the optical axes point to the origin and the average
|
97 |
+
distance to the origin is 1.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
cameras (List[camera]).
|
101 |
+
"""
|
102 |
+
|
103 |
+
# Let distance from first camera to origin be unit
|
104 |
+
new_cameras = cameras.clone()
|
105 |
+
new_transform = new_cameras.get_world_to_view_transform()
|
106 |
+
|
107 |
+
p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(
|
108 |
+
cameras
|
109 |
+
)
|
110 |
+
t = Translate(p_intersect)
|
111 |
+
|
112 |
+
# scale = dist.squeeze()[0]
|
113 |
+
scale = max(dist.squeeze())
|
114 |
+
|
115 |
+
# Degenerate case
|
116 |
+
if scale == 0:
|
117 |
+
print(cameras.T)
|
118 |
+
print(new_transform.get_matrix()[:, 3, :3])
|
119 |
+
return -1
|
120 |
+
assert scale != 0
|
121 |
+
|
122 |
+
new_transform = t.compose(new_transform)
|
123 |
+
new_cameras.R = new_transform.get_matrix()[:, :3, :3]
|
124 |
+
new_cameras.T = new_transform.get_matrix()[:, 3, :3] / scale
|
125 |
+
return new_cameras, p_intersect, p_line_intersect, pp, r
|
126 |
+
|
127 |
+
|
128 |
+
def centerandalign(cameras, scale=1.0):
|
129 |
+
"""
|
130 |
+
Normalizes cameras such that the optical axes point to the origin and the average
|
131 |
+
distance to the origin is 1.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
cameras (List[camera]).
|
135 |
+
"""
|
136 |
+
|
137 |
+
# Let distance from first camera to origin be unit
|
138 |
+
new_cameras = cameras.clone()
|
139 |
+
new_transform = new_cameras.get_world_to_view_transform()
|
140 |
+
|
141 |
+
p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(
|
142 |
+
cameras
|
143 |
+
)
|
144 |
+
t = Translate(p_intersect)
|
145 |
+
|
146 |
+
centers = [cam.get_camera_center() for cam in new_cameras]
|
147 |
+
centers = torch.concat(centers, 0).cpu().numpy()
|
148 |
+
m = len(cameras)
|
149 |
+
|
150 |
+
# https://math.stackexchange.com/questions/99299/best-fitting-plane-given-a-set-of-points
|
151 |
+
A = np.hstack((centers[:m, :2], np.ones((m, 1))))
|
152 |
+
B = centers[:m, 2:]
|
153 |
+
if A.shape[0] == 2:
|
154 |
+
x = A.T @ np.linalg.inv(A @ A.T) @ B
|
155 |
+
else:
|
156 |
+
x = np.linalg.inv(A.T @ A) @ A.T @ B
|
157 |
+
a, b, c = x.flatten()
|
158 |
+
n = np.array([a, b, 1])
|
159 |
+
n /= np.linalg.norm(n)
|
160 |
+
|
161 |
+
# https://math.stackexchange.com/questions/180418/calculate-rotation-matrix-to-align-vector-a-to-vector-b-in-3d
|
162 |
+
v = np.cross(n, [0, 1, 0])
|
163 |
+
s = np.linalg.norm(v)
|
164 |
+
c = np.dot(n, [0, 1, 0])
|
165 |
+
V = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
|
166 |
+
rot = torch.from_numpy(np.eye(3) + V + V @ V * (1 - c) / s**2).float()
|
167 |
+
|
168 |
+
scale = dist.squeeze()[0]
|
169 |
+
|
170 |
+
# Degenerate case
|
171 |
+
if scale == 0:
|
172 |
+
print(cameras.T)
|
173 |
+
print(new_transform.get_matrix()[:, 3, :3])
|
174 |
+
return -1
|
175 |
+
assert scale != 0
|
176 |
+
|
177 |
+
rot = Rotate(rot.T)
|
178 |
+
|
179 |
+
new_transform = rot.compose(t).compose(new_transform)
|
180 |
+
new_cameras.R = new_transform.get_matrix()[:, :3, :3]
|
181 |
+
new_cameras.T = new_transform.get_matrix()[:, 3, :3] / scale
|
182 |
+
return new_cameras
|
183 |
+
|
184 |
+
|
185 |
+
def square_bbox(bbox, padding=0.0, astype=None):
|
186 |
+
"""
|
187 |
+
Computes a square bounding box, with optional padding parameters.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
bbox: Bounding box in xyxy format (4,).
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
square_bbox in xyxy format (4,).
|
194 |
+
"""
|
195 |
+
if astype is None:
|
196 |
+
astype = type(bbox[0])
|
197 |
+
bbox = np.array(bbox)
|
198 |
+
center = ((bbox[:2] + bbox[2:]) / 2).round().astype(int)
|
199 |
+
extents = (bbox[2:] - bbox[:2]) / 2
|
200 |
+
s = (max(extents) * (1 + padding)).round().astype(int)
|
201 |
+
square_bbox = np.array(
|
202 |
+
[center[0] - s, center[1] - s, center[0] + s, center[1] + s],
|
203 |
+
dtype=astype,
|
204 |
+
)
|
205 |
+
|
206 |
+
return square_bbox
|
207 |
+
|
208 |
+
|
209 |
+
class Co3dDataset(Dataset):
|
210 |
+
def __init__(
|
211 |
+
self,
|
212 |
+
category,
|
213 |
+
split="train",
|
214 |
+
skip=2,
|
215 |
+
img_size=1024,
|
216 |
+
num_images=4,
|
217 |
+
mask_images=False,
|
218 |
+
single_id=0,
|
219 |
+
bbox=False,
|
220 |
+
modifier_token=None,
|
221 |
+
addreg=False,
|
222 |
+
drop_ratio=0.5,
|
223 |
+
drop_txt=0.1,
|
224 |
+
categoryname=None,
|
225 |
+
aligncameras=False,
|
226 |
+
repeat=100,
|
227 |
+
addlen=False,
|
228 |
+
onlyref=False,
|
229 |
+
):
|
230 |
+
"""
|
231 |
+
Args:
|
232 |
+
category (iterable): List of categories to use. If "all" is in the list,
|
233 |
+
all training categories are used.
|
234 |
+
num_images (int): Default number of images in each batch.
|
235 |
+
normalize_cameras (bool): If True, normalizes cameras so that the
|
236 |
+
intersection of the optical axes is placed at the origin and the norm
|
237 |
+
of the first camera translation is 1.
|
238 |
+
mask_images (bool): If True, masks out the background of the images.
|
239 |
+
"""
|
240 |
+
# category = CATEGORIES
|
241 |
+
category = sorted(category.split(','))
|
242 |
+
self.category = category
|
243 |
+
self.single_id = single_id
|
244 |
+
self.addlen = addlen
|
245 |
+
self.onlyref = onlyref
|
246 |
+
self.categoryname = categoryname
|
247 |
+
self.bbox = bbox
|
248 |
+
self.modifier_token = modifier_token
|
249 |
+
self.addreg = addreg
|
250 |
+
self.drop_txt = drop_txt
|
251 |
+
self.skip = skip
|
252 |
+
if self.addreg:
|
253 |
+
with open(f'data/regularization/{category[0]}_sp_generated/caption.txt', "r") as f:
|
254 |
+
self.regcaptions = f.read().splitlines()
|
255 |
+
self.reglen = len(self.regcaptions)
|
256 |
+
self.regimpath = f'data/regularization/{category[0]}_sp_generated'
|
257 |
+
|
258 |
+
self.low_quality_translations = []
|
259 |
+
self.rotations = {}
|
260 |
+
self.category_map = {}
|
261 |
+
co3d_dir = CO3D_DIR
|
262 |
+
for c in category:
|
263 |
+
subset = 'fewview_dev'
|
264 |
+
category_dir = osp.join(co3d_dir, c)
|
265 |
+
frame_file = osp.join(category_dir, "frame_annotations.jgz")
|
266 |
+
sequence_file = osp.join(category_dir, "sequence_annotations.jgz")
|
267 |
+
subset_lists_file = osp.join(category_dir, f"set_lists/set_lists_{subset}.json")
|
268 |
+
bbox_file = osp.join(category_dir, f"{c}_bbox.jgz")
|
269 |
+
|
270 |
+
with open(subset_lists_file) as f:
|
271 |
+
subset_lists_data = json.load(f)
|
272 |
+
|
273 |
+
with gzip.open(sequence_file, "r") as fin:
|
274 |
+
sequence_data = json.loads(fin.read())
|
275 |
+
|
276 |
+
with gzip.open(bbox_file, "r") as fin:
|
277 |
+
bbox_data = json.loads(fin.read())
|
278 |
+
|
279 |
+
with gzip.open(frame_file, "r") as fin:
|
280 |
+
frame_data = json.loads(fin.read())
|
281 |
+
|
282 |
+
frame_data_processed = {}
|
283 |
+
for f_data in frame_data:
|
284 |
+
sequence_name = f_data["sequence_name"]
|
285 |
+
if sequence_name not in frame_data_processed:
|
286 |
+
frame_data_processed[sequence_name] = {}
|
287 |
+
frame_data_processed[sequence_name][f_data["frame_number"]] = f_data
|
288 |
+
|
289 |
+
good_quality_sequences = set()
|
290 |
+
for seq_data in sequence_data:
|
291 |
+
if seq_data["viewpoint_quality_score"] > 0.5:
|
292 |
+
good_quality_sequences.add(seq_data["sequence_name"])
|
293 |
+
|
294 |
+
for subset in ["train"]:
|
295 |
+
for seq_name, frame_number, filepath in subset_lists_data[subset]:
|
296 |
+
if seq_name not in good_quality_sequences:
|
297 |
+
continue
|
298 |
+
|
299 |
+
if seq_name not in self.rotations:
|
300 |
+
self.rotations[seq_name] = []
|
301 |
+
self.category_map[seq_name] = c
|
302 |
+
|
303 |
+
mask_path = filepath.replace("images", "masks").replace(".jpg", ".png")
|
304 |
+
|
305 |
+
frame_data = frame_data_processed[seq_name][frame_number]
|
306 |
+
|
307 |
+
self.rotations[seq_name].append(
|
308 |
+
{
|
309 |
+
"filepath": filepath,
|
310 |
+
"R": frame_data["viewpoint"]["R"],
|
311 |
+
"T": frame_data["viewpoint"]["T"],
|
312 |
+
"focal_length": frame_data["viewpoint"]["focal_length"],
|
313 |
+
"principal_point": frame_data["viewpoint"]["principal_point"],
|
314 |
+
"mask": mask_path,
|
315 |
+
"txt": "a car",
|
316 |
+
"bbox": bbox_data[mask_path]
|
317 |
+
}
|
318 |
+
)
|
319 |
+
|
320 |
+
for seq_name in self.rotations:
|
321 |
+
seq_data = self.rotations[seq_name]
|
322 |
+
cameras = PerspectiveCameras(
|
323 |
+
focal_length=[data["focal_length"] for data in seq_data],
|
324 |
+
principal_point=[data["principal_point"] for data in seq_data],
|
325 |
+
R=[data["R"] for data in seq_data],
|
326 |
+
T=[data["T"] for data in seq_data],
|
327 |
+
)
|
328 |
+
|
329 |
+
normalized_cameras, _, _, _, _ = normalize_cameras(cameras)
|
330 |
+
if aligncameras:
|
331 |
+
normalized_cameras = centerandalign(cameras)
|
332 |
+
|
333 |
+
if normalized_cameras == -1:
|
334 |
+
print("Error in normalizing cameras: camera scale was 0")
|
335 |
+
del self.rotations[seq_name]
|
336 |
+
continue
|
337 |
+
|
338 |
+
for i, data in enumerate(seq_data):
|
339 |
+
self.rotations[seq_name][i]["R"] = normalized_cameras.R[i]
|
340 |
+
self.rotations[seq_name][i]["T"] = normalized_cameras.T[i]
|
341 |
+
self.rotations[seq_name][i]["R_original"] = torch.from_numpy(np.array(seq_data[i]["R"]))
|
342 |
+
self.rotations[seq_name][i]["T_original"] = torch.from_numpy(np.array(seq_data[i]["T"]))
|
343 |
+
|
344 |
+
# Make sure translations are not ridiculous
|
345 |
+
if self.rotations[seq_name][i]["T"][0] + self.rotations[seq_name][i]["T"][1] + self.rotations[seq_name][i]["T"][2] > 1e5:
|
346 |
+
bad_seq = True
|
347 |
+
self.low_quality_translations.append(seq_name)
|
348 |
+
break
|
349 |
+
|
350 |
+
for seq_name in self.low_quality_translations:
|
351 |
+
if seq_name in self.rotations:
|
352 |
+
del self.rotations[seq_name]
|
353 |
+
|
354 |
+
self.sequence_list = list(self.rotations.keys())
|
355 |
+
|
356 |
+
self.transform = transforms.Compose(
|
357 |
+
[
|
358 |
+
transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC),
|
359 |
+
transforms.ToTensor(),
|
360 |
+
transforms.Lambda(lambda x: x * 2.0 - 1.0)
|
361 |
+
]
|
362 |
+
)
|
363 |
+
self.transformim = transforms.Compose(
|
364 |
+
[
|
365 |
+
transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC),
|
366 |
+
transforms.CenterCrop(img_size),
|
367 |
+
transforms.ToTensor(),
|
368 |
+
transforms.Lambda(lambda x: x * 2.0 - 1.0)
|
369 |
+
]
|
370 |
+
)
|
371 |
+
self.transformmask = transforms.Compose(
|
372 |
+
[
|
373 |
+
transforms.Resize(img_size // 8),
|
374 |
+
transforms.ToTensor(),
|
375 |
+
]
|
376 |
+
)
|
377 |
+
|
378 |
+
self.num_images = num_images
|
379 |
+
self.image_size = img_size
|
380 |
+
self.normalize_cameras = normalize_cameras
|
381 |
+
self.mask_images = mask_images
|
382 |
+
self.drop_ratio = drop_ratio
|
383 |
+
self.kernel_tensor = torch.ones((1, 1, 7, 7))
|
384 |
+
self.repeat = repeat
|
385 |
+
print(self.sequence_list, "$$$$$$$$$$$$$$$$$$$$$")
|
386 |
+
self.valid_ids = np.arange(0, len(self.rotations[self.sequence_list[self.single_id]]), skip).tolist()
|
387 |
+
if split == 'test':
|
388 |
+
self.valid_ids = list(set(np.arange(0, len(self.rotations[self.sequence_list[self.single_id]])).tolist()).difference(self.valid_ids))
|
389 |
+
|
390 |
+
print(
|
391 |
+
f"Low quality translation sequences, not used: {self.low_quality_translations}"
|
392 |
+
)
|
393 |
+
print(f"Data size: {len(self)}")
|
394 |
+
|
395 |
+
def __len__(self):
|
396 |
+
return (len(self.valid_ids))*self.repeat + (1 if self.addlen else 0)
|
397 |
+
|
398 |
+
def _padded_bbox(self, bbox, w, h):
|
399 |
+
if w < h:
|
400 |
+
bbox = np.array([0, 0, w, h])
|
401 |
+
else:
|
402 |
+
bbox = np.array([0, 0, w, h])
|
403 |
+
return square_bbox(bbox.astype(np.float32))
|
404 |
+
|
405 |
+
def _crop_bbox(self, bbox, w, h):
|
406 |
+
bbox = square_bbox(bbox.astype(np.float32))
|
407 |
+
|
408 |
+
side_length = bbox[2] - bbox[0]
|
409 |
+
center = (bbox[:2] + bbox[2:]) / 2
|
410 |
+
extent = side_length / 2
|
411 |
+
|
412 |
+
# Final coordinates need to be integer for cropping.
|
413 |
+
ul = (center - extent).round().astype(int)
|
414 |
+
lr = ul + np.round(2 * extent).astype(int)
|
415 |
+
return np.concatenate((ul, lr))
|
416 |
+
|
417 |
+
def _crop_image(self, image, bbox, white_bg=False):
|
418 |
+
if white_bg:
|
419 |
+
# Only support PIL Images
|
420 |
+
image_crop = Image.new(
|
421 |
+
"RGB", (bbox[2] - bbox[0], bbox[3] - bbox[1]), (255, 255, 255)
|
422 |
+
)
|
423 |
+
image_crop.paste(image, (-bbox[0], -bbox[1]))
|
424 |
+
else:
|
425 |
+
image_crop = transforms.functional.crop(
|
426 |
+
image,
|
427 |
+
top=bbox[1],
|
428 |
+
left=bbox[0],
|
429 |
+
height=bbox[3] - bbox[1],
|
430 |
+
width=bbox[2] - bbox[0],
|
431 |
+
)
|
432 |
+
return image_crop
|
433 |
+
|
434 |
+
def __getitem__(self, index, specific_id=None, validation=False):
|
435 |
+
sequence_name = self.sequence_list[self.single_id]
|
436 |
+
|
437 |
+
metadata = self.rotations[sequence_name]
|
438 |
+
|
439 |
+
if validation:
|
440 |
+
drop_text = False
|
441 |
+
drop_im = False
|
442 |
+
else:
|
443 |
+
drop_im = np.random.uniform(0, 1) < self.drop_ratio
|
444 |
+
if not drop_im:
|
445 |
+
drop_text = np.random.uniform(0, 1) < self.drop_txt
|
446 |
+
else:
|
447 |
+
drop_text = False
|
448 |
+
|
449 |
+
size = self.image_size
|
450 |
+
|
451 |
+
# sample reference ids
|
452 |
+
listofindices = self.valid_ids.copy()
|
453 |
+
max_diff = len(listofindices) // (self.num_images-1)
|
454 |
+
if (index*self.skip) % len(metadata) in listofindices:
|
455 |
+
listofindices.remove((index*self.skip) % len(metadata))
|
456 |
+
references = np.random.choice(np.arange(0, len(listofindices)+1, max_diff), self.num_images-1, replace=False)
|
457 |
+
rem = np.random.randint(0, max_diff)
|
458 |
+
references = [listofindices[(x + rem) % len(listofindices)] for x in references]
|
459 |
+
ids = [(index*self.skip) % len(metadata)] + references
|
460 |
+
|
461 |
+
# special case to save features corresponding to ref image as part of model buffer
|
462 |
+
if self.onlyref:
|
463 |
+
ids = references + [(index*self.skip) % len(metadata)]
|
464 |
+
if specific_id is not None: # remove this later
|
465 |
+
ids = specific_id
|
466 |
+
|
467 |
+
# get data
|
468 |
+
batch = self.get_data(index=self.single_id, ids=ids)
|
469 |
+
|
470 |
+
# text prompt
|
471 |
+
if self.modifier_token is not None:
|
472 |
+
name = self.category[0] if self.categoryname is None else self.categoryname
|
473 |
+
batch['txt'] = [f'photo of a {self.modifier_token} {name}' for _ in range(len(batch['txt']))]
|
474 |
+
|
475 |
+
# replace with regularization image if drop_im
|
476 |
+
if drop_im and self.addreg:
|
477 |
+
select_id = np.random.randint(0, self.reglen)
|
478 |
+
batch["image"] = [self.transformim(Image.open(f'{self.regimpath}/images/{select_id}.png').convert('RGB'))]
|
479 |
+
batch['txt'] = [self.regcaptions[select_id]]
|
480 |
+
batch["original_size_as_tuple"] = torch.ones_like(batch["original_size_as_tuple"])*1024
|
481 |
+
|
482 |
+
# create camera class and adjust intrinsics for crop
|
483 |
+
cameras = [PerspectiveCameras(R=batch['R'][i].unsqueeze(0),
|
484 |
+
T=batch['T'][i].unsqueeze(0),
|
485 |
+
focal_length=batch['focal_lengths'][i].unsqueeze(0),
|
486 |
+
principal_point=batch['principal_points'][i].unsqueeze(0),
|
487 |
+
image_size=self.image_size
|
488 |
+
)
|
489 |
+
for i in range(len(ids))]
|
490 |
+
for i, cam in enumerate(cameras):
|
491 |
+
adjust_camera_to_bbox_crop_(cam, batch["original_size_as_tuple"][i, :2], batch["crop_coords"][i])
|
492 |
+
adjust_camera_to_image_scale_(cam, batch["original_size_as_tuple"][i, 2:], torch.tensor([self.image_size, self.image_size]))
|
493 |
+
|
494 |
+
# create mask and dilated mask for mask based losses
|
495 |
+
batch["depth"] = batch["mask"].clone()
|
496 |
+
batch["mask"] = torch.clamp(torch.nn.functional.conv2d(batch["mask"], self.kernel_tensor, padding='same'), 0, 1)
|
497 |
+
if not self.mask_images:
|
498 |
+
batch["mask"] = [None for i in range(len(ids))]
|
499 |
+
|
500 |
+
# special case to save features corresponding to zero image
|
501 |
+
if index == self.__len__()-1 and self.addlen:
|
502 |
+
batch["image"][0] *= 0.
|
503 |
+
|
504 |
+
return {"jpg": batch["image"][0],
|
505 |
+
"txt": batch["txt"][0] if not drop_text else "",
|
506 |
+
"jpg_ref": batch["image"][1:] if not drop_im else torch.stack([2*torch.rand_like(batch["image"][0])-1. for _ in range(len(ids)-1)], dim=0),
|
507 |
+
"txt_ref": batch["txt"][1:] if not drop_im else ["" for _ in range(len(ids)-1)],
|
508 |
+
"pose": cameras,
|
509 |
+
"mask": batch["mask"][0] if not drop_im else torch.ones_like(batch["mask"][0]),
|
510 |
+
"mask_ref": batch["masks_padding"][1:],
|
511 |
+
"depth": batch["depth"][0] if len(batch["depth"]) > 0 else None,
|
512 |
+
"filepaths": batch["filepaths"],
|
513 |
+
"original_size_as_tuple": batch["original_size_as_tuple"][0][2:],
|
514 |
+
"target_size_as_tuple": torch.ones_like(batch["original_size_as_tuple"][0][2:])*size,
|
515 |
+
"crop_coords_top_left": torch.zeros_like(batch["crop_coords"][0][:2]),
|
516 |
+
"original_size_as_tuple_ref": batch["original_size_as_tuple"][1:][:, 2:],
|
517 |
+
"target_size_as_tuple_ref": torch.ones_like(batch["original_size_as_tuple"][1:][:, 2:])*size,
|
518 |
+
"crop_coords_top_left_ref": torch.zeros_like(batch["crop_coords"][1:][:, :2]),
|
519 |
+
"drop_im": torch.Tensor([1-drop_im*1.])
|
520 |
+
}
|
521 |
+
|
522 |
+
def get_data(self, index=None, sequence_name=None, ids=(0, 1)):
|
523 |
+
if sequence_name is None:
|
524 |
+
sequence_name = self.sequence_list[index]
|
525 |
+
metadata = self.rotations[sequence_name]
|
526 |
+
category = self.category_map[sequence_name]
|
527 |
+
annos = [metadata[i] for i in ids]
|
528 |
+
images = []
|
529 |
+
rotations = []
|
530 |
+
translations = []
|
531 |
+
focal_lengths = []
|
532 |
+
principal_points = []
|
533 |
+
txts = []
|
534 |
+
masks = []
|
535 |
+
filepaths = []
|
536 |
+
images_transformed = []
|
537 |
+
masks_transformed = []
|
538 |
+
original_size_as_tuple = []
|
539 |
+
crop_parameters = []
|
540 |
+
masks_padding = []
|
541 |
+
depths = []
|
542 |
+
|
543 |
+
for counter, anno in enumerate(annos):
|
544 |
+
filepath = anno["filepath"]
|
545 |
+
filepaths.append(filepath)
|
546 |
+
image = Image.open(osp.join(CO3D_DIR, filepath)).convert("RGB")
|
547 |
+
|
548 |
+
mask_name = osp.basename(filepath.replace(".jpg", ".png"))
|
549 |
+
|
550 |
+
mask_path = osp.join(
|
551 |
+
CO3D_DIR, category, sequence_name, "masks", mask_name
|
552 |
+
)
|
553 |
+
mask = Image.open(mask_path).convert("L")
|
554 |
+
|
555 |
+
if mask.size != image.size:
|
556 |
+
mask = mask.resize(image.size)
|
557 |
+
|
558 |
+
mask_padded = Image.fromarray((np.ones_like(mask) > 0))
|
559 |
+
mask = Image.fromarray((np.array(mask) > 125))
|
560 |
+
masks.append(mask)
|
561 |
+
|
562 |
+
# crop image around object
|
563 |
+
w, h = image.width, image.height
|
564 |
+
bbox = np.array(anno["bbox"])
|
565 |
+
if len(bbox) == 0:
|
566 |
+
bbox = np.array([0, 0, w, h])
|
567 |
+
|
568 |
+
if self.bbox and counter > 0:
|
569 |
+
bbox = self._crop_bbox(bbox, w, h)
|
570 |
+
else:
|
571 |
+
bbox = self._padded_bbox(None, w, h)
|
572 |
+
image = self._crop_image(image, bbox)
|
573 |
+
mask = self._crop_image(mask, bbox)
|
574 |
+
mask_padded = self._crop_image(mask_padded, bbox)
|
575 |
+
masks_padding.append(self.transformmask(mask_padded))
|
576 |
+
images_transformed.append(self.transform(image))
|
577 |
+
masks_transformed.append(self.transformmask(mask))
|
578 |
+
|
579 |
+
crop_parameters.append(torch.tensor([bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1] ]).int())
|
580 |
+
original_size_as_tuple.append(torch.tensor([w, h, bbox[2] - bbox[0], bbox[3] - bbox[1]]))
|
581 |
+
images.append(image)
|
582 |
+
rotations.append(anno["R"])
|
583 |
+
translations.append(anno["T"])
|
584 |
+
focal_lengths.append(torch.tensor(anno["focal_length"]))
|
585 |
+
principal_points.append(torch.tensor(anno["principal_point"]))
|
586 |
+
txts.append(anno["txt"])
|
587 |
+
|
588 |
+
images = images_transformed
|
589 |
+
batch = {
|
590 |
+
"model_id": sequence_name,
|
591 |
+
"category": category,
|
592 |
+
"original_size_as_tuple": torch.stack(original_size_as_tuple),
|
593 |
+
"crop_coords": torch.stack(crop_parameters),
|
594 |
+
"n": len(metadata),
|
595 |
+
"ind": torch.tensor(ids),
|
596 |
+
"txt": txts,
|
597 |
+
"filepaths": filepaths,
|
598 |
+
"masks_padding": torch.stack(masks_padding) if len(masks_padding) > 0 else [],
|
599 |
+
"depth": torch.stack(depths) if len(depths) > 0 else [],
|
600 |
+
}
|
601 |
+
|
602 |
+
batch["R"] = torch.stack(rotations)
|
603 |
+
batch["T"] = torch.stack(translations)
|
604 |
+
batch["focal_lengths"] = torch.stack(focal_lengths)
|
605 |
+
batch["principal_points"] = torch.stack(principal_points)
|
606 |
+
|
607 |
+
# Add images
|
608 |
+
if self.transform is None:
|
609 |
+
batch["image"] = images
|
610 |
+
else:
|
611 |
+
batch["image"] = torch.stack(images)
|
612 |
+
batch["mask"] = torch.stack(masks_transformed)
|
613 |
+
|
614 |
+
return batch
|
615 |
+
|
616 |
+
@staticmethod
|
617 |
+
def collate_fn(batch):
|
618 |
+
"""A function to collate the data across batches. This function must be passed to pytorch's DataLoader to collate batches.
|
619 |
+
Args:
|
620 |
+
batch(list): List of objects returned by this class' __getitem__ function. This is given by pytorch's dataloader that calls __getitem__
|
621 |
+
multiple times and expects a collated batch.
|
622 |
+
Returns:
|
623 |
+
dict: The collated dictionary representing the data in the batch.
|
624 |
+
"""
|
625 |
+
result = {
|
626 |
+
"jpg": [],
|
627 |
+
"txt": [],
|
628 |
+
"jpg_ref": [],
|
629 |
+
"txt_ref": [],
|
630 |
+
"pose": [],
|
631 |
+
"original_size_as_tuple": [],
|
632 |
+
"original_size_as_tuple_ref": [],
|
633 |
+
"crop_coords_top_left": [],
|
634 |
+
"crop_coords_top_left_ref": [],
|
635 |
+
"target_size_as_tuple_ref": [],
|
636 |
+
"target_size_as_tuple": [],
|
637 |
+
"drop_im": [],
|
638 |
+
"mask_ref": [],
|
639 |
+
}
|
640 |
+
if batch[0]["mask"] is not None:
|
641 |
+
result["mask"] = []
|
642 |
+
if batch[0]["depth"] is not None:
|
643 |
+
result["depth"] = []
|
644 |
+
|
645 |
+
for batch_obj in batch:
|
646 |
+
for key in result.keys():
|
647 |
+
result[key].append(batch_obj[key])
|
648 |
+
for key in result.keys():
|
649 |
+
if not (key == 'pose' or 'txt' in key or 'size_as_tuple_ref' in key or 'coords_top_left_ref' in key):
|
650 |
+
result[key] = torch.stack(result[key], dim=0)
|
651 |
+
elif 'txt_ref' in key:
|
652 |
+
result[key] = [item for sublist in result[key] for item in sublist]
|
653 |
+
elif 'size_as_tuple_ref' in key or 'coords_top_left_ref' in key:
|
654 |
+
result[key] = torch.cat(result[key], dim=0)
|
655 |
+
elif 'pose' in key:
|
656 |
+
result[key] = [join_cameras_as_batch(cameras) for cameras in result[key]]
|
657 |
+
|
658 |
+
return result
|
659 |
+
|
660 |
+
|
661 |
+
class CustomDataDictLoader(pl.LightningDataModule):
|
662 |
+
def __init__(
|
663 |
+
self,
|
664 |
+
category,
|
665 |
+
batch_size,
|
666 |
+
mask_images=False,
|
667 |
+
skip=1,
|
668 |
+
img_size=1024,
|
669 |
+
num_images=4,
|
670 |
+
num_workers=0,
|
671 |
+
shuffle=True,
|
672 |
+
single_id=0,
|
673 |
+
modifier_token=None,
|
674 |
+
bbox=False,
|
675 |
+
addreg=False,
|
676 |
+
drop_ratio=0.5,
|
677 |
+
jitter=False,
|
678 |
+
drop_txt=0.1,
|
679 |
+
categoryname=None,
|
680 |
+
):
|
681 |
+
super().__init__()
|
682 |
+
|
683 |
+
self.batch_size = batch_size
|
684 |
+
self.num_workers = num_workers
|
685 |
+
self.shuffle = shuffle
|
686 |
+
self.train_dataset = Co3dDataset(category,
|
687 |
+
img_size=img_size,
|
688 |
+
mask_images=mask_images,
|
689 |
+
skip=skip,
|
690 |
+
num_images=num_images,
|
691 |
+
single_id=single_id,
|
692 |
+
modifier_token=modifier_token,
|
693 |
+
bbox=bbox,
|
694 |
+
addreg=addreg,
|
695 |
+
drop_ratio=drop_ratio,
|
696 |
+
drop_txt=drop_txt,
|
697 |
+
categoryname=categoryname,
|
698 |
+
)
|
699 |
+
self.val_dataset = Co3dDataset(category,
|
700 |
+
img_size=img_size,
|
701 |
+
mask_images=mask_images,
|
702 |
+
skip=skip,
|
703 |
+
num_images=2,
|
704 |
+
single_id=single_id,
|
705 |
+
modifier_token=modifier_token,
|
706 |
+
bbox=bbox,
|
707 |
+
addreg=addreg,
|
708 |
+
drop_ratio=0.,
|
709 |
+
drop_txt=0.,
|
710 |
+
categoryname=categoryname,
|
711 |
+
repeat=1,
|
712 |
+
addlen=True,
|
713 |
+
onlyref=True,
|
714 |
+
)
|
715 |
+
self.test_dataset = Co3dDataset(category,
|
716 |
+
img_size=img_size,
|
717 |
+
mask_images=mask_images,
|
718 |
+
split="test",
|
719 |
+
skip=skip,
|
720 |
+
num_images=2,
|
721 |
+
single_id=single_id,
|
722 |
+
modifier_token=modifier_token,
|
723 |
+
bbox=False,
|
724 |
+
addreg=addreg,
|
725 |
+
drop_ratio=0.,
|
726 |
+
drop_txt=0.,
|
727 |
+
categoryname=categoryname,
|
728 |
+
repeat=1,
|
729 |
+
)
|
730 |
+
self.collate_fn = Co3dDataset.collate_fn
|
731 |
+
|
732 |
+
def prepare_data(self):
|
733 |
+
pass
|
734 |
+
|
735 |
+
def train_dataloader(self):
|
736 |
+
return DataLoader(
|
737 |
+
self.train_dataset,
|
738 |
+
batch_size=self.batch_size,
|
739 |
+
shuffle=self.shuffle,
|
740 |
+
num_workers=self.num_workers,
|
741 |
+
collate_fn=self.collate_fn,
|
742 |
+
drop_last=True,
|
743 |
+
)
|
744 |
+
|
745 |
+
def test_dataloader(self):
|
746 |
+
return DataLoader(
|
747 |
+
self.train_dataset,
|
748 |
+
batch_size=self.batch_size,
|
749 |
+
shuffle=False,
|
750 |
+
num_workers=self.num_workers,
|
751 |
+
collate_fn=self.collate_fn,
|
752 |
+
)
|
753 |
+
|
754 |
+
def val_dataloader(self):
|
755 |
+
return DataLoader(
|
756 |
+
self.val_dataset,
|
757 |
+
batch_size=self.batch_size,
|
758 |
+
shuffle=False,
|
759 |
+
num_workers=self.num_workers,
|
760 |
+
collate_fn=self.collate_fn,
|
761 |
+
drop_last=True
|
762 |
+
)
|
sgm/lr_scheduler.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
warm_up_steps,
|
12 |
+
lr_min,
|
13 |
+
lr_max,
|
14 |
+
lr_start,
|
15 |
+
max_decay_steps,
|
16 |
+
verbosity_interval=0,
|
17 |
+
):
|
18 |
+
self.lr_warm_up_steps = warm_up_steps
|
19 |
+
self.lr_start = lr_start
|
20 |
+
self.lr_min = lr_min
|
21 |
+
self.lr_max = lr_max
|
22 |
+
self.lr_max_decay_steps = max_decay_steps
|
23 |
+
self.last_lr = 0.0
|
24 |
+
self.verbosity_interval = verbosity_interval
|
25 |
+
|
26 |
+
def schedule(self, n, **kwargs):
|
27 |
+
if self.verbosity_interval > 0:
|
28 |
+
if n % self.verbosity_interval == 0:
|
29 |
+
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
30 |
+
if n < self.lr_warm_up_steps:
|
31 |
+
lr = (
|
32 |
+
self.lr_max - self.lr_start
|
33 |
+
) / self.lr_warm_up_steps * n + self.lr_start
|
34 |
+
self.last_lr = lr
|
35 |
+
return lr
|
36 |
+
else:
|
37 |
+
t = (n - self.lr_warm_up_steps) / (
|
38 |
+
self.lr_max_decay_steps - self.lr_warm_up_steps
|
39 |
+
)
|
40 |
+
t = min(t, 1.0)
|
41 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
42 |
+
1 + np.cos(t * np.pi)
|
43 |
+
)
|
44 |
+
self.last_lr = lr
|
45 |
+
return lr
|
46 |
+
|
47 |
+
def __call__(self, n, **kwargs):
|
48 |
+
return self.schedule(n, **kwargs)
|
49 |
+
|
50 |
+
|
51 |
+
class LambdaWarmUpCosineScheduler2:
|
52 |
+
"""
|
53 |
+
supports repeated iterations, configurable via lists
|
54 |
+
note: use with a base_lr of 1.0.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
|
59 |
+
):
|
60 |
+
assert (
|
61 |
+
len(warm_up_steps)
|
62 |
+
== len(f_min)
|
63 |
+
== len(f_max)
|
64 |
+
== len(f_start)
|
65 |
+
== len(cycle_lengths)
|
66 |
+
)
|
67 |
+
self.lr_warm_up_steps = warm_up_steps
|
68 |
+
self.f_start = f_start
|
69 |
+
self.f_min = f_min
|
70 |
+
self.f_max = f_max
|
71 |
+
self.cycle_lengths = cycle_lengths
|
72 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
73 |
+
self.last_f = 0.0
|
74 |
+
self.verbosity_interval = verbosity_interval
|
75 |
+
|
76 |
+
def find_in_interval(self, n):
|
77 |
+
interval = 0
|
78 |
+
for cl in self.cum_cycles[1:]:
|
79 |
+
if n <= cl:
|
80 |
+
return interval
|
81 |
+
interval += 1
|
82 |
+
|
83 |
+
def schedule(self, n, **kwargs):
|
84 |
+
cycle = self.find_in_interval(n)
|
85 |
+
n = n - self.cum_cycles[cycle]
|
86 |
+
if self.verbosity_interval > 0:
|
87 |
+
if n % self.verbosity_interval == 0:
|
88 |
+
print(
|
89 |
+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
90 |
+
f"current cycle {cycle}"
|
91 |
+
)
|
92 |
+
if n < self.lr_warm_up_steps[cycle]:
|
93 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
94 |
+
cycle
|
95 |
+
] * n + self.f_start[cycle]
|
96 |
+
self.last_f = f
|
97 |
+
return f
|
98 |
+
else:
|
99 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (
|
100 |
+
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
101 |
+
)
|
102 |
+
t = min(t, 1.0)
|
103 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
104 |
+
1 + np.cos(t * np.pi)
|
105 |
+
)
|
106 |
+
self.last_f = f
|
107 |
+
return f
|
108 |
+
|
109 |
+
def __call__(self, n, **kwargs):
|
110 |
+
return self.schedule(n, **kwargs)
|
111 |
+
|
112 |
+
|
113 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
114 |
+
def schedule(self, n, **kwargs):
|
115 |
+
cycle = self.find_in_interval(n)
|
116 |
+
n = n - self.cum_cycles[cycle]
|
117 |
+
if self.verbosity_interval > 0:
|
118 |
+
if n % self.verbosity_interval == 0:
|
119 |
+
print(
|
120 |
+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
121 |
+
f"current cycle {cycle}"
|
122 |
+
)
|
123 |
+
|
124 |
+
if n < self.lr_warm_up_steps[cycle]:
|
125 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
126 |
+
cycle
|
127 |
+
] * n + self.f_start[cycle]
|
128 |
+
self.last_f = f
|
129 |
+
return f
|
130 |
+
else:
|
131 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
|
132 |
+
self.cycle_lengths[cycle] - n
|
133 |
+
) / (self.cycle_lengths[cycle])
|
134 |
+
self.last_f = f
|
135 |
+
return f
|
sgm/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .autoencoder import AutoencodingEngine
|
2 |
+
from .diffusion import DiffusionEngine
|
sgm/models/autoencoder.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from abc import abstractmethod
|
3 |
+
from contextlib import contextmanager
|
4 |
+
from typing import Any, Dict, Tuple, Union
|
5 |
+
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch
|
8 |
+
from omegaconf import ListConfig
|
9 |
+
from packaging import version
|
10 |
+
from safetensors.torch import load_file as load_safetensors
|
11 |
+
|
12 |
+
from ..modules.diffusionmodules.model import Decoder, Encoder
|
13 |
+
from ..modules.distributions.distributions import DiagonalGaussianDistribution
|
14 |
+
from ..modules.ema import LitEma
|
15 |
+
from ..util import default, get_obj_from_str, instantiate_from_config
|
16 |
+
|
17 |
+
|
18 |
+
class AbstractAutoencoder(pl.LightningModule):
|
19 |
+
"""
|
20 |
+
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
21 |
+
unCLIP models, etc. Hence, it is fairly general, and specific features
|
22 |
+
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
ema_decay: Union[None, float] = None,
|
28 |
+
monitor: Union[None, str] = None,
|
29 |
+
input_key: str = "jpg",
|
30 |
+
ckpt_path: Union[None, str] = None,
|
31 |
+
ignore_keys: Union[Tuple, list, ListConfig] = (),
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
self.input_key = input_key
|
35 |
+
self.use_ema = ema_decay is not None
|
36 |
+
if monitor is not None:
|
37 |
+
self.monitor = monitor
|
38 |
+
|
39 |
+
if self.use_ema:
|
40 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
41 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
42 |
+
|
43 |
+
if ckpt_path is not None:
|
44 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
45 |
+
|
46 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
47 |
+
self.automatic_optimization = False
|
48 |
+
|
49 |
+
def init_from_ckpt(
|
50 |
+
self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
|
51 |
+
) -> None:
|
52 |
+
if path.endswith("ckpt"):
|
53 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
54 |
+
elif path.endswith("safetensors"):
|
55 |
+
sd = load_safetensors(path)
|
56 |
+
else:
|
57 |
+
raise NotImplementedError
|
58 |
+
|
59 |
+
keys = list(sd.keys())
|
60 |
+
for k in keys:
|
61 |
+
for ik in ignore_keys:
|
62 |
+
if re.match(ik, k):
|
63 |
+
print("Deleting key {} from state_dict.".format(k))
|
64 |
+
del sd[k]
|
65 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
66 |
+
print(
|
67 |
+
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
68 |
+
)
|
69 |
+
if len(missing) > 0:
|
70 |
+
print(f"Missing Keys: {missing}")
|
71 |
+
if len(unexpected) > 0:
|
72 |
+
print(f"Unexpected Keys: {unexpected}")
|
73 |
+
|
74 |
+
@abstractmethod
|
75 |
+
def get_input(self, batch) -> Any:
|
76 |
+
raise NotImplementedError()
|
77 |
+
|
78 |
+
def on_train_batch_end(self, *args, **kwargs):
|
79 |
+
# for EMA computation
|
80 |
+
if self.use_ema:
|
81 |
+
self.model_ema(self)
|
82 |
+
|
83 |
+
@contextmanager
|
84 |
+
def ema_scope(self, context=None):
|
85 |
+
if self.use_ema:
|
86 |
+
self.model_ema.store(self.parameters())
|
87 |
+
self.model_ema.copy_to(self)
|
88 |
+
if context is not None:
|
89 |
+
print(f"{context}: Switched to EMA weights")
|
90 |
+
try:
|
91 |
+
yield None
|
92 |
+
finally:
|
93 |
+
if self.use_ema:
|
94 |
+
self.model_ema.restore(self.parameters())
|
95 |
+
if context is not None:
|
96 |
+
print(f"{context}: Restored training weights")
|
97 |
+
|
98 |
+
@abstractmethod
|
99 |
+
def encode(self, *args, **kwargs) -> torch.Tensor:
|
100 |
+
raise NotImplementedError("encode()-method of abstract base class called")
|
101 |
+
|
102 |
+
@abstractmethod
|
103 |
+
def decode(self, *args, **kwargs) -> torch.Tensor:
|
104 |
+
raise NotImplementedError("decode()-method of abstract base class called")
|
105 |
+
|
106 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
107 |
+
print(f"loading >>> {cfg['target']} <<< optimizer from config")
|
108 |
+
return get_obj_from_str(cfg["target"])(
|
109 |
+
params, lr=lr, **cfg.get("params", dict())
|
110 |
+
)
|
111 |
+
|
112 |
+
def configure_optimizers(self) -> Any:
|
113 |
+
raise NotImplementedError()
|
114 |
+
|
115 |
+
|
116 |
+
class AutoencodingEngine(AbstractAutoencoder):
|
117 |
+
"""
|
118 |
+
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
119 |
+
(we also restore them explicitly as special cases for legacy reasons).
|
120 |
+
Regularizations such as KL or VQ are moved to the regularizer class.
|
121 |
+
"""
|
122 |
+
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
*args,
|
126 |
+
encoder_config: Dict,
|
127 |
+
decoder_config: Dict,
|
128 |
+
loss_config: Dict,
|
129 |
+
regularizer_config: Dict,
|
130 |
+
optimizer_config: Union[Dict, None] = None,
|
131 |
+
lr_g_factor: float = 1.0,
|
132 |
+
**kwargs,
|
133 |
+
):
|
134 |
+
super().__init__(*args, **kwargs)
|
135 |
+
# todo: add options to freeze encoder/decoder
|
136 |
+
self.encoder = instantiate_from_config(encoder_config)
|
137 |
+
self.decoder = instantiate_from_config(decoder_config)
|
138 |
+
self.loss = instantiate_from_config(loss_config)
|
139 |
+
self.regularization = instantiate_from_config(regularizer_config)
|
140 |
+
self.optimizer_config = default(
|
141 |
+
optimizer_config, {"target": "torch.optim.Adam"}
|
142 |
+
)
|
143 |
+
self.lr_g_factor = lr_g_factor
|
144 |
+
|
145 |
+
def get_input(self, batch: Dict) -> torch.Tensor:
|
146 |
+
# assuming unified data format, dataloader returns a dict.
|
147 |
+
# image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
|
148 |
+
return batch[self.input_key]
|
149 |
+
|
150 |
+
def get_autoencoder_params(self) -> list:
|
151 |
+
params = (
|
152 |
+
list(self.encoder.parameters())
|
153 |
+
+ list(self.decoder.parameters())
|
154 |
+
+ list(self.regularization.get_trainable_parameters())
|
155 |
+
+ list(self.loss.get_trainable_autoencoder_parameters())
|
156 |
+
)
|
157 |
+
return params
|
158 |
+
|
159 |
+
def get_discriminator_params(self) -> list:
|
160 |
+
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
161 |
+
return params
|
162 |
+
|
163 |
+
def get_last_layer(self):
|
164 |
+
return self.decoder.get_last_layer()
|
165 |
+
|
166 |
+
def encode(self, x: Any, return_reg_log: bool = False) -> Any:
|
167 |
+
z = self.encoder(x)
|
168 |
+
z, reg_log = self.regularization(z)
|
169 |
+
if return_reg_log:
|
170 |
+
return z, reg_log
|
171 |
+
return z
|
172 |
+
|
173 |
+
def decode(self, z: Any) -> torch.Tensor:
|
174 |
+
x = self.decoder(z)
|
175 |
+
return x
|
176 |
+
|
177 |
+
def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
178 |
+
z, reg_log = self.encode(x, return_reg_log=True)
|
179 |
+
dec = self.decode(z)
|
180 |
+
return z, dec, reg_log
|
181 |
+
|
182 |
+
def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
|
183 |
+
x = self.get_input(batch)
|
184 |
+
z, xrec, regularization_log = self(x)
|
185 |
+
|
186 |
+
if optimizer_idx == 0:
|
187 |
+
# autoencode
|
188 |
+
aeloss, log_dict_ae = self.loss(
|
189 |
+
regularization_log,
|
190 |
+
x,
|
191 |
+
xrec,
|
192 |
+
optimizer_idx,
|
193 |
+
self.global_step,
|
194 |
+
last_layer=self.get_last_layer(),
|
195 |
+
split="train",
|
196 |
+
)
|
197 |
+
|
198 |
+
self.log_dict(
|
199 |
+
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
200 |
+
)
|
201 |
+
return aeloss
|
202 |
+
|
203 |
+
if optimizer_idx == 1:
|
204 |
+
# discriminator
|
205 |
+
discloss, log_dict_disc = self.loss(
|
206 |
+
regularization_log,
|
207 |
+
x,
|
208 |
+
xrec,
|
209 |
+
optimizer_idx,
|
210 |
+
self.global_step,
|
211 |
+
last_layer=self.get_last_layer(),
|
212 |
+
split="train",
|
213 |
+
)
|
214 |
+
self.log_dict(
|
215 |
+
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
216 |
+
)
|
217 |
+
return discloss
|
218 |
+
|
219 |
+
def validation_step(self, batch, batch_idx) -> Dict:
|
220 |
+
log_dict = self._validation_step(batch, batch_idx)
|
221 |
+
with self.ema_scope():
|
222 |
+
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
223 |
+
log_dict.update(log_dict_ema)
|
224 |
+
return log_dict
|
225 |
+
|
226 |
+
def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
|
227 |
+
x = self.get_input(batch)
|
228 |
+
|
229 |
+
z, xrec, regularization_log = self(x)
|
230 |
+
aeloss, log_dict_ae = self.loss(
|
231 |
+
regularization_log,
|
232 |
+
x,
|
233 |
+
xrec,
|
234 |
+
0,
|
235 |
+
self.global_step,
|
236 |
+
last_layer=self.get_last_layer(),
|
237 |
+
split="val" + postfix,
|
238 |
+
)
|
239 |
+
|
240 |
+
discloss, log_dict_disc = self.loss(
|
241 |
+
regularization_log,
|
242 |
+
x,
|
243 |
+
xrec,
|
244 |
+
1,
|
245 |
+
self.global_step,
|
246 |
+
last_layer=self.get_last_layer(),
|
247 |
+
split="val" + postfix,
|
248 |
+
)
|
249 |
+
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
250 |
+
log_dict_ae.update(log_dict_disc)
|
251 |
+
self.log_dict(log_dict_ae)
|
252 |
+
return log_dict_ae
|
253 |
+
|
254 |
+
def configure_optimizers(self) -> Any:
|
255 |
+
ae_params = self.get_autoencoder_params()
|
256 |
+
disc_params = self.get_discriminator_params()
|
257 |
+
|
258 |
+
opt_ae = self.instantiate_optimizer_from_config(
|
259 |
+
ae_params,
|
260 |
+
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
261 |
+
self.optimizer_config,
|
262 |
+
)
|
263 |
+
opt_disc = self.instantiate_optimizer_from_config(
|
264 |
+
disc_params, self.learning_rate, self.optimizer_config
|
265 |
+
)
|
266 |
+
|
267 |
+
return [opt_ae, opt_disc], []
|
268 |
+
|
269 |
+
@torch.no_grad()
|
270 |
+
def log_images(self, batch: Dict, **kwargs) -> Dict:
|
271 |
+
log = dict()
|
272 |
+
x = self.get_input(batch)
|
273 |
+
_, xrec, _ = self(x)
|
274 |
+
log["inputs"] = x
|
275 |
+
log["reconstructions"] = xrec
|
276 |
+
with self.ema_scope():
|
277 |
+
_, xrec_ema, _ = self(x)
|
278 |
+
log["reconstructions_ema"] = xrec_ema
|
279 |
+
return log
|
280 |
+
|
281 |
+
|
282 |
+
class AutoencoderKL(AutoencodingEngine):
|
283 |
+
def __init__(self, embed_dim: int, **kwargs):
|
284 |
+
ddconfig = kwargs.pop("ddconfig")
|
285 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
286 |
+
ignore_keys = kwargs.pop("ignore_keys", ())
|
287 |
+
super().__init__(
|
288 |
+
encoder_config={"target": "torch.nn.Identity"},
|
289 |
+
decoder_config={"target": "torch.nn.Identity"},
|
290 |
+
regularizer_config={"target": "torch.nn.Identity"},
|
291 |
+
loss_config=kwargs.pop("lossconfig"),
|
292 |
+
**kwargs,
|
293 |
+
)
|
294 |
+
assert ddconfig["double_z"]
|
295 |
+
self.encoder = Encoder(**ddconfig)
|
296 |
+
self.decoder = Decoder(**ddconfig)
|
297 |
+
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
298 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
299 |
+
self.embed_dim = embed_dim
|
300 |
+
|
301 |
+
if ckpt_path is not None:
|
302 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
303 |
+
|
304 |
+
def encode(self, x):
|
305 |
+
assert (
|
306 |
+
not self.training
|
307 |
+
), f"{self.__class__.__name__} only supports inference currently"
|
308 |
+
h = self.encoder(x)
|
309 |
+
moments = self.quant_conv(h)
|
310 |
+
posterior = DiagonalGaussianDistribution(moments)
|
311 |
+
return posterior
|
312 |
+
|
313 |
+
def decode(self, z, **decoder_kwargs):
|
314 |
+
z = self.post_quant_conv(z)
|
315 |
+
dec = self.decoder(z, **decoder_kwargs)
|
316 |
+
return dec
|
317 |
+
|
318 |
+
|
319 |
+
class AutoencoderKLInferenceWrapper(AutoencoderKL):
|
320 |
+
def encode(self, x):
|
321 |
+
return super().encode(x).sample()
|
322 |
+
|
323 |
+
|
324 |
+
class IdentityFirstStage(AbstractAutoencoder):
|
325 |
+
def __init__(self, *args, **kwargs):
|
326 |
+
super().__init__(*args, **kwargs)
|
327 |
+
|
328 |
+
def get_input(self, x: Any) -> Any:
|
329 |
+
return x
|
330 |
+
|
331 |
+
def encode(self, x: Any, *args, **kwargs) -> Any:
|
332 |
+
return x
|
333 |
+
|
334 |
+
def decode(self, x: Any, *args, **kwargs) -> Any:
|
335 |
+
return x
|
sgm/models/diffusion.py
ADDED
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import contextmanager
|
2 |
+
from typing import Any, Dict, List, Tuple, Union, DefaultDict
|
3 |
+
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import torch
|
6 |
+
from omegaconf import ListConfig, OmegaConf
|
7 |
+
from safetensors.torch import load_file as load_safetensors
|
8 |
+
from torch.optim.lr_scheduler import LambdaLR
|
9 |
+
from einops import rearrange
|
10 |
+
import math
|
11 |
+
import torch.nn as nn
|
12 |
+
from ..modules import UNCONDITIONAL_CONFIG
|
13 |
+
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
14 |
+
from ..modules.ema import LitEma
|
15 |
+
from ..util import (
|
16 |
+
default,
|
17 |
+
disabled_train,
|
18 |
+
get_obj_from_str,
|
19 |
+
instantiate_from_config,
|
20 |
+
log_txt_as_img,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
import collections
|
25 |
+
from functools import partial
|
26 |
+
|
27 |
+
|
28 |
+
def save_activations(
|
29 |
+
activations: DefaultDict,
|
30 |
+
name: str,
|
31 |
+
module: nn.Module,
|
32 |
+
inp: Tuple,
|
33 |
+
out: torch.Tensor
|
34 |
+
) -> None:
|
35 |
+
"""PyTorch Forward hook to save outputs at each forward
|
36 |
+
pass. Mutates specified dict objects with each fwd pass.
|
37 |
+
"""
|
38 |
+
if isinstance(out, tuple):
|
39 |
+
if out[1] is None:
|
40 |
+
activations[name].append(out[0].detach())
|
41 |
+
|
42 |
+
class DiffusionEngine(pl.LightningModule):
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
network_config,
|
46 |
+
denoiser_config,
|
47 |
+
first_stage_config,
|
48 |
+
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
49 |
+
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
50 |
+
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
51 |
+
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
52 |
+
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
53 |
+
network_wrapper: Union[None, str] = None,
|
54 |
+
ckpt_path: Union[None, str] = None,
|
55 |
+
use_ema: bool = False,
|
56 |
+
ema_decay_rate: float = 0.9999,
|
57 |
+
scale_factor: float = 1.0,
|
58 |
+
disable_first_stage_autocast=False,
|
59 |
+
input_key: str = "jpg",
|
60 |
+
log_keys: Union[List, None] = None,
|
61 |
+
no_cond_log: bool = False,
|
62 |
+
compile_model: bool = False,
|
63 |
+
trainkeys='pose',
|
64 |
+
multiplier=0.05,
|
65 |
+
loss_rgb_lambda=20.,
|
66 |
+
loss_fg_lambda=10.,
|
67 |
+
loss_bg_lambda=20.,
|
68 |
+
):
|
69 |
+
super().__init__()
|
70 |
+
self.log_keys = log_keys
|
71 |
+
self.input_key = input_key
|
72 |
+
self.trainkeys = trainkeys
|
73 |
+
self.multiplier = multiplier
|
74 |
+
self.loss_rgb_lambda = loss_rgb_lambda
|
75 |
+
self.loss_fg_lambda = loss_fg_lambda
|
76 |
+
self.loss_bg_lambda = loss_bg_lambda
|
77 |
+
self.rgb = network_config.params.rgb
|
78 |
+
self.rgb_predict = network_config.params.rgb_predict
|
79 |
+
self.add_token = ('modifier_token' in conditioner_config.params.emb_models[1].params)
|
80 |
+
self.optimizer_config = default(
|
81 |
+
optimizer_config, {"target": "torch.optim.AdamW"}
|
82 |
+
)
|
83 |
+
model = instantiate_from_config(network_config)
|
84 |
+
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
85 |
+
model, compile_model=compile_model
|
86 |
+
)
|
87 |
+
|
88 |
+
self.denoiser = instantiate_from_config(denoiser_config)
|
89 |
+
self.sampler = (
|
90 |
+
instantiate_from_config(sampler_config)
|
91 |
+
if sampler_config is not None
|
92 |
+
else None
|
93 |
+
)
|
94 |
+
self.conditioner = instantiate_from_config(
|
95 |
+
default(conditioner_config, UNCONDITIONAL_CONFIG)
|
96 |
+
)
|
97 |
+
self.scheduler_config = scheduler_config
|
98 |
+
self._init_first_stage(first_stage_config)
|
99 |
+
|
100 |
+
self.loss_fn = (
|
101 |
+
instantiate_from_config(loss_fn_config)
|
102 |
+
if loss_fn_config is not None
|
103 |
+
else None
|
104 |
+
)
|
105 |
+
|
106 |
+
self.use_ema = use_ema
|
107 |
+
if self.use_ema:
|
108 |
+
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
|
109 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
110 |
+
|
111 |
+
self.scale_factor = scale_factor
|
112 |
+
self.disable_first_stage_autocast = disable_first_stage_autocast
|
113 |
+
self.no_cond_log = no_cond_log
|
114 |
+
|
115 |
+
if ckpt_path is not None:
|
116 |
+
self.init_from_ckpt(ckpt_path)
|
117 |
+
|
118 |
+
blocks = []
|
119 |
+
if self.trainkeys == 'poseattn':
|
120 |
+
for x in self.model.diffusion_model.named_parameters():
|
121 |
+
if not ('pose' in x[0] or 'transformer_blocks' in x[0]):
|
122 |
+
x[1].requires_grad = False
|
123 |
+
else:
|
124 |
+
if 'pose' in x[0]:
|
125 |
+
x[1].requires_grad = True
|
126 |
+
blocks.append(x[0].split('.pose')[0])
|
127 |
+
|
128 |
+
blocks = set(blocks)
|
129 |
+
for x in self.model.diffusion_model.named_parameters():
|
130 |
+
if 'transformer_blocks' in x[0]:
|
131 |
+
reqgrad = False
|
132 |
+
for each in blocks:
|
133 |
+
if each in x[0] and ('attn1' in x[0] or 'attn2' in x[0] or 'pose' in x[0]):
|
134 |
+
reqgrad = True
|
135 |
+
x[1].requires_grad = True
|
136 |
+
if not reqgrad:
|
137 |
+
x[1].requires_grad = False
|
138 |
+
elif self.trainkeys == 'pose':
|
139 |
+
for x in self.model.diffusion_model.named_parameters():
|
140 |
+
if not ('pose' in x[0]):
|
141 |
+
x[1].requires_grad = False
|
142 |
+
else:
|
143 |
+
x[1].requires_grad = True
|
144 |
+
elif self.trainkeys == 'all':
|
145 |
+
for x in self.model.diffusion_model.named_parameters():
|
146 |
+
x[1].requires_grad = True
|
147 |
+
|
148 |
+
self.model = self.model.to(memory_format=torch.channels_last)
|
149 |
+
|
150 |
+
def register_activation_hooks(
|
151 |
+
self,
|
152 |
+
) -> None:
|
153 |
+
self.activations_dict = collections.defaultdict(list)
|
154 |
+
handles = []
|
155 |
+
for name, module in self.model.diffusion_model.named_modules():
|
156 |
+
if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks':
|
157 |
+
if hasattr(module, 'pose_emb_layers'):
|
158 |
+
handle = module.register_forward_hook(
|
159 |
+
partial(save_activations, self.activations_dict, name)
|
160 |
+
)
|
161 |
+
handles.append(handle)
|
162 |
+
self.handles = handles
|
163 |
+
|
164 |
+
def clear_rendered_feat(self,):
|
165 |
+
for name, module in self.model.diffusion_model.named_modules():
|
166 |
+
if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks':
|
167 |
+
if hasattr(module, 'pose_emb_layers'):
|
168 |
+
module.rendered_feat = None
|
169 |
+
|
170 |
+
def remove_activation_hooks(
|
171 |
+
self, handles
|
172 |
+
) -> None:
|
173 |
+
for handle in handles:
|
174 |
+
handle.remove()
|
175 |
+
|
176 |
+
def init_from_ckpt(
|
177 |
+
self,
|
178 |
+
path: str,
|
179 |
+
) -> None:
|
180 |
+
if path.endswith("ckpt"):
|
181 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
182 |
+
elif path.endswith("safetensors"):
|
183 |
+
sd = load_safetensors(path)
|
184 |
+
else:
|
185 |
+
raise NotImplementedError
|
186 |
+
|
187 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
188 |
+
print(
|
189 |
+
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
190 |
+
)
|
191 |
+
if len(missing) > 0:
|
192 |
+
print(f"Missing Keys: {missing}")
|
193 |
+
if len(unexpected) > 0:
|
194 |
+
print(f"Unexpected Keys: {unexpected}")
|
195 |
+
|
196 |
+
def _init_first_stage(self, config):
|
197 |
+
model = instantiate_from_config(config).eval()
|
198 |
+
model.train = disabled_train
|
199 |
+
for param in model.parameters():
|
200 |
+
param.requires_grad = False
|
201 |
+
self.first_stage_model = model
|
202 |
+
|
203 |
+
def get_input(self, batch):
|
204 |
+
return batch[self.input_key], batch[self.input_key + '_ref'] if self.input_key + '_ref' in batch else None, batch['pose'] if 'pose' in batch else None, batch['mask'] if "mask" in batch else None, batch['mask_ref'] if "mask_ref" in batch else None, batch['depth'] if "depth" in batch else None, batch['drop_im'] if "drop_im" in batch else 0.
|
205 |
+
|
206 |
+
@torch.no_grad()
|
207 |
+
def decode_first_stage(self, z):
|
208 |
+
z = 1.0 / self.scale_factor * z
|
209 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
210 |
+
out = self.first_stage_model.decode(z)
|
211 |
+
return out
|
212 |
+
|
213 |
+
@torch.no_grad()
|
214 |
+
def encode_first_stage(self, x):
|
215 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
216 |
+
z = self.first_stage_model.encode(x)
|
217 |
+
z = self.scale_factor * z
|
218 |
+
return z
|
219 |
+
|
220 |
+
def forward(self, x, x_rgb, xr, pose, mask, mask_ref, opacity, drop_im, batch):
|
221 |
+
loss, loss_fg, loss_bg, loss_rgb = self.loss_fn(self.model, self.denoiser, self.conditioner, x, x_rgb, xr, pose, mask, mask_ref, opacity, batch)
|
222 |
+
loss_mean = loss.mean()
|
223 |
+
loss_dict = {"loss": loss_mean.item()}
|
224 |
+
if self.rgb and self.global_step > 0:
|
225 |
+
loss_fg = (loss_fg.mean(1)*drop_im.reshape(-1)).sum()/(drop_im.sum() + 1e-12)
|
226 |
+
loss_bg = (loss_bg.mean(1)*drop_im.reshape(-1)).sum()/(drop_im.sum() + 1e-12)
|
227 |
+
loss_mean += self.loss_fg_lambda*loss_fg
|
228 |
+
loss_mean += self.loss_bg_lambda*loss_bg
|
229 |
+
loss_dict["loss_fg"] = loss_fg.item()
|
230 |
+
loss_dict["loss_bg"] = loss_bg.item()
|
231 |
+
if self.rgb_predict and loss_rgb.mean() > 0:
|
232 |
+
loss_rgb = (loss_rgb.mean(1)*drop_im.reshape(-1)).sum()/(drop_im.sum() + 1e-12)
|
233 |
+
loss_mean += self.loss_rgb_lambda*loss_rgb
|
234 |
+
loss_dict["loss_rgb"] = loss_rgb.item()
|
235 |
+
return loss_mean, loss_dict
|
236 |
+
|
237 |
+
def shared_step(self, batch: Dict) -> Any:
|
238 |
+
x, xr, pose, mask, mask_ref, opacity, drop_im = self.get_input(batch)
|
239 |
+
x_rgb = x.clone().detach()
|
240 |
+
x = self.encode_first_stage(x)
|
241 |
+
x = x.to(memory_format=torch.channels_last)
|
242 |
+
if xr is not None:
|
243 |
+
b, n = xr.shape[0], xr.shape[1]
|
244 |
+
xr = rearrange(self.encode_first_stage(rearrange(xr, "b n ... -> (b n) ...")), "(b n) ... -> b n ...", b=b, n=n)
|
245 |
+
xr = drop_im.reshape(b, 1, 1, 1, 1)*xr + (1-drop_im.reshape(b, 1, 1, 1, 1))*torch.zeros_like(xr)
|
246 |
+
batch["global_step"] = self.global_step
|
247 |
+
loss, loss_dict = self(x, x_rgb, xr, pose, mask, mask_ref, opacity, drop_im, batch)
|
248 |
+
return loss, loss_dict
|
249 |
+
|
250 |
+
def training_step(self, batch, batch_idx):
|
251 |
+
loss, loss_dict = self.shared_step(batch)
|
252 |
+
|
253 |
+
self.log_dict(
|
254 |
+
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
255 |
+
)
|
256 |
+
|
257 |
+
self.log(
|
258 |
+
"global_step",
|
259 |
+
self.global_step,
|
260 |
+
prog_bar=True,
|
261 |
+
logger=True,
|
262 |
+
on_step=True,
|
263 |
+
on_epoch=False,
|
264 |
+
)
|
265 |
+
|
266 |
+
if self.scheduler_config is not None:
|
267 |
+
lr = self.optimizers().param_groups[0]["lr"]
|
268 |
+
self.log(
|
269 |
+
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
270 |
+
)
|
271 |
+
return loss
|
272 |
+
|
273 |
+
def validation_step(self, batch, batch_idx):
|
274 |
+
# print("validation data", len(self.trainer.val_dataloaders))
|
275 |
+
loss, loss_dict = self.shared_step(batch)
|
276 |
+
return loss
|
277 |
+
|
278 |
+
def on_train_start(self, *args, **kwargs):
|
279 |
+
if self.sampler is None or self.loss_fn is None:
|
280 |
+
raise ValueError("Sampler and loss function need to be set for training.")
|
281 |
+
|
282 |
+
def on_train_batch_end(self, *args, **kwargs):
|
283 |
+
if self.use_ema:
|
284 |
+
self.model_ema(self.model)
|
285 |
+
|
286 |
+
def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
|
287 |
+
optimizer.zero_grad(set_to_none=True)
|
288 |
+
|
289 |
+
@contextmanager
|
290 |
+
def ema_scope(self, context=None):
|
291 |
+
if self.use_ema:
|
292 |
+
self.model_ema.store(self.model.parameters())
|
293 |
+
self.model_ema.copy_to(self.model)
|
294 |
+
if context is not None:
|
295 |
+
print(f"{context}: Switched to EMA weights")
|
296 |
+
try:
|
297 |
+
yield None
|
298 |
+
finally:
|
299 |
+
if self.use_ema:
|
300 |
+
self.model_ema.restore(self.model.parameters())
|
301 |
+
if context is not None:
|
302 |
+
print(f"{context}: Restored training weights")
|
303 |
+
|
304 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
305 |
+
return get_obj_from_str(cfg["target"])(
|
306 |
+
params, lr=lr, **cfg.get("params", dict())
|
307 |
+
)
|
308 |
+
|
309 |
+
def configure_optimizers(self):
|
310 |
+
lr = self.learning_rate
|
311 |
+
params = []
|
312 |
+
blocks = []
|
313 |
+
lowlrparams = []
|
314 |
+
if self.trainkeys == 'poseattn':
|
315 |
+
lowlrparams = []
|
316 |
+
for x in self.model.diffusion_model.named_parameters():
|
317 |
+
if ('pose' in x[0]):
|
318 |
+
params += [x[1]]
|
319 |
+
blocks.append(x[0].split('.pose')[0])
|
320 |
+
print(x[0])
|
321 |
+
blocks = set(blocks)
|
322 |
+
for x in self.model.diffusion_model.named_parameters():
|
323 |
+
if 'transformer_blocks' in x[0]:
|
324 |
+
for each in blocks:
|
325 |
+
if each in x[0] and not ('pose' in x[0]) and ('attn1' in x[0] or 'attn2' in x[0]):
|
326 |
+
lowlrparams += [x[1]]
|
327 |
+
elif self.trainkeys == 'pose':
|
328 |
+
for x in self.model.diffusion_model.named_parameters():
|
329 |
+
if ('pose' in x[0]):
|
330 |
+
params += [x[1]]
|
331 |
+
print(x[0])
|
332 |
+
elif self.trainkeys == 'all':
|
333 |
+
lowlrparams = []
|
334 |
+
for x in self.model.diffusion_model.named_parameters():
|
335 |
+
if ('pose' in x[0]):
|
336 |
+
params += [x[1]]
|
337 |
+
print(x[0])
|
338 |
+
else:
|
339 |
+
lowlrparams += [x[1]]
|
340 |
+
|
341 |
+
for i, embedder in enumerate(self.conditioner.embedders[:2]):
|
342 |
+
if embedder.is_trainable:
|
343 |
+
params = params + list(embedder.parameters())
|
344 |
+
if self.add_token:
|
345 |
+
if i == 0:
|
346 |
+
for name, param in embedder.transformer.get_input_embeddings().named_parameters():
|
347 |
+
param.requires_grad = True
|
348 |
+
print(name, "conditional model param")
|
349 |
+
params += [param]
|
350 |
+
else:
|
351 |
+
for name, param in embedder.model.token_embedding.named_parameters():
|
352 |
+
param.requires_grad = True
|
353 |
+
print(name, "conditional model param")
|
354 |
+
params += [param]
|
355 |
+
|
356 |
+
if len(lowlrparams) > 0:
|
357 |
+
print("different optimizer groups")
|
358 |
+
opt = self.instantiate_optimizer_from_config([{'params': params}, {'params': lowlrparams, 'lr': self.multiplier*lr}], lr, self.optimizer_config)
|
359 |
+
else:
|
360 |
+
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
|
361 |
+
if self.scheduler_config is not None:
|
362 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
363 |
+
print("Setting up LambdaLR scheduler...")
|
364 |
+
scheduler = [
|
365 |
+
{
|
366 |
+
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
367 |
+
"interval": "step",
|
368 |
+
"frequency": 1,
|
369 |
+
}
|
370 |
+
]
|
371 |
+
return [opt], scheduler
|
372 |
+
return opt
|
373 |
+
|
374 |
+
@torch.no_grad()
|
375 |
+
def sample(
|
376 |
+
self,
|
377 |
+
cond: Dict,
|
378 |
+
uc: Union[Dict, None] = None,
|
379 |
+
batch_size: int = 16,
|
380 |
+
num_steps=None,
|
381 |
+
randn=None,
|
382 |
+
shape: Union[None, Tuple, List] = None,
|
383 |
+
return_rgb=False,
|
384 |
+
mask=None,
|
385 |
+
init_im=None,
|
386 |
+
**kwargs,
|
387 |
+
):
|
388 |
+
if randn is None:
|
389 |
+
randn = torch.randn(batch_size, *shape)
|
390 |
+
|
391 |
+
denoiser = lambda input, sigma, c: self.denoiser(
|
392 |
+
self.model, input, sigma, c, **kwargs
|
393 |
+
)
|
394 |
+
if mask is not None:
|
395 |
+
samples, rgb_list = self.sampler(denoiser, randn.to(self.device), cond, uc=uc, mask=mask, init_im=init_im, num_steps=num_steps)
|
396 |
+
else:
|
397 |
+
samples, rgb_list = self.sampler(denoiser, randn.to(self.device), cond, uc=uc, num_steps=num_steps)
|
398 |
+
if return_rgb:
|
399 |
+
return samples, rgb_list
|
400 |
+
return samples
|
401 |
+
|
402 |
+
@torch.no_grad()
|
403 |
+
def samplemulti(
|
404 |
+
self,
|
405 |
+
cond,
|
406 |
+
uc=None,
|
407 |
+
batch_size: int = 16,
|
408 |
+
num_steps=None,
|
409 |
+
randn=None,
|
410 |
+
shape: Union[None, Tuple, List] = None,
|
411 |
+
return_rgb=False,
|
412 |
+
mask=None,
|
413 |
+
init_im=None,
|
414 |
+
multikwargs=None,
|
415 |
+
):
|
416 |
+
if randn is None:
|
417 |
+
randn = torch.randn(batch_size, *shape)
|
418 |
+
|
419 |
+
samples, rgb_list = self.sampler(self.denoiser, self.model, randn.to(self.device), cond, uc=uc, num_steps=num_steps, multikwargs=multikwargs)
|
420 |
+
if return_rgb:
|
421 |
+
return samples, rgb_list
|
422 |
+
return samples
|
423 |
+
|
424 |
+
@torch.no_grad()
|
425 |
+
def log_conditionings(self, batch: Dict, n: int, refernce: bool = True) -> Dict:
|
426 |
+
"""
|
427 |
+
Defines heuristics to log different conditionings.
|
428 |
+
These can be lists of strings (text-to-image), tensors, ints, ...
|
429 |
+
"""
|
430 |
+
image_h, image_w = batch[self.input_key].shape[2:]
|
431 |
+
log = dict()
|
432 |
+
|
433 |
+
for embedder in self.conditioner.embedders:
|
434 |
+
if refernce:
|
435 |
+
check = (embedder.input_keys[0] in self.log_keys)
|
436 |
+
else:
|
437 |
+
check = (embedder.input_key in self.log_keys)
|
438 |
+
if (
|
439 |
+
(self.log_keys is None) or check
|
440 |
+
) and not self.no_cond_log:
|
441 |
+
if refernce:
|
442 |
+
x = batch[embedder.input_keys[0]][:n]
|
443 |
+
else:
|
444 |
+
x = batch[embedder.input_key][:n]
|
445 |
+
if isinstance(x, torch.Tensor):
|
446 |
+
if x.dim() == 1:
|
447 |
+
# class-conditional, convert integer to string
|
448 |
+
x = [str(x[i].item()) for i in range(x.shape[0])]
|
449 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
|
450 |
+
elif x.dim() == 2:
|
451 |
+
# size and crop cond and the like
|
452 |
+
x = [
|
453 |
+
"x".join([str(xx) for xx in x[i].tolist()])
|
454 |
+
for i in range(x.shape[0])
|
455 |
+
]
|
456 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
457 |
+
else:
|
458 |
+
raise NotImplementedError()
|
459 |
+
elif isinstance(x, (List, ListConfig)):
|
460 |
+
if isinstance(x[0], str):
|
461 |
+
# strings
|
462 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
463 |
+
else:
|
464 |
+
raise NotImplementedError()
|
465 |
+
else:
|
466 |
+
raise NotImplementedError()
|
467 |
+
if refernce:
|
468 |
+
log[embedder.input_keys[0]] = xc
|
469 |
+
else:
|
470 |
+
log[embedder.input_key] = xc
|
471 |
+
return log
|
472 |
+
|
473 |
+
@torch.no_grad()
|
474 |
+
def log_images(
|
475 |
+
self,
|
476 |
+
batch: Dict,
|
477 |
+
N: int = 8,
|
478 |
+
sample: bool = True,
|
479 |
+
ucg_keys: List[str] = None,
|
480 |
+
**kwargs,
|
481 |
+
) -> Dict:
|
482 |
+
log = dict()
|
483 |
+
|
484 |
+
x, xr, pose, mask, mask_ref, depth, drop_im = self.get_input(batch)
|
485 |
+
|
486 |
+
if xr is not None:
|
487 |
+
conditioner_input_keys = [e.input_keys for e in self.conditioner.embedders]
|
488 |
+
else:
|
489 |
+
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
|
490 |
+
|
491 |
+
if ucg_keys:
|
492 |
+
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
|
493 |
+
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
|
494 |
+
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
|
495 |
+
)
|
496 |
+
else:
|
497 |
+
ucg_keys = conditioner_input_keys
|
498 |
+
|
499 |
+
c, uc = self.conditioner.get_unconditional_conditioning(
|
500 |
+
batch,
|
501 |
+
force_uc_zero_embeddings=ucg_keys
|
502 |
+
if len(self.conditioner.embedders) > 0
|
503 |
+
else [],
|
504 |
+
)
|
505 |
+
|
506 |
+
N = min(x.shape[0], N)
|
507 |
+
x = x.to(self.device)[:N]
|
508 |
+
zr = None
|
509 |
+
if xr is not None:
|
510 |
+
xr = xr.to(self.device)[:N]
|
511 |
+
b, n = xr.shape[0], xr.shape[1]
|
512 |
+
log["reference"] = rearrange(xr, "b n ... -> (b n) ...", b=b, n=n)
|
513 |
+
zr = rearrange(self.encode_first_stage(rearrange(xr, "b n ... -> (b n) ...", b=b, n=n)), "(b n) ... -> b n ...", b=b, n=n)
|
514 |
+
|
515 |
+
log["inputs"] = x
|
516 |
+
b = x.shape[0]
|
517 |
+
if mask is not None:
|
518 |
+
log["mask"] = mask
|
519 |
+
if depth is not None:
|
520 |
+
log["depth"] = depth
|
521 |
+
z = self.encode_first_stage(x)
|
522 |
+
|
523 |
+
if uc is not None:
|
524 |
+
if xr is not None:
|
525 |
+
zr = torch.cat([torch.zeros_like(zr), zr])
|
526 |
+
drop_im = torch.cat([drop_im, drop_im])
|
527 |
+
if isinstance(pose, list):
|
528 |
+
pose = pose[:N]*2
|
529 |
+
else:
|
530 |
+
pose = torch.cat([pose[:N]] * 2)
|
531 |
+
|
532 |
+
sampling_kwargs = {'input_ref':zr}
|
533 |
+
sampling_kwargs['pose'] = pose
|
534 |
+
sampling_kwargs['mask_ref'] = None
|
535 |
+
sampling_kwargs['drop_im'] = drop_im
|
536 |
+
|
537 |
+
log["reconstructions"] = self.decode_first_stage(z)
|
538 |
+
log.update(self.log_conditionings(batch, N, refernce=True if xr is not None else False))
|
539 |
+
|
540 |
+
for k in c:
|
541 |
+
if isinstance(c[k], torch.Tensor):
|
542 |
+
if xr is not None:
|
543 |
+
c[k], uc[k] = map(lambda y: y[k][:(n+1)*N].to(self.device), (c, uc))
|
544 |
+
else:
|
545 |
+
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
546 |
+
if sample:
|
547 |
+
with self.ema_scope("Plotting"):
|
548 |
+
samples, rgb_list = self.sample(
|
549 |
+
c, shape=z.shape[1:], uc=uc, batch_size=N, return_rgb=True, **sampling_kwargs
|
550 |
+
)
|
551 |
+
samples = self.decode_first_stage(samples)
|
552 |
+
log["samples"] = samples
|
553 |
+
if len(rgb_list) > 0:
|
554 |
+
size = int(math.sqrt(rgb_list[0].size(1)))
|
555 |
+
log["predicted_rgb"] = rgb_list[0].reshape(-1, size, size, 3).permute(0, 3, 1, 2)
|
556 |
+
return log
|
sgm/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .encoders.modules import GeneralConditioner
|
2 |
+
|
3 |
+
UNCONDITIONAL_CONFIG = {
|
4 |
+
"target": "sgm.modules.GeneralConditioner",
|
5 |
+
"params": {"emb_models": []},
|
6 |
+
}
|
sgm/modules/attention.py
ADDED
@@ -0,0 +1,1202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
import itertools
|
4 |
+
from inspect import isfunction
|
5 |
+
from typing import Any, Optional
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
from packaging import version
|
11 |
+
from torch import nn
|
12 |
+
from .diffusionmodules.util import checkpoint
|
13 |
+
from torch.autograd import Function
|
14 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
15 |
+
|
16 |
+
from ..modules.diffusionmodules.util import zero_module
|
17 |
+
from ..modules.nerfsd_pytorch3d import NerfSDModule, VolRender
|
18 |
+
|
19 |
+
logpy = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
22 |
+
SDP_IS_AVAILABLE = True
|
23 |
+
from torch.backends.cuda import SDPBackend, sdp_kernel
|
24 |
+
|
25 |
+
BACKEND_MAP = {
|
26 |
+
SDPBackend.MATH: {
|
27 |
+
"enable_math": True,
|
28 |
+
"enable_flash": False,
|
29 |
+
"enable_mem_efficient": False,
|
30 |
+
},
|
31 |
+
SDPBackend.FLASH_ATTENTION: {
|
32 |
+
"enable_math": False,
|
33 |
+
"enable_flash": True,
|
34 |
+
"enable_mem_efficient": False,
|
35 |
+
},
|
36 |
+
SDPBackend.EFFICIENT_ATTENTION: {
|
37 |
+
"enable_math": False,
|
38 |
+
"enable_flash": False,
|
39 |
+
"enable_mem_efficient": True,
|
40 |
+
},
|
41 |
+
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
|
42 |
+
}
|
43 |
+
else:
|
44 |
+
from contextlib import nullcontext
|
45 |
+
|
46 |
+
SDP_IS_AVAILABLE = False
|
47 |
+
sdp_kernel = nullcontext
|
48 |
+
BACKEND_MAP = {}
|
49 |
+
logpy.warn(
|
50 |
+
f"No SDP backend available, likely because you are running in pytorch "
|
51 |
+
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
|
52 |
+
f"You might want to consider upgrading."
|
53 |
+
)
|
54 |
+
|
55 |
+
try:
|
56 |
+
import xformers
|
57 |
+
import xformers.ops
|
58 |
+
|
59 |
+
XFORMERS_IS_AVAILABLE = True
|
60 |
+
except:
|
61 |
+
XFORMERS_IS_AVAILABLE = False
|
62 |
+
logpy.warn("no module 'xformers'. Processing without...")
|
63 |
+
|
64 |
+
|
65 |
+
def exists(val):
|
66 |
+
return val is not None
|
67 |
+
|
68 |
+
|
69 |
+
def uniq(arr):
|
70 |
+
return {el: True for el in arr}.keys()
|
71 |
+
|
72 |
+
|
73 |
+
def default(val, d):
|
74 |
+
if exists(val):
|
75 |
+
return val
|
76 |
+
return d() if isfunction(d) else d
|
77 |
+
|
78 |
+
|
79 |
+
def max_neg_value(t):
|
80 |
+
return -torch.finfo(t.dtype).max
|
81 |
+
|
82 |
+
|
83 |
+
def init_(tensor):
|
84 |
+
dim = tensor.shape[-1]
|
85 |
+
std = 1 / math.sqrt(dim)
|
86 |
+
tensor.uniform_(-std, std)
|
87 |
+
return tensor
|
88 |
+
|
89 |
+
|
90 |
+
# feedforward
|
91 |
+
class GEGLU(nn.Module):
|
92 |
+
def __init__(self, dim_in, dim_out):
|
93 |
+
super().__init__()
|
94 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
98 |
+
return x * F.gelu(gate)
|
99 |
+
|
100 |
+
|
101 |
+
class FeedForward(nn.Module):
|
102 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
103 |
+
super().__init__()
|
104 |
+
inner_dim = int(dim * mult)
|
105 |
+
dim_out = default(dim_out, dim)
|
106 |
+
project_in = (
|
107 |
+
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
108 |
+
if not glu
|
109 |
+
else GEGLU(dim, inner_dim)
|
110 |
+
)
|
111 |
+
|
112 |
+
self.net = nn.Sequential(
|
113 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
return self.net(x)
|
118 |
+
|
119 |
+
|
120 |
+
def Normalize(in_channels):
|
121 |
+
return torch.nn.GroupNorm(
|
122 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
class LinearAttention(nn.Module):
|
127 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
128 |
+
super().__init__()
|
129 |
+
self.heads = heads
|
130 |
+
hidden_dim = dim_head * heads
|
131 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
132 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
b, c, h, w = x.shape
|
136 |
+
qkv = self.to_qkv(x)
|
137 |
+
q, k, v = rearrange(
|
138 |
+
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
139 |
+
)
|
140 |
+
k = k.softmax(dim=-1)
|
141 |
+
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
142 |
+
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
143 |
+
out = rearrange(
|
144 |
+
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
145 |
+
)
|
146 |
+
return self.to_out(out)
|
147 |
+
|
148 |
+
|
149 |
+
class SpatialSelfAttention(nn.Module):
|
150 |
+
def __init__(self, in_channels):
|
151 |
+
super().__init__()
|
152 |
+
self.in_channels = in_channels
|
153 |
+
|
154 |
+
self.norm = Normalize(in_channels)
|
155 |
+
self.q = torch.nn.Conv2d(
|
156 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
157 |
+
)
|
158 |
+
self.k = torch.nn.Conv2d(
|
159 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
160 |
+
)
|
161 |
+
self.v = torch.nn.Conv2d(
|
162 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
163 |
+
)
|
164 |
+
self.proj_out = torch.nn.Conv2d(
|
165 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
166 |
+
)
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
h_ = x
|
170 |
+
h_ = self.norm(h_)
|
171 |
+
q = self.q(h_)
|
172 |
+
k = self.k(h_)
|
173 |
+
v = self.v(h_)
|
174 |
+
|
175 |
+
# compute attention
|
176 |
+
b, c, h, w = q.shape
|
177 |
+
q = rearrange(q, "b c h w -> b (h w) c")
|
178 |
+
k = rearrange(k, "b c h w -> b c (h w)")
|
179 |
+
w_ = torch.einsum("bij,bjk->bik", q, k)
|
180 |
+
|
181 |
+
w_ = w_ * (int(c) ** (-0.5))
|
182 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
183 |
+
|
184 |
+
# attend to values
|
185 |
+
v = rearrange(v, "b c h w -> b c (h w)")
|
186 |
+
w_ = rearrange(w_, "b i j -> b j i")
|
187 |
+
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
188 |
+
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
189 |
+
h_ = self.proj_out(h_)
|
190 |
+
|
191 |
+
return x + h_
|
192 |
+
|
193 |
+
|
194 |
+
class _TruncExp(Function): # pylint: disable=abstract-method
|
195 |
+
# Implementation from torch-ngp:
|
196 |
+
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
|
197 |
+
@staticmethod
|
198 |
+
@custom_fwd(cast_inputs=torch.float32)
|
199 |
+
def forward(ctx, x): # pylint: disable=arguments-differ
|
200 |
+
ctx.save_for_backward(x)
|
201 |
+
return torch.exp(x)
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
@custom_bwd
|
205 |
+
def backward(ctx, g): # pylint: disable=arguments-differ
|
206 |
+
x = ctx.saved_tensors[0]
|
207 |
+
return g * torch.exp(x.clamp(-15, 15))
|
208 |
+
|
209 |
+
|
210 |
+
trunc_exp = _TruncExp.apply
|
211 |
+
"""Same as torch.exp, but with the backward pass clipped to prevent vanishing/exploding
|
212 |
+
gradients."""
|
213 |
+
|
214 |
+
|
215 |
+
class CrossAttention(nn.Module):
|
216 |
+
def __init__(
|
217 |
+
self,
|
218 |
+
query_dim,
|
219 |
+
context_dim=None,
|
220 |
+
heads=8,
|
221 |
+
dim_head=64,
|
222 |
+
dropout=0.0,
|
223 |
+
backend=None,
|
224 |
+
):
|
225 |
+
super().__init__()
|
226 |
+
inner_dim = dim_head * heads
|
227 |
+
context_dim = default(context_dim, query_dim)
|
228 |
+
|
229 |
+
self.scale = dim_head**-0.5
|
230 |
+
self.heads = heads
|
231 |
+
|
232 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
233 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
234 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
235 |
+
|
236 |
+
self.to_out = nn.Sequential(
|
237 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
238 |
+
)
|
239 |
+
self.backend = backend
|
240 |
+
|
241 |
+
def forward(
|
242 |
+
self,
|
243 |
+
x,
|
244 |
+
context=None,
|
245 |
+
mask=None,
|
246 |
+
additional_tokens=None,
|
247 |
+
n_times_crossframe_attn_in_self=0,
|
248 |
+
):
|
249 |
+
h = self.heads
|
250 |
+
|
251 |
+
if additional_tokens is not None:
|
252 |
+
# get the number of masked tokens at the beginning of the output sequence
|
253 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
254 |
+
# add additional token
|
255 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
256 |
+
|
257 |
+
q = self.to_q(x)
|
258 |
+
context = default(context, x)
|
259 |
+
k = self.to_k(context)
|
260 |
+
v = self.to_v(context)
|
261 |
+
|
262 |
+
if n_times_crossframe_attn_in_self:
|
263 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
264 |
+
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
265 |
+
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
|
266 |
+
k = repeat(
|
267 |
+
k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
268 |
+
)
|
269 |
+
v = repeat(
|
270 |
+
v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
271 |
+
)
|
272 |
+
|
273 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
274 |
+
|
275 |
+
## old
|
276 |
+
"""
|
277 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
278 |
+
del q, k
|
279 |
+
|
280 |
+
if exists(mask):
|
281 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
282 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
283 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
284 |
+
sim.masked_fill_(~mask, max_neg_value)
|
285 |
+
|
286 |
+
# attention, what we cannot get enough of
|
287 |
+
sim = sim.softmax(dim=-1)
|
288 |
+
|
289 |
+
out = einsum('b i j, b j d -> b i d', sim, v)
|
290 |
+
"""
|
291 |
+
## new
|
292 |
+
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
293 |
+
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
294 |
+
out = F.scaled_dot_product_attention(
|
295 |
+
q, k, v, attn_mask=mask
|
296 |
+
) # scale is dim_head ** -0.5 per default
|
297 |
+
|
298 |
+
del q, k, v
|
299 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
300 |
+
|
301 |
+
if additional_tokens is not None:
|
302 |
+
# remove additional token
|
303 |
+
out = out[:, n_tokens_to_mask:]
|
304 |
+
return self.to_out(out)
|
305 |
+
|
306 |
+
|
307 |
+
class MemoryEfficientCrossAttention(nn.Module):
|
308 |
+
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
309 |
+
def __init__(
|
310 |
+
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, add_lora=False, **kwargs
|
311 |
+
):
|
312 |
+
super().__init__()
|
313 |
+
logpy.debug(
|
314 |
+
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
|
315 |
+
f"context_dim is {context_dim} and using {heads} heads with a "
|
316 |
+
f"dimension of {dim_head}."
|
317 |
+
)
|
318 |
+
inner_dim = dim_head * heads
|
319 |
+
context_dim = default(context_dim, query_dim)
|
320 |
+
|
321 |
+
self.heads = heads
|
322 |
+
self.dim_head = dim_head
|
323 |
+
self.add_lora = add_lora
|
324 |
+
|
325 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
326 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
327 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
328 |
+
|
329 |
+
self.to_out = nn.Sequential(
|
330 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
331 |
+
)
|
332 |
+
if add_lora:
|
333 |
+
r = 32
|
334 |
+
self.to_q_attn3_down = nn.Linear(query_dim, r, bias=False)
|
335 |
+
self.to_q_attn3_up = zero_module(nn.Linear(r, inner_dim, bias=False))
|
336 |
+
self.to_k_attn3_down = nn.Linear(context_dim, r, bias=False)
|
337 |
+
self.to_k_attn3_up = zero_module(nn.Linear(r, inner_dim, bias=False))
|
338 |
+
self.to_v_attn3_down = nn.Linear(context_dim, r, bias=False)
|
339 |
+
self.to_v_attn3_up = zero_module(nn.Linear(r, inner_dim, bias=False))
|
340 |
+
self.to_o_attn3_down = nn.Linear(inner_dim, r, bias=False)
|
341 |
+
self.to_o_attn3_up = zero_module(nn.Linear(r, query_dim, bias=False))
|
342 |
+
self.dropoutq = nn.Dropout(0.1)
|
343 |
+
self.dropoutk = nn.Dropout(0.1)
|
344 |
+
self.dropoutv = nn.Dropout(0.1)
|
345 |
+
self.dropouto = nn.Dropout(0.1)
|
346 |
+
|
347 |
+
nn.init.normal_(self.to_q_attn3_down.weight, std=1 / r)
|
348 |
+
nn.init.normal_(self.to_k_attn3_down.weight, std=1 / r)
|
349 |
+
nn.init.normal_(self.to_v_attn3_down.weight, std=1 / r)
|
350 |
+
nn.init.normal_(self.to_o_attn3_down.weight, std=1 / r)
|
351 |
+
|
352 |
+
self.attention_op: Optional[Any] = None
|
353 |
+
|
354 |
+
def forward(
|
355 |
+
self,
|
356 |
+
x,
|
357 |
+
context=None,
|
358 |
+
mask=None,
|
359 |
+
additional_tokens=None,
|
360 |
+
n_times_crossframe_attn_in_self=0,
|
361 |
+
):
|
362 |
+
if additional_tokens is not None:
|
363 |
+
# get the number of masked tokens at the beginning of the output sequence
|
364 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
365 |
+
# add additional token
|
366 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
367 |
+
|
368 |
+
context_k = context # b, n, c, h, w
|
369 |
+
|
370 |
+
q = self.to_q(x)
|
371 |
+
context = default(context, x)
|
372 |
+
context_k = default(context_k, x)
|
373 |
+
k = self.to_k(context_k)
|
374 |
+
v = self.to_v(context_k)
|
375 |
+
if self.add_lora:
|
376 |
+
q += self.dropoutq(self.to_q_attn3_up(self.to_q_attn3_down(x)))
|
377 |
+
k += self.dropoutk(self.to_k_attn3_up(self.to_k_attn3_down(context_k)))
|
378 |
+
v += self.dropoutv(self.to_v_attn3_up(self.to_v_attn3_down(context_k)))
|
379 |
+
|
380 |
+
if n_times_crossframe_attn_in_self:
|
381 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
382 |
+
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
383 |
+
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
|
384 |
+
k = repeat(
|
385 |
+
k[::n_times_crossframe_attn_in_self],
|
386 |
+
"b ... -> (b n) ...",
|
387 |
+
n=n_times_crossframe_attn_in_self,
|
388 |
+
)
|
389 |
+
v = repeat(
|
390 |
+
v[::n_times_crossframe_attn_in_self],
|
391 |
+
"b ... -> (b n) ...",
|
392 |
+
n=n_times_crossframe_attn_in_self,
|
393 |
+
)
|
394 |
+
|
395 |
+
b, _, _ = q.shape
|
396 |
+
q, k, v = map(
|
397 |
+
lambda t: t.unsqueeze(3)
|
398 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
399 |
+
.permute(0, 2, 1, 3)
|
400 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
401 |
+
.contiguous(),
|
402 |
+
(q, k, v),
|
403 |
+
)
|
404 |
+
|
405 |
+
attn_bias = None
|
406 |
+
|
407 |
+
# actually compute the attention, what we cannot get enough of
|
408 |
+
out = xformers.ops.memory_efficient_attention(
|
409 |
+
q, k, v, attn_bias=attn_bias, op=self.attention_op
|
410 |
+
)
|
411 |
+
|
412 |
+
# TODO: Use this directly in the attention operation, as a bias
|
413 |
+
if exists(mask):
|
414 |
+
raise NotImplementedError
|
415 |
+
out = (
|
416 |
+
out.unsqueeze(0)
|
417 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
418 |
+
.permute(0, 2, 1, 3)
|
419 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
420 |
+
)
|
421 |
+
if additional_tokens is not None:
|
422 |
+
# remove additional token
|
423 |
+
out = out[:, n_tokens_to_mask:]
|
424 |
+
final = self.to_out(out)
|
425 |
+
if self.add_lora:
|
426 |
+
final += self.dropouto(self.to_o_attn3_up(self.to_o_attn3_down(out)))
|
427 |
+
return final
|
428 |
+
|
429 |
+
|
430 |
+
class BasicTransformerBlock(nn.Module):
|
431 |
+
ATTENTION_MODES = {
|
432 |
+
"softmax": CrossAttention, # vanilla attention
|
433 |
+
"softmax-xformers": MemoryEfficientCrossAttention, # ampere
|
434 |
+
}
|
435 |
+
|
436 |
+
def __init__(
|
437 |
+
self,
|
438 |
+
dim,
|
439 |
+
n_heads,
|
440 |
+
d_head,
|
441 |
+
dropout=0.0,
|
442 |
+
context_dim=None,
|
443 |
+
gated_ff=True,
|
444 |
+
checkpoint=True,
|
445 |
+
disable_self_attn=False,
|
446 |
+
attn_mode="softmax",
|
447 |
+
sdp_backend=None,
|
448 |
+
image_cross=False,
|
449 |
+
far=2,
|
450 |
+
num_samples=32,
|
451 |
+
add_lora=False,
|
452 |
+
rgb_predict=False,
|
453 |
+
mode='pixel-nerf',
|
454 |
+
average=False,
|
455 |
+
num_freqs=16,
|
456 |
+
use_prev_weights_imp_sample=False,
|
457 |
+
imp_sample_next_step=False,
|
458 |
+
stratified=False,
|
459 |
+
imp_sampling_percent=0.9,
|
460 |
+
near_plane=0.
|
461 |
+
):
|
462 |
+
|
463 |
+
super().__init__()
|
464 |
+
assert attn_mode in self.ATTENTION_MODES
|
465 |
+
self.add_lora = add_lora
|
466 |
+
self.image_cross = image_cross
|
467 |
+
self.rgb_predict = rgb_predict
|
468 |
+
self.use_prev_weights_imp_sample = use_prev_weights_imp_sample
|
469 |
+
self.imp_sample_next_step = imp_sample_next_step
|
470 |
+
self.rendered_feat = None
|
471 |
+
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
472 |
+
logpy.warn(
|
473 |
+
f"Attention mode '{attn_mode}' is not available. Falling "
|
474 |
+
f"back to native attention. This is not a problem in "
|
475 |
+
f"Pytorch >= 2.0. FYI, you are running with PyTorch "
|
476 |
+
f"version {torch.__version__}."
|
477 |
+
)
|
478 |
+
attn_mode = "softmax"
|
479 |
+
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
480 |
+
logpy.warn(
|
481 |
+
"We do not support vanilla attention anymore, as it is too "
|
482 |
+
"expensive. Sorry."
|
483 |
+
)
|
484 |
+
if not XFORMERS_IS_AVAILABLE:
|
485 |
+
assert (
|
486 |
+
False
|
487 |
+
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
488 |
+
else:
|
489 |
+
logpy.info("Falling back to xformers efficient attention.")
|
490 |
+
attn_mode = "softmax-xformers"
|
491 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
492 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
493 |
+
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
|
494 |
+
else:
|
495 |
+
assert sdp_backend is None
|
496 |
+
self.disable_self_attn = disable_self_attn
|
497 |
+
self.attn1 = attn_cls(
|
498 |
+
query_dim=dim,
|
499 |
+
heads=n_heads,
|
500 |
+
dim_head=d_head,
|
501 |
+
dropout=dropout,
|
502 |
+
add_lora=self.add_lora,
|
503 |
+
context_dim=context_dim if self.disable_self_attn else None,
|
504 |
+
backend=sdp_backend,
|
505 |
+
) # is a self-attention if not self.disable_self_attn
|
506 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
507 |
+
self.attn2 = attn_cls(
|
508 |
+
query_dim=dim,
|
509 |
+
context_dim=context_dim,
|
510 |
+
heads=n_heads,
|
511 |
+
dim_head=d_head,
|
512 |
+
dropout=dropout,
|
513 |
+
add_lora=self.add_lora,
|
514 |
+
backend=sdp_backend,
|
515 |
+
) # is self-attn if context is none
|
516 |
+
if image_cross:
|
517 |
+
self.pose_emb_layers = nn.Linear(2*dim, dim, bias=False)
|
518 |
+
nn.init.eye_(self.pose_emb_layers.weight)
|
519 |
+
self.pose_featurenerf = NerfSDModule(mode=mode,
|
520 |
+
out_channels=dim,
|
521 |
+
far_plane=far,
|
522 |
+
num_samples=num_samples,
|
523 |
+
rgb_predict=rgb_predict,
|
524 |
+
average=average,
|
525 |
+
num_freqs=num_freqs,
|
526 |
+
stratified=stratified,
|
527 |
+
imp_sampling_percent=imp_sampling_percent,
|
528 |
+
near_plane=near_plane,
|
529 |
+
)
|
530 |
+
|
531 |
+
self.renderer = VolRender()
|
532 |
+
|
533 |
+
self.norm1 = nn.LayerNorm(dim)
|
534 |
+
self.norm2 = nn.LayerNorm(dim)
|
535 |
+
self.norm3 = nn.LayerNorm(dim)
|
536 |
+
self.checkpoint = checkpoint
|
537 |
+
if self.checkpoint:
|
538 |
+
logpy.debug(f"{self.__class__.__name__} is using checkpointing")
|
539 |
+
|
540 |
+
def forward(
|
541 |
+
self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
542 |
+
):
|
543 |
+
kwargs = {"x": x}
|
544 |
+
|
545 |
+
if context is not None:
|
546 |
+
kwargs.update({"context": context})
|
547 |
+
|
548 |
+
if context_ref is not None:
|
549 |
+
kwargs.update({"context_ref": context_ref})
|
550 |
+
|
551 |
+
if pose is not None:
|
552 |
+
kwargs.update({"pose": pose})
|
553 |
+
|
554 |
+
if mask_ref is not None:
|
555 |
+
kwargs.update({"mask_ref": mask_ref})
|
556 |
+
|
557 |
+
if prev_weights is not None:
|
558 |
+
kwargs.update({"prev_weights": prev_weights})
|
559 |
+
|
560 |
+
if additional_tokens is not None:
|
561 |
+
kwargs.update({"additional_tokens": additional_tokens})
|
562 |
+
|
563 |
+
if n_times_crossframe_attn_in_self:
|
564 |
+
kwargs.update(
|
565 |
+
{"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
|
566 |
+
)
|
567 |
+
|
568 |
+
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
569 |
+
return checkpoint(
|
570 |
+
self._forward, (x, context, context_ref, pose, mask_ref, prev_weights), self.parameters(), self.checkpoint
|
571 |
+
)
|
572 |
+
|
573 |
+
def reference_attn(self, x, context_ref, context, pose, prev_weights, mask_ref):
|
574 |
+
feats, sigmas, dists, _, predicted_rgb, sigmas_uniform, dists_uniform = self.pose_featurenerf(pose,
|
575 |
+
context_ref,
|
576 |
+
mask_ref,
|
577 |
+
prev_weights=prev_weights if self.use_prev_weights_imp_sample else None,
|
578 |
+
imp_sample_next_step=self.imp_sample_next_step)
|
579 |
+
|
580 |
+
b, hw, d = feats.size()[:3]
|
581 |
+
feats = rearrange(feats, "b hw d ... -> b (hw d) ...")
|
582 |
+
|
583 |
+
feats = (
|
584 |
+
self.attn2(
|
585 |
+
self.norm2(feats), context=context,
|
586 |
+
)
|
587 |
+
+ feats
|
588 |
+
)
|
589 |
+
|
590 |
+
feats = rearrange(feats, "b (hw d) ... -> b hw d ...", hw=hw, d=d)
|
591 |
+
|
592 |
+
sigmas_ = trunc_exp(sigmas)
|
593 |
+
if sigmas_uniform is not None:
|
594 |
+
sigmas_uniform = trunc_exp(sigmas_uniform)
|
595 |
+
|
596 |
+
context_ref, fg_mask, alphas, weights_uniform, predicted_rgb = self.renderer(feats, sigmas_, dists, densities_uniform=sigmas_uniform, dists_uniform=dists_uniform, return_weights_uniform=True, rgb=F.sigmoid(predicted_rgb) if predicted_rgb is not None else None)
|
597 |
+
if self.use_prev_weights_imp_sample:
|
598 |
+
prev_weights = weights_uniform
|
599 |
+
|
600 |
+
return context_ref, fg_mask, prev_weights, alphas, predicted_rgb
|
601 |
+
|
602 |
+
def _forward(
|
603 |
+
self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
604 |
+
):
|
605 |
+
fg_mask = None
|
606 |
+
weights = None
|
607 |
+
alphas = None
|
608 |
+
predicted_rgb = None
|
609 |
+
xref = None
|
610 |
+
|
611 |
+
x = (
|
612 |
+
self.attn1(
|
613 |
+
self.norm1(x),
|
614 |
+
context=context if self.disable_self_attn else None,
|
615 |
+
additional_tokens=additional_tokens,
|
616 |
+
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
617 |
+
if not self.disable_self_attn
|
618 |
+
else 0,
|
619 |
+
)
|
620 |
+
+ x
|
621 |
+
)
|
622 |
+
x = (
|
623 |
+
self.attn2(
|
624 |
+
self.norm2(x), context=context, additional_tokens=additional_tokens
|
625 |
+
)
|
626 |
+
+ x
|
627 |
+
)
|
628 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.float32):
|
629 |
+
if context_ref is not None:
|
630 |
+
xref, fg_mask, weights, alphas, predicted_rgb = self.reference_attn(x,
|
631 |
+
rearrange(context_ref, "(b n) ... -> b n ...", b=x.size(0), n=context_ref.size(0) // x.size(0)),
|
632 |
+
context,
|
633 |
+
pose,
|
634 |
+
prev_weights,
|
635 |
+
mask_ref)
|
636 |
+
x = self.pose_emb_layers(torch.cat([x, xref], -1))
|
637 |
+
|
638 |
+
x = self.ff(self.norm3(x)) + x
|
639 |
+
return x, fg_mask, weights, alphas, predicted_rgb
|
640 |
+
|
641 |
+
|
642 |
+
class BasicTransformerSingleLayerBlock(nn.Module):
|
643 |
+
ATTENTION_MODES = {
|
644 |
+
"softmax": CrossAttention, # vanilla attention
|
645 |
+
"softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
|
646 |
+
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
|
647 |
+
}
|
648 |
+
|
649 |
+
def __init__(
|
650 |
+
self,
|
651 |
+
dim,
|
652 |
+
n_heads,
|
653 |
+
d_head,
|
654 |
+
dropout=0.0,
|
655 |
+
context_dim=None,
|
656 |
+
gated_ff=True,
|
657 |
+
checkpoint=True,
|
658 |
+
attn_mode="softmax",
|
659 |
+
):
|
660 |
+
super().__init__()
|
661 |
+
assert attn_mode in self.ATTENTION_MODES
|
662 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
663 |
+
self.attn1 = attn_cls(
|
664 |
+
query_dim=dim,
|
665 |
+
heads=n_heads,
|
666 |
+
dim_head=d_head,
|
667 |
+
dropout=dropout,
|
668 |
+
context_dim=context_dim,
|
669 |
+
)
|
670 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
671 |
+
self.norm1 = nn.LayerNorm(dim)
|
672 |
+
self.norm2 = nn.LayerNorm(dim)
|
673 |
+
self.checkpoint = checkpoint
|
674 |
+
|
675 |
+
def forward(self, x, context=None):
|
676 |
+
return checkpoint(
|
677 |
+
self._forward, (x, context), self.parameters(), self.checkpoint
|
678 |
+
)
|
679 |
+
|
680 |
+
def _forward(self, x, context=None):
|
681 |
+
x = self.attn1(self.norm1(x), context=context) + x
|
682 |
+
x = self.ff(self.norm2(x)) + x
|
683 |
+
return x
|
684 |
+
|
685 |
+
|
686 |
+
class SpatialTransformer(nn.Module):
|
687 |
+
"""
|
688 |
+
Transformer block for image-like data.
|
689 |
+
First, project the input (aka embedding)
|
690 |
+
and reshape to b, t, d.
|
691 |
+
Then apply standard transformer action.
|
692 |
+
Finally, reshape to image
|
693 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
694 |
+
"""
|
695 |
+
|
696 |
+
def __init__(
|
697 |
+
self,
|
698 |
+
in_channels,
|
699 |
+
n_heads,
|
700 |
+
d_head,
|
701 |
+
depth=1,
|
702 |
+
dropout=0.0,
|
703 |
+
context_dim=None,
|
704 |
+
disable_self_attn=False,
|
705 |
+
use_linear=False,
|
706 |
+
attn_type="softmax",
|
707 |
+
use_checkpoint=True,
|
708 |
+
# sdp_backend=SDPBackend.FLASH_ATTENTION
|
709 |
+
sdp_backend=None,
|
710 |
+
image_cross=True,
|
711 |
+
rgb_predict=False,
|
712 |
+
far=2,
|
713 |
+
num_samples=32,
|
714 |
+
add_lora=False,
|
715 |
+
mode='feature-nerf',
|
716 |
+
average=False,
|
717 |
+
num_freqs=16,
|
718 |
+
use_prev_weights_imp_sample=False,
|
719 |
+
stratified=False,
|
720 |
+
poscontrol_interval=4,
|
721 |
+
imp_sampling_percent=0.9,
|
722 |
+
near_plane=0.
|
723 |
+
):
|
724 |
+
super().__init__()
|
725 |
+
logpy.debug(
|
726 |
+
f"constructing {self.__class__.__name__} of depth {depth} w/ "
|
727 |
+
f"{in_channels} channels and {n_heads} heads."
|
728 |
+
)
|
729 |
+
from omegaconf import ListConfig
|
730 |
+
|
731 |
+
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
732 |
+
context_dim = [context_dim]
|
733 |
+
if exists(context_dim) and isinstance(context_dim, list):
|
734 |
+
if depth != len(context_dim):
|
735 |
+
logpy.warn(
|
736 |
+
f"{self.__class__.__name__}: Found context dims "
|
737 |
+
f"{context_dim} of depth {len(context_dim)}, which does not "
|
738 |
+
f"match the specified 'depth' of {depth}. Setting context_dim "
|
739 |
+
f"to {depth * [context_dim[0]]} now."
|
740 |
+
)
|
741 |
+
# depth does not match context dims.
|
742 |
+
assert all(
|
743 |
+
map(lambda x: x == context_dim[0], context_dim)
|
744 |
+
), "need homogenous context_dim to match depth automatically"
|
745 |
+
context_dim = depth * [context_dim[0]]
|
746 |
+
elif context_dim is None:
|
747 |
+
context_dim = [None] * depth
|
748 |
+
self.in_channels = in_channels
|
749 |
+
inner_dim = n_heads * d_head
|
750 |
+
self.norm = Normalize(in_channels)
|
751 |
+
|
752 |
+
self.image_cross = image_cross
|
753 |
+
self.poscontrol_interval = poscontrol_interval
|
754 |
+
|
755 |
+
if not use_linear:
|
756 |
+
self.proj_in = nn.Conv2d(
|
757 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
758 |
+
)
|
759 |
+
else:
|
760 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
761 |
+
|
762 |
+
self.transformer_blocks = nn.ModuleList(
|
763 |
+
[
|
764 |
+
BasicTransformerBlock(
|
765 |
+
inner_dim,
|
766 |
+
n_heads,
|
767 |
+
d_head,
|
768 |
+
dropout=dropout,
|
769 |
+
context_dim=context_dim[d],
|
770 |
+
disable_self_attn=disable_self_attn,
|
771 |
+
attn_mode=attn_type,
|
772 |
+
checkpoint=use_checkpoint,
|
773 |
+
sdp_backend=sdp_backend,
|
774 |
+
image_cross=self.image_cross and (d % poscontrol_interval == 0),
|
775 |
+
far=far,
|
776 |
+
num_samples=num_samples,
|
777 |
+
add_lora=add_lora and self.image_cross and (d % poscontrol_interval == 0),
|
778 |
+
rgb_predict=rgb_predict,
|
779 |
+
mode=mode,
|
780 |
+
average=average,
|
781 |
+
num_freqs=num_freqs,
|
782 |
+
use_prev_weights_imp_sample=use_prev_weights_imp_sample,
|
783 |
+
imp_sample_next_step=(use_prev_weights_imp_sample and self.image_cross and (d % poscontrol_interval == 0) and depth >= poscontrol_interval and d < (depth // poscontrol_interval) * poscontrol_interval ),
|
784 |
+
stratified=stratified,
|
785 |
+
imp_sampling_percent=imp_sampling_percent,
|
786 |
+
near_plane=near_plane,
|
787 |
+
)
|
788 |
+
for d in range(depth)
|
789 |
+
]
|
790 |
+
)
|
791 |
+
if not use_linear:
|
792 |
+
self.proj_out = zero_module(
|
793 |
+
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
794 |
+
)
|
795 |
+
else:
|
796 |
+
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
797 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
798 |
+
self.use_linear = use_linear
|
799 |
+
|
800 |
+
def forward(self, x, xr, context=None, contextr=None, pose=None, mask_ref=None, prev_weights=None):
|
801 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
802 |
+
if xr is None:
|
803 |
+
if not isinstance(context, list):
|
804 |
+
context = [context]
|
805 |
+
b, c, h, w = x.shape
|
806 |
+
x_in = x
|
807 |
+
x = self.norm(x)
|
808 |
+
if not self.use_linear:
|
809 |
+
x = self.proj_in(x)
|
810 |
+
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
811 |
+
if self.use_linear:
|
812 |
+
x = self.proj_in(x)
|
813 |
+
for i, block in enumerate(self.transformer_blocks):
|
814 |
+
if i > 0 and len(context) == 1:
|
815 |
+
i = 0 # use same context for each block
|
816 |
+
x, _, _, _, _ = block(x, context=context[i])
|
817 |
+
if self.use_linear:
|
818 |
+
x = self.proj_out(x)
|
819 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
820 |
+
if not self.use_linear:
|
821 |
+
x = self.proj_out(x)
|
822 |
+
return x + x_in, None, None, None, None, None
|
823 |
+
else:
|
824 |
+
if not isinstance(context, list):
|
825 |
+
context = [context]
|
826 |
+
contextr = [contextr]
|
827 |
+
b, c, h, w = x.shape
|
828 |
+
b1, _, _, _ = xr.shape
|
829 |
+
x_in = x
|
830 |
+
xr_in = xr
|
831 |
+
fg_masks = []
|
832 |
+
alphas = []
|
833 |
+
rgbs = []
|
834 |
+
|
835 |
+
x = self.norm(x)
|
836 |
+
with torch.no_grad():
|
837 |
+
xr = self.norm(xr)
|
838 |
+
|
839 |
+
if not self.use_linear:
|
840 |
+
x = self.proj_in(x)
|
841 |
+
with torch.no_grad():
|
842 |
+
xr = self.proj_in(xr)
|
843 |
+
|
844 |
+
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
845 |
+
xr = rearrange(xr, "b1 c h w -> b1 (h w) c").contiguous()
|
846 |
+
if self.use_linear:
|
847 |
+
x = self.proj_in(x)
|
848 |
+
with torch.no_grad():
|
849 |
+
xr = self.proj_in(xr)
|
850 |
+
|
851 |
+
prev_weights = None
|
852 |
+
counter = 0
|
853 |
+
for i, block in enumerate(self.transformer_blocks):
|
854 |
+
if i > 0 and len(context) == 1:
|
855 |
+
i = 0 # use same context for each block
|
856 |
+
if self.image_cross and (counter % self.poscontrol_interval == 0):
|
857 |
+
with torch.no_grad():
|
858 |
+
xr, _, _, _, _ = block(xr, context=contextr[i])
|
859 |
+
x, fg_mask, weights, alpha, rgb = block(x, context=context[i], context_ref=xr.detach(), pose=pose, mask_ref=mask_ref, prev_weights=prev_weights)
|
860 |
+
prev_weights = weights
|
861 |
+
fg_masks.append(fg_mask)
|
862 |
+
if alpha is not None:
|
863 |
+
alphas.append(alpha)
|
864 |
+
if rgb is not None:
|
865 |
+
rgbs.append(rgb)
|
866 |
+
else:
|
867 |
+
with torch.no_grad():
|
868 |
+
xr, _, _, _, _ = block(xr, context=contextr[i])
|
869 |
+
x, _, _, _, _ = block(x, context=context[i])
|
870 |
+
counter += 1
|
871 |
+
if self.use_linear:
|
872 |
+
x = self.proj_out(x)
|
873 |
+
with torch.no_grad():
|
874 |
+
xr = self.proj_out(xr)
|
875 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
876 |
+
xr = rearrange(xr, "b1 (h w) c -> b1 c h w", h=h, w=w).contiguous()
|
877 |
+
if not self.use_linear:
|
878 |
+
x = self.proj_out(x)
|
879 |
+
with torch.no_grad():
|
880 |
+
xr = self.proj_out(xr)
|
881 |
+
if len(fg_masks) > 0:
|
882 |
+
if len(rgbs) <= 0:
|
883 |
+
rgbs = None
|
884 |
+
if len(alphas) <= 0:
|
885 |
+
alphas = None
|
886 |
+
return x + x_in, (xr + xr_in).detach(), fg_masks, prev_weights, alphas, rgbs
|
887 |
+
else:
|
888 |
+
return x + x_in, (xr + xr_in).detach(), None, prev_weights, None, None
|
889 |
+
|
890 |
+
|
891 |
+
def benchmark_attn():
|
892 |
+
# Lets define a helpful benchmarking function:
|
893 |
+
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
|
894 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
895 |
+
import torch.nn.functional as F
|
896 |
+
import torch.utils.benchmark as benchmark
|
897 |
+
|
898 |
+
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
899 |
+
t0 = benchmark.Timer(
|
900 |
+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
901 |
+
)
|
902 |
+
return t0.blocked_autorange().mean * 1e6
|
903 |
+
|
904 |
+
# Lets define the hyper-parameters of our input
|
905 |
+
batch_size = 32
|
906 |
+
max_sequence_len = 1024
|
907 |
+
num_heads = 32
|
908 |
+
embed_dimension = 32
|
909 |
+
|
910 |
+
dtype = torch.float16
|
911 |
+
|
912 |
+
query = torch.rand(
|
913 |
+
batch_size,
|
914 |
+
num_heads,
|
915 |
+
max_sequence_len,
|
916 |
+
embed_dimension,
|
917 |
+
device=device,
|
918 |
+
dtype=dtype,
|
919 |
+
)
|
920 |
+
key = torch.rand(
|
921 |
+
batch_size,
|
922 |
+
num_heads,
|
923 |
+
max_sequence_len,
|
924 |
+
embed_dimension,
|
925 |
+
device=device,
|
926 |
+
dtype=dtype,
|
927 |
+
)
|
928 |
+
value = torch.rand(
|
929 |
+
batch_size,
|
930 |
+
num_heads,
|
931 |
+
max_sequence_len,
|
932 |
+
embed_dimension,
|
933 |
+
device=device,
|
934 |
+
dtype=dtype,
|
935 |
+
)
|
936 |
+
|
937 |
+
print(f"q/k/v shape:", query.shape, key.shape, value.shape)
|
938 |
+
|
939 |
+
# Lets explore the speed of each of the 3 implementations
|
940 |
+
from torch.backends.cuda import SDPBackend, sdp_kernel
|
941 |
+
|
942 |
+
# Helpful arguments mapper
|
943 |
+
backend_map = {
|
944 |
+
SDPBackend.MATH: {
|
945 |
+
"enable_math": True,
|
946 |
+
"enable_flash": False,
|
947 |
+
"enable_mem_efficient": False,
|
948 |
+
},
|
949 |
+
SDPBackend.FLASH_ATTENTION: {
|
950 |
+
"enable_math": False,
|
951 |
+
"enable_flash": True,
|
952 |
+
"enable_mem_efficient": False,
|
953 |
+
},
|
954 |
+
SDPBackend.EFFICIENT_ATTENTION: {
|
955 |
+
"enable_math": False,
|
956 |
+
"enable_flash": False,
|
957 |
+
"enable_mem_efficient": True,
|
958 |
+
},
|
959 |
+
}
|
960 |
+
|
961 |
+
from torch.profiler import ProfilerActivity, profile, record_function
|
962 |
+
|
963 |
+
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
964 |
+
|
965 |
+
print(
|
966 |
+
f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
967 |
+
)
|
968 |
+
with profile(
|
969 |
+
activities=activities, record_shapes=False, profile_memory=True
|
970 |
+
) as prof:
|
971 |
+
with record_function("Default detailed stats"):
|
972 |
+
for _ in range(25):
|
973 |
+
o = F.scaled_dot_product_attention(query, key, value)
|
974 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
975 |
+
|
976 |
+
print(
|
977 |
+
f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
978 |
+
)
|
979 |
+
with sdp_kernel(**backend_map[SDPBackend.MATH]):
|
980 |
+
with profile(
|
981 |
+
activities=activities, record_shapes=False, profile_memory=True
|
982 |
+
) as prof:
|
983 |
+
with record_function("Math implmentation stats"):
|
984 |
+
for _ in range(25):
|
985 |
+
o = F.scaled_dot_product_attention(query, key, value)
|
986 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
987 |
+
|
988 |
+
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
|
989 |
+
try:
|
990 |
+
print(
|
991 |
+
f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
992 |
+
)
|
993 |
+
except RuntimeError:
|
994 |
+
print("FlashAttention is not supported. See warnings for reasons.")
|
995 |
+
with profile(
|
996 |
+
activities=activities, record_shapes=False, profile_memory=True
|
997 |
+
) as prof:
|
998 |
+
with record_function("FlashAttention stats"):
|
999 |
+
for _ in range(25):
|
1000 |
+
o = F.scaled_dot_product_attention(query, key, value)
|
1001 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
1002 |
+
|
1003 |
+
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
|
1004 |
+
try:
|
1005 |
+
print(
|
1006 |
+
f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
1007 |
+
)
|
1008 |
+
except RuntimeError:
|
1009 |
+
print("EfficientAttention is not supported. See warnings for reasons.")
|
1010 |
+
with profile(
|
1011 |
+
activities=activities, record_shapes=False, profile_memory=True
|
1012 |
+
) as prof:
|
1013 |
+
with record_function("EfficientAttention stats"):
|
1014 |
+
for _ in range(25):
|
1015 |
+
o = F.scaled_dot_product_attention(query, key, value)
|
1016 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
1017 |
+
|
1018 |
+
|
1019 |
+
def run_model(model, x, context):
|
1020 |
+
return model(x, context)
|
1021 |
+
|
1022 |
+
|
1023 |
+
def benchmark_transformer_blocks():
|
1024 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
1025 |
+
import torch.utils.benchmark as benchmark
|
1026 |
+
|
1027 |
+
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
1028 |
+
t0 = benchmark.Timer(
|
1029 |
+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
1030 |
+
)
|
1031 |
+
return t0.blocked_autorange().mean * 1e6
|
1032 |
+
|
1033 |
+
checkpoint = True
|
1034 |
+
compile = False
|
1035 |
+
|
1036 |
+
batch_size = 32
|
1037 |
+
h, w = 64, 64
|
1038 |
+
context_len = 77
|
1039 |
+
embed_dimension = 1024
|
1040 |
+
context_dim = 1024
|
1041 |
+
d_head = 64
|
1042 |
+
|
1043 |
+
transformer_depth = 4
|
1044 |
+
|
1045 |
+
n_heads = embed_dimension // d_head
|
1046 |
+
|
1047 |
+
dtype = torch.float16
|
1048 |
+
|
1049 |
+
model_native = SpatialTransformer(
|
1050 |
+
embed_dimension,
|
1051 |
+
n_heads,
|
1052 |
+
d_head,
|
1053 |
+
context_dim=context_dim,
|
1054 |
+
use_linear=True,
|
1055 |
+
use_checkpoint=checkpoint,
|
1056 |
+
attn_type="softmax",
|
1057 |
+
depth=transformer_depth,
|
1058 |
+
sdp_backend=SDPBackend.FLASH_ATTENTION,
|
1059 |
+
).to(device)
|
1060 |
+
model_efficient_attn = SpatialTransformer(
|
1061 |
+
embed_dimension,
|
1062 |
+
n_heads,
|
1063 |
+
d_head,
|
1064 |
+
context_dim=context_dim,
|
1065 |
+
use_linear=True,
|
1066 |
+
depth=transformer_depth,
|
1067 |
+
use_checkpoint=checkpoint,
|
1068 |
+
attn_type="softmax-xformers",
|
1069 |
+
).to(device)
|
1070 |
+
if not checkpoint and compile:
|
1071 |
+
print("compiling models")
|
1072 |
+
model_native = torch.compile(model_native)
|
1073 |
+
model_efficient_attn = torch.compile(model_efficient_attn)
|
1074 |
+
|
1075 |
+
x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
|
1076 |
+
c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
|
1077 |
+
|
1078 |
+
from torch.profiler import ProfilerActivity, profile, record_function
|
1079 |
+
|
1080 |
+
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
1081 |
+
|
1082 |
+
with torch.autocast("cuda"):
|
1083 |
+
print(
|
1084 |
+
f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
|
1085 |
+
)
|
1086 |
+
print(
|
1087 |
+
f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
|
1088 |
+
)
|
1089 |
+
|
1090 |
+
print(75 * "+")
|
1091 |
+
print("NATIVE")
|
1092 |
+
print(75 * "+")
|
1093 |
+
torch.cuda.reset_peak_memory_stats()
|
1094 |
+
with profile(
|
1095 |
+
activities=activities, record_shapes=False, profile_memory=True
|
1096 |
+
) as prof:
|
1097 |
+
with record_function("NativeAttention stats"):
|
1098 |
+
for _ in range(25):
|
1099 |
+
model_native(x, c)
|
1100 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
1101 |
+
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
|
1102 |
+
|
1103 |
+
print(75 * "+")
|
1104 |
+
print("Xformers")
|
1105 |
+
print(75 * "+")
|
1106 |
+
torch.cuda.reset_peak_memory_stats()
|
1107 |
+
with profile(
|
1108 |
+
activities=activities, record_shapes=False, profile_memory=True
|
1109 |
+
) as prof:
|
1110 |
+
with record_function("xformers stats"):
|
1111 |
+
for _ in range(25):
|
1112 |
+
model_efficient_attn(x, c)
|
1113 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
1114 |
+
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
|
1115 |
+
|
1116 |
+
|
1117 |
+
def test01():
|
1118 |
+
# conv1x1 vs linear
|
1119 |
+
from ..util import count_params
|
1120 |
+
|
1121 |
+
conv = nn.Conv2d(3, 32, kernel_size=1).cuda()
|
1122 |
+
print(count_params(conv))
|
1123 |
+
linear = torch.nn.Linear(3, 32).cuda()
|
1124 |
+
print(count_params(linear))
|
1125 |
+
|
1126 |
+
print(conv.weight.shape)
|
1127 |
+
|
1128 |
+
# use same initialization
|
1129 |
+
linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
|
1130 |
+
linear.bias = torch.nn.Parameter(conv.bias)
|
1131 |
+
|
1132 |
+
print(linear.weight.shape)
|
1133 |
+
|
1134 |
+
x = torch.randn(11, 3, 64, 64).cuda()
|
1135 |
+
|
1136 |
+
xr = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
1137 |
+
print(xr.shape)
|
1138 |
+
out_linear = linear(xr)
|
1139 |
+
print(out_linear.mean(), out_linear.shape)
|
1140 |
+
|
1141 |
+
out_conv = conv(x)
|
1142 |
+
print(out_conv.mean(), out_conv.shape)
|
1143 |
+
print("done with test01.\n")
|
1144 |
+
|
1145 |
+
|
1146 |
+
def test02():
|
1147 |
+
# try cosine flash attention
|
1148 |
+
import time
|
1149 |
+
|
1150 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
1151 |
+
torch.backends.cudnn.allow_tf32 = True
|
1152 |
+
torch.backends.cudnn.benchmark = True
|
1153 |
+
print("testing cosine flash attention...")
|
1154 |
+
DIM = 1024
|
1155 |
+
SEQLEN = 4096
|
1156 |
+
BS = 16
|
1157 |
+
|
1158 |
+
print(" softmax (vanilla) first...")
|
1159 |
+
model = BasicTransformerBlock(
|
1160 |
+
dim=DIM,
|
1161 |
+
n_heads=16,
|
1162 |
+
d_head=64,
|
1163 |
+
dropout=0.0,
|
1164 |
+
context_dim=None,
|
1165 |
+
attn_mode="softmax",
|
1166 |
+
).cuda()
|
1167 |
+
try:
|
1168 |
+
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
1169 |
+
tic = time.time()
|
1170 |
+
y = model(x)
|
1171 |
+
toc = time.time()
|
1172 |
+
print(y.shape, toc - tic)
|
1173 |
+
except RuntimeError as e:
|
1174 |
+
# likely oom
|
1175 |
+
print(str(e))
|
1176 |
+
|
1177 |
+
print("\n now flash-cosine...")
|
1178 |
+
model = BasicTransformerBlock(
|
1179 |
+
dim=DIM,
|
1180 |
+
n_heads=16,
|
1181 |
+
d_head=64,
|
1182 |
+
dropout=0.0,
|
1183 |
+
context_dim=None,
|
1184 |
+
attn_mode="flash-cosine",
|
1185 |
+
).cuda()
|
1186 |
+
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
1187 |
+
tic = time.time()
|
1188 |
+
y = model(x)
|
1189 |
+
toc = time.time()
|
1190 |
+
print(y.shape, toc - tic)
|
1191 |
+
print("done with test02.\n")
|
1192 |
+
|
1193 |
+
|
1194 |
+
if __name__ == "__main__":
|
1195 |
+
# test01()
|
1196 |
+
# test02()
|
1197 |
+
# test03()
|
1198 |
+
|
1199 |
+
# benchmark_attn()
|
1200 |
+
benchmark_transformer_blocks()
|
1201 |
+
|
1202 |
+
print("done.")
|
sgm/modules/autoencoding/__init__.py
ADDED
File without changes
|
sgm/modules/autoencoding/lpips/__init__.py
ADDED
File without changes
|
sgm/modules/autoencoding/lpips/loss.py
ADDED
File without changes
|
sgm/modules/autoencoding/lpips/loss/LICENSE
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without
|
5 |
+
modification, are permitted provided that the following conditions are met:
|
6 |
+
|
7 |
+
* Redistributions of source code must retain the above copyright notice, this
|
8 |
+
list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer in the documentation
|
12 |
+
and/or other materials provided with the distribution.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
18 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
19 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
20 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
21 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
22 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
23 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
sgm/modules/autoencoding/lpips/loss/__init__.py
ADDED
File without changes
|
sgm/modules/autoencoding/lpips/loss/lpips.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
2 |
+
|
3 |
+
from collections import namedtuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision import models
|
8 |
+
|
9 |
+
from ..util import get_ckpt_path
|
10 |
+
|
11 |
+
|
12 |
+
class LPIPS(nn.Module):
|
13 |
+
# Learned perceptual metric
|
14 |
+
def __init__(self, use_dropout=True):
|
15 |
+
super().__init__()
|
16 |
+
self.scaling_layer = ScalingLayer()
|
17 |
+
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
18 |
+
self.net = vgg16(pretrained=True, requires_grad=False)
|
19 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
20 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
21 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
22 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
23 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
24 |
+
self.load_from_pretrained()
|
25 |
+
for param in self.parameters():
|
26 |
+
param.requires_grad = False
|
27 |
+
|
28 |
+
def load_from_pretrained(self, name="vgg_lpips"):
|
29 |
+
ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
|
30 |
+
self.load_state_dict(
|
31 |
+
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
32 |
+
)
|
33 |
+
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
34 |
+
|
35 |
+
@classmethod
|
36 |
+
def from_pretrained(cls, name="vgg_lpips"):
|
37 |
+
if name != "vgg_lpips":
|
38 |
+
raise NotImplementedError
|
39 |
+
model = cls()
|
40 |
+
ckpt = get_ckpt_path(name)
|
41 |
+
model.load_state_dict(
|
42 |
+
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
43 |
+
)
|
44 |
+
return model
|
45 |
+
|
46 |
+
def forward(self, input, target):
|
47 |
+
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
48 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
49 |
+
feats0, feats1, diffs = {}, {}, {}
|
50 |
+
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
51 |
+
for kk in range(len(self.chns)):
|
52 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
|
53 |
+
outs1[kk]
|
54 |
+
)
|
55 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
56 |
+
|
57 |
+
res = [
|
58 |
+
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
|
59 |
+
for kk in range(len(self.chns))
|
60 |
+
]
|
61 |
+
val = res[0]
|
62 |
+
for l in range(1, len(self.chns)):
|
63 |
+
val += res[l]
|
64 |
+
return val
|
65 |
+
|
66 |
+
|
67 |
+
class ScalingLayer(nn.Module):
|
68 |
+
def __init__(self):
|
69 |
+
super(ScalingLayer, self).__init__()
|
70 |
+
self.register_buffer(
|
71 |
+
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
|
72 |
+
)
|
73 |
+
self.register_buffer(
|
74 |
+
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
|
75 |
+
)
|
76 |
+
|
77 |
+
def forward(self, inp):
|
78 |
+
return (inp - self.shift) / self.scale
|
79 |
+
|
80 |
+
|
81 |
+
class NetLinLayer(nn.Module):
|
82 |
+
"""A single linear layer which does a 1x1 conv"""
|
83 |
+
|
84 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
85 |
+
super(NetLinLayer, self).__init__()
|
86 |
+
layers = (
|
87 |
+
[
|
88 |
+
nn.Dropout(),
|
89 |
+
]
|
90 |
+
if (use_dropout)
|
91 |
+
else []
|
92 |
+
)
|
93 |
+
layers += [
|
94 |
+
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
|
95 |
+
]
|
96 |
+
self.model = nn.Sequential(*layers)
|
97 |
+
|
98 |
+
|
99 |
+
class vgg16(torch.nn.Module):
|
100 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
101 |
+
super(vgg16, self).__init__()
|
102 |
+
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
103 |
+
self.slice1 = torch.nn.Sequential()
|
104 |
+
self.slice2 = torch.nn.Sequential()
|
105 |
+
self.slice3 = torch.nn.Sequential()
|
106 |
+
self.slice4 = torch.nn.Sequential()
|
107 |
+
self.slice5 = torch.nn.Sequential()
|
108 |
+
self.N_slices = 5
|
109 |
+
for x in range(4):
|
110 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
111 |
+
for x in range(4, 9):
|
112 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
113 |
+
for x in range(9, 16):
|
114 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
115 |
+
for x in range(16, 23):
|
116 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
117 |
+
for x in range(23, 30):
|
118 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
119 |
+
if not requires_grad:
|
120 |
+
for param in self.parameters():
|
121 |
+
param.requires_grad = False
|
122 |
+
|
123 |
+
def forward(self, X):
|
124 |
+
h = self.slice1(X)
|
125 |
+
h_relu1_2 = h
|
126 |
+
h = self.slice2(h)
|
127 |
+
h_relu2_2 = h
|
128 |
+
h = self.slice3(h)
|
129 |
+
h_relu3_3 = h
|
130 |
+
h = self.slice4(h)
|
131 |
+
h_relu4_3 = h
|
132 |
+
h = self.slice5(h)
|
133 |
+
h_relu5_3 = h
|
134 |
+
vgg_outputs = namedtuple(
|
135 |
+
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
|
136 |
+
)
|
137 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
def normalize_tensor(x, eps=1e-10):
|
142 |
+
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
|
143 |
+
return x / (norm_factor + eps)
|
144 |
+
|
145 |
+
|
146 |
+
def spatial_average(x, keepdim=True):
|
147 |
+
return x.mean([2, 3], keepdim=keepdim)
|
sgm/modules/autoencoding/lpips/model/LICENSE
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without
|
5 |
+
modification, are permitted provided that the following conditions are met:
|
6 |
+
|
7 |
+
* Redistributions of source code must retain the above copyright notice, this
|
8 |
+
list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer in the documentation
|
12 |
+
and/or other materials provided with the distribution.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
18 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
19 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
20 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
21 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
22 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
23 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
24 |
+
|
25 |
+
|
26 |
+
--------------------------- LICENSE FOR pix2pix --------------------------------
|
27 |
+
BSD License
|
28 |
+
|
29 |
+
For pix2pix software
|
30 |
+
Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
|
31 |
+
All rights reserved.
|
32 |
+
|
33 |
+
Redistribution and use in source and binary forms, with or without
|
34 |
+
modification, are permitted provided that the following conditions are met:
|
35 |
+
|
36 |
+
* Redistributions of source code must retain the above copyright notice, this
|
37 |
+
list of conditions and the following disclaimer.
|
38 |
+
|
39 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
40 |
+
this list of conditions and the following disclaimer in the documentation
|
41 |
+
and/or other materials provided with the distribution.
|
42 |
+
|
43 |
+
----------------------------- LICENSE FOR DCGAN --------------------------------
|
44 |
+
BSD License
|
45 |
+
|
46 |
+
For dcgan.torch software
|
47 |
+
|
48 |
+
Copyright (c) 2015, Facebook, Inc. All rights reserved.
|
49 |
+
|
50 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
51 |
+
|
52 |
+
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
53 |
+
|
54 |
+
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
55 |
+
|
56 |
+
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
57 |
+
|
58 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
sgm/modules/autoencoding/lpips/model/__init__.py
ADDED
File without changes
|
sgm/modules/autoencoding/lpips/model/model.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from ..util import ActNorm
|
6 |
+
|
7 |
+
|
8 |
+
def weights_init(m):
|
9 |
+
classname = m.__class__.__name__
|
10 |
+
if classname.find("Conv") != -1:
|
11 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
12 |
+
elif classname.find("BatchNorm") != -1:
|
13 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
14 |
+
nn.init.constant_(m.bias.data, 0)
|
15 |
+
|
16 |
+
|
17 |
+
class NLayerDiscriminator(nn.Module):
|
18 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
19 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
23 |
+
"""Construct a PatchGAN discriminator
|
24 |
+
Parameters:
|
25 |
+
input_nc (int) -- the number of channels in input images
|
26 |
+
ndf (int) -- the number of filters in the last conv layer
|
27 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
28 |
+
norm_layer -- normalization layer
|
29 |
+
"""
|
30 |
+
super(NLayerDiscriminator, self).__init__()
|
31 |
+
if not use_actnorm:
|
32 |
+
norm_layer = nn.BatchNorm2d
|
33 |
+
else:
|
34 |
+
norm_layer = ActNorm
|
35 |
+
if (
|
36 |
+
type(norm_layer) == functools.partial
|
37 |
+
): # no need to use bias as BatchNorm2d has affine parameters
|
38 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
39 |
+
else:
|
40 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
41 |
+
|
42 |
+
kw = 4
|
43 |
+
padw = 1
|
44 |
+
sequence = [
|
45 |
+
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
46 |
+
nn.LeakyReLU(0.2, True),
|
47 |
+
]
|
48 |
+
nf_mult = 1
|
49 |
+
nf_mult_prev = 1
|
50 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
51 |
+
nf_mult_prev = nf_mult
|
52 |
+
nf_mult = min(2**n, 8)
|
53 |
+
sequence += [
|
54 |
+
nn.Conv2d(
|
55 |
+
ndf * nf_mult_prev,
|
56 |
+
ndf * nf_mult,
|
57 |
+
kernel_size=kw,
|
58 |
+
stride=2,
|
59 |
+
padding=padw,
|
60 |
+
bias=use_bias,
|
61 |
+
),
|
62 |
+
norm_layer(ndf * nf_mult),
|
63 |
+
nn.LeakyReLU(0.2, True),
|
64 |
+
]
|
65 |
+
|
66 |
+
nf_mult_prev = nf_mult
|
67 |
+
nf_mult = min(2**n_layers, 8)
|
68 |
+
sequence += [
|
69 |
+
nn.Conv2d(
|
70 |
+
ndf * nf_mult_prev,
|
71 |
+
ndf * nf_mult,
|
72 |
+
kernel_size=kw,
|
73 |
+
stride=1,
|
74 |
+
padding=padw,
|
75 |
+
bias=use_bias,
|
76 |
+
),
|
77 |
+
norm_layer(ndf * nf_mult),
|
78 |
+
nn.LeakyReLU(0.2, True),
|
79 |
+
]
|
80 |
+
|
81 |
+
sequence += [
|
82 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
|
83 |
+
] # output 1 channel prediction map
|
84 |
+
self.main = nn.Sequential(*sequence)
|
85 |
+
|
86 |
+
def forward(self, input):
|
87 |
+
"""Standard forward."""
|
88 |
+
return self.main(input)
|
sgm/modules/autoencoding/lpips/util.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
import requests
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
|
10 |
+
|
11 |
+
CKPT_MAP = {"vgg_lpips": "vgg.pth"}
|
12 |
+
|
13 |
+
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
|
14 |
+
|
15 |
+
|
16 |
+
def download(url, local_path, chunk_size=1024):
|
17 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
18 |
+
with requests.get(url, stream=True) as r:
|
19 |
+
total_size = int(r.headers.get("content-length", 0))
|
20 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
21 |
+
with open(local_path, "wb") as f:
|
22 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
23 |
+
if data:
|
24 |
+
f.write(data)
|
25 |
+
pbar.update(chunk_size)
|
26 |
+
|
27 |
+
|
28 |
+
def md5_hash(path):
|
29 |
+
with open(path, "rb") as f:
|
30 |
+
content = f.read()
|
31 |
+
return hashlib.md5(content).hexdigest()
|
32 |
+
|
33 |
+
|
34 |
+
def get_ckpt_path(name, root, check=False):
|
35 |
+
assert name in URL_MAP
|
36 |
+
path = os.path.join(root, CKPT_MAP[name])
|
37 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
38 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
39 |
+
download(URL_MAP[name], path)
|
40 |
+
md5 = md5_hash(path)
|
41 |
+
assert md5 == MD5_MAP[name], md5
|
42 |
+
return path
|
43 |
+
|
44 |
+
|
45 |
+
class ActNorm(nn.Module):
|
46 |
+
def __init__(
|
47 |
+
self, num_features, logdet=False, affine=True, allow_reverse_init=False
|
48 |
+
):
|
49 |
+
assert affine
|
50 |
+
super().__init__()
|
51 |
+
self.logdet = logdet
|
52 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
53 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
54 |
+
self.allow_reverse_init = allow_reverse_init
|
55 |
+
|
56 |
+
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
|
57 |
+
|
58 |
+
def initialize(self, input):
|
59 |
+
with torch.no_grad():
|
60 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
61 |
+
mean = (
|
62 |
+
flatten.mean(1)
|
63 |
+
.unsqueeze(1)
|
64 |
+
.unsqueeze(2)
|
65 |
+
.unsqueeze(3)
|
66 |
+
.permute(1, 0, 2, 3)
|
67 |
+
)
|
68 |
+
std = (
|
69 |
+
flatten.std(1)
|
70 |
+
.unsqueeze(1)
|
71 |
+
.unsqueeze(2)
|
72 |
+
.unsqueeze(3)
|
73 |
+
.permute(1, 0, 2, 3)
|
74 |
+
)
|
75 |
+
|
76 |
+
self.loc.data.copy_(-mean)
|
77 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
78 |
+
|
79 |
+
def forward(self, input, reverse=False):
|
80 |
+
if reverse:
|
81 |
+
return self.reverse(input)
|
82 |
+
if len(input.shape) == 2:
|
83 |
+
input = input[:, :, None, None]
|
84 |
+
squeeze = True
|
85 |
+
else:
|
86 |
+
squeeze = False
|
87 |
+
|
88 |
+
_, _, height, width = input.shape
|
89 |
+
|
90 |
+
if self.training and self.initialized.item() == 0:
|
91 |
+
self.initialize(input)
|
92 |
+
self.initialized.fill_(1)
|
93 |
+
|
94 |
+
h = self.scale * (input + self.loc)
|
95 |
+
|
96 |
+
if squeeze:
|
97 |
+
h = h.squeeze(-1).squeeze(-1)
|
98 |
+
|
99 |
+
if self.logdet:
|
100 |
+
log_abs = torch.log(torch.abs(self.scale))
|
101 |
+
logdet = height * width * torch.sum(log_abs)
|
102 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
103 |
+
return h, logdet
|
104 |
+
|
105 |
+
return h
|
106 |
+
|
107 |
+
def reverse(self, output):
|
108 |
+
if self.training and self.initialized.item() == 0:
|
109 |
+
if not self.allow_reverse_init:
|
110 |
+
raise RuntimeError(
|
111 |
+
"Initializing ActNorm in reverse direction is "
|
112 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
self.initialize(output)
|
116 |
+
self.initialized.fill_(1)
|
117 |
+
|
118 |
+
if len(output.shape) == 2:
|
119 |
+
output = output[:, :, None, None]
|
120 |
+
squeeze = True
|
121 |
+
else:
|
122 |
+
squeeze = False
|
123 |
+
|
124 |
+
h = output / self.scale - self.loc
|
125 |
+
|
126 |
+
if squeeze:
|
127 |
+
h = h.squeeze(-1).squeeze(-1)
|
128 |
+
return h
|
sgm/modules/autoencoding/lpips/vqperceptual.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def hinge_d_loss(logits_real, logits_fake):
|
6 |
+
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
7 |
+
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
8 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
9 |
+
return d_loss
|
10 |
+
|
11 |
+
|
12 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
13 |
+
d_loss = 0.5 * (
|
14 |
+
torch.mean(torch.nn.functional.softplus(-logits_real))
|
15 |
+
+ torch.mean(torch.nn.functional.softplus(logits_fake))
|
16 |
+
)
|
17 |
+
return d_loss
|
sgm/modules/autoencoding/regularizers/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from typing import Any, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from ....modules.distributions.distributions import \
|
9 |
+
DiagonalGaussianDistribution
|
10 |
+
from .base import AbstractRegularizer
|
11 |
+
|
12 |
+
|
13 |
+
class DiagonalGaussianRegularizer(AbstractRegularizer):
|
14 |
+
def __init__(self, sample: bool = True):
|
15 |
+
super().__init__()
|
16 |
+
self.sample = sample
|
17 |
+
|
18 |
+
def get_trainable_parameters(self) -> Any:
|
19 |
+
yield from ()
|
20 |
+
|
21 |
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
22 |
+
log = dict()
|
23 |
+
posterior = DiagonalGaussianDistribution(z)
|
24 |
+
if self.sample:
|
25 |
+
z = posterior.sample()
|
26 |
+
else:
|
27 |
+
z = posterior.mode()
|
28 |
+
kl_loss = posterior.kl()
|
29 |
+
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
30 |
+
log["kl_loss"] = kl_loss
|
31 |
+
return z, log
|