ashh757 commited on
Commit
2e968f0
·
verified ·
1 Parent(s): 254a1c6

Upload 21 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stage 1: Build the dependencies
2
+ FROM python:3.12-bullseye AS builder
3
+
4
+ # Install required system packages
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ git \
7
+ build-essential \
8
+ cmake \
9
+ libopenblas-dev \
10
+ libomp-dev \
11
+ && apt-get clean \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Set the working directory to /app
15
+ WORKDIR /app
16
+
17
+ # Copy requirements and install dependencies
18
+ COPY requirements.txt /app/
19
+
20
+ # Install Python dependencies and torchmcubes
21
+ RUN pip install --upgrade pip setuptools wheel \
22
+ && pip install -r requirements.txt \
23
+ && pip install git+https://github.com/tatsy/torchmcubes.git@3aef8afa5f21b113afc4f4ea148baee850cbd472 \
24
+ && rm -rf ~/.cache/pip
25
+
26
+ # Copy the application files
27
+ COPY . /app
28
+
29
+ # Configure Git to treat the directory as safe before switching to the final stage
30
+ RUN git config --global --add safe.directory /app
31
+
32
+ # Stage 2: Final image
33
+ FROM python:3.12-slim-bullseye
34
+
35
+ # Set up a new user named "user"
36
+ RUN useradd user
37
+
38
+ # Set the home environment variable and PATH
39
+ ENV HOME=/home/user \
40
+ PATH=/home/user/.local/bin:$PATH
41
+
42
+ # Set the working directory to the user's home directory
43
+ WORKDIR $HOME/app
44
+
45
+ # Copy the application files and installed packages from the builder stage
46
+ COPY --from=builder /app $HOME/app
47
+ COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
48
+ COPY --from=builder /usr/local/bin /usr/local/bin
49
+
50
+ # Change ownership of the app directory to the user
51
+ RUN chown -R user:user $HOME/app
52
+
53
+ # Install git in the final stage
54
+ RUN apt-get update && apt-get install -y --no-install-recommends git \
55
+ && apt-get clean \
56
+ && rm -rf /var/lib/apt/lists/*
57
+
58
+ # Expose secrets at build time and store them in a file
59
+ RUN --mount=type=secret,id=AWS_ACCESS_KEY_ID,mode=0444,required=true \
60
+ git config --global --add safe.directory $HOME/app && \
61
+ git init && \
62
+ git remote add sec1 $(cat /run/secrets/AWS_ACCESS_KEY_ID)
63
+
64
+ RUN --mount=type=secret,id=AWS_SECRET_ACCESS_KEY,mode=0444,required=true \
65
+ git config --global --add safe.directory $HOME/app && \
66
+ git init && \
67
+ git remote add sec2 $(cat /run/secrets/AWS_SECRET_ACCESS_KEY)
68
+
69
+ RUN --mount=type=secret,id=AWS_DEFAULT_REGION,mode=0444,required=true \
70
+ git config --global --add safe.directory $HOME/app && \
71
+ git init && \
72
+ git remote add sec3 $(cat /run/secrets/AWS_DEFAULT_REGION)
73
+
74
+ # Switch to the "user" user
75
+ USER user
76
+
77
+ EXPOSE 7960
78
+
79
+ # Set the entry point to run the FastAPI application
80
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7960"]
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form, Query
2
+ from fastapi.responses import FileResponse
3
+ import os
4
+
5
+ from main import load_model, generate_mesh
6
+
7
+ ## create a new FASTAPI app instance
8
+ app=FastAPI()
9
+
10
+ model = load_model()
11
+
12
+ @app.get("/")
13
+ def home():
14
+ return {"message":"Hello World"}
15
+
16
+ # Define a function to handle the GET request at `/generate`
17
+ @app.post("/generate")
18
+ async def generate(image: UploadFile = File(...),
19
+ no_remove_bg: bool = Form(True),
20
+ foreground_ratio: float = Form(0.85),
21
+ render: bool = Form(False),
22
+ mc_resolution: int = Form(256),
23
+ bake_texture_flag: bool = Form(False),
24
+ texture_resolution: int = Form(2048),
25
+ bucket_name: str= Form('BUCKET_NAME'),
26
+ input_folder: str= Form('INPUT_IMAGES'),
27
+ output_folder: str= Form('OUTPUT_MESHES'),
28
+ input_s3_id: str= Form('input_image.png'),
29
+ output_s3_id: str= Form('output_mesh.obj')
30
+ ):
31
+
32
+ # Save the uploaded image to a temporary location
33
+ temp_image_path = f"tmp/output/{image.filename}"
34
+ with open(temp_image_path, "wb") as f:
35
+ f.write(await image.read())
36
+
37
+ # Call the `generate_mesh` function with customized parameters
38
+ output_file_path, output_video_path = generate_mesh(
39
+ image_path=temp_image_path,
40
+ output_dir='tmp/output/',
41
+ no_remove_bg=no_remove_bg,
42
+ foreground_ratio=foreground_ratio,
43
+ render=render,
44
+ mc_resolution=mc_resolution,
45
+ bake_texture_flag=bake_texture_flag,
46
+ texture_resolution=texture_resolution,
47
+ model=model,
48
+ bucket_name=bucket_name,
49
+ input_folder=input_folder,
50
+ output_folder=output_folder,
51
+ input_s3_id=input_s3_id,
52
+ output_s3_id=output_s3_id
53
+ )
54
+
55
+ if output_video_path==None:
56
+ ## return the generate text in Json reposne
57
+ return FileResponse(output_file_path, media_type='application/octet-stream', filename="output_mesh.obj")
58
+ else:
59
+ return (FileResponse(output_file_path, media_type='application/octet-stream', filename="output_mesh.obj"),
60
+ FileResponse(output_video_path, media_type='application/octet-stream', filename="output_video.mp4"))
main.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from utils import process_image, run_model
3
+ from typing import Tuple, Optional
4
+ from PIL import Image
5
+ from boto3 import Session
6
+ import torch
7
+ import pickle
8
+ import datetime
9
+ import gzip
10
+
11
+ # Retrieve credentials from environment variables
12
+ session = Session(
13
+ aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
14
+ aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
15
+ region_name=os.getenv('AWS_DEFAULT_REGION')
16
+ )
17
+ s3 = session.client('s3')
18
+
19
+ def load_model():
20
+ with gzip.open('model_quantized_compressed.pkl.gz', 'rb') as f_in:
21
+ model_data = f_in.read()
22
+
23
+ model = pickle.loads(model_data)
24
+ print("Model Loaded")
25
+ return model
26
+
27
+ def upload_to_s3(file_path, bucket_name, s3_key):
28
+ with open(file_path, 'rb') as f:
29
+ s3.upload_fileobj(f, bucket_name, s3_key)
30
+ s3_url = f's3://{bucket_name}/{s3_key}'
31
+ return s3_url
32
+
33
+ def generate_mesh(image_path:str,
34
+ output_dir:str ='tmp/output/',
35
+ no_remove_bg:bool =True,
36
+ foreground_ratio:float =0.85 ,
37
+ render:bool =False ,
38
+ mc_resolution:int =256 ,
39
+ bake_texture_flag:bool =False ,
40
+ texture_resolution:int =2048,
41
+ model=None,
42
+ bucket_name:str=None,
43
+ input_folder:str=None,
44
+ output_folder:str=None,
45
+ input_s3_id:str='input_image.png',
46
+ output_s3_id:str='output_mesh.obj'
47
+ ) -> Tuple[Optional[str], Optional[str]] :
48
+
49
+ print('Process start')
50
+ image = process_image(image_path=image_path,
51
+ output_dir=output_dir ,
52
+ no_remove_bg=no_remove_bg ,
53
+ foreground_ratio=foreground_ratio)
54
+ print('Process end')
55
+
56
+ print('Run start')
57
+ output_file_path ,output_video_path = run_model(model=model,
58
+ image=image,
59
+ output_dir=output_dir ,
60
+ device="cuda:0" if torch.cuda.is_available() else "cpu",
61
+ render=render ,
62
+ mc_resolution=mc_resolution ,
63
+ model_save_format='obj',
64
+ bake_texture_flag=bake_texture_flag ,
65
+ texture_resolution=texture_resolution)
66
+ print('Run end')
67
+
68
+ print('Uploading to bucket...')
69
+ # Upload the input image and generated mesh file to S3
70
+ if input_folder != None:
71
+ input_s3_key = input_folder + '/' + input_s3_id
72
+ else:
73
+ input_s3_key = input_s3_id
74
+
75
+ if output_folder != None:
76
+ output_s3_key = output_folder + '/' + output_s3_id
77
+ else:
78
+ output_s3_key = output_s3_id
79
+ print(input_s3_key ,output_s3_key)
80
+
81
+ input_s3_url = upload_to_s3(image_path, bucket_name, input_s3_key)
82
+ output_s3_url = upload_to_s3(output_file_path, bucket_name, output_s3_key)
83
+
84
+ print(f'Files uploaded to S3:\nInput Image: {input_s3_url}\nOutput Mesh: {output_s3_url}')
85
+
86
+ return output_file_path ,output_video_path
model_quantized_compressed.pkl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc8f71c24a7db66a83bd712463be983e9bff078b05153fb60dba3e84d5781955
3
+ size 320691353
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf==2.3.0
2
+ Pillow==10.1.0
3
+ einops==0.7.0
4
+ transformers==4.35.0
5
+ trimesh==4.0.5
6
+ rembg
7
+ huggingface-hub
8
+ imageio[ffmpeg]
9
+ xatlas==0.0.9
10
+ moderngl==5.10.0
11
+ torch==2.3.1
12
+ numpy==1.26.4
13
+ boto3==1.34.161
14
+ uvicorn==0.30.6
15
+ fastapi==0.112.2
16
+ python-multipart==0.0.9
tmp/output/mesh.obj ADDED
The diff for this file is too large to render. See raw diff
 
tmp/output/poly_fox.png ADDED
tmp/output/processed_input.png ADDED
tmp/output/trash.txt ADDED
File without changes
tsr/bake_texture.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import xatlas
4
+ import trimesh
5
+ import moderngl
6
+ from PIL import Image
7
+
8
+
9
+ def make_atlas(mesh, texture_resolution, texture_padding):
10
+ atlas = xatlas.Atlas()
11
+ atlas.add_mesh(mesh.vertices, mesh.faces)
12
+ options = xatlas.PackOptions()
13
+ options.resolution = texture_resolution
14
+ options.padding = texture_padding
15
+ options.bilinear = True
16
+ atlas.generate(pack_options=options)
17
+ vmapping, indices, uvs = atlas[0]
18
+ return {
19
+ "vmapping": vmapping,
20
+ "indices": indices,
21
+ "uvs": uvs,
22
+ }
23
+
24
+
25
+ def rasterize_position_atlas(
26
+ mesh, atlas_vmapping, atlas_indices, atlas_uvs, texture_resolution, texture_padding
27
+ ):
28
+ ctx = moderngl.create_context(standalone=True)
29
+ basic_prog = ctx.program(
30
+ vertex_shader="""
31
+ #version 330
32
+ in vec2 in_uv;
33
+ in vec3 in_pos;
34
+ out vec3 v_pos;
35
+ void main() {
36
+ v_pos = in_pos;
37
+ gl_Position = vec4(in_uv * 2.0 - 1.0, 0.0, 1.0);
38
+ }
39
+ """,
40
+ fragment_shader="""
41
+ #version 330
42
+ in vec3 v_pos;
43
+ out vec4 o_col;
44
+ void main() {
45
+ o_col = vec4(v_pos, 1.0);
46
+ }
47
+ """,
48
+ )
49
+ gs_prog = ctx.program(
50
+ vertex_shader="""
51
+ #version 330
52
+ in vec2 in_uv;
53
+ in vec3 in_pos;
54
+ out vec3 vg_pos;
55
+ void main() {
56
+ vg_pos = in_pos;
57
+ gl_Position = vec4(in_uv * 2.0 - 1.0, 0.0, 1.0);
58
+ }
59
+ """,
60
+ geometry_shader="""
61
+ #version 330
62
+ uniform float u_resolution;
63
+ uniform float u_dilation;
64
+ layout (triangles) in;
65
+ layout (triangle_strip, max_vertices = 12) out;
66
+ in vec3 vg_pos[];
67
+ out vec3 vf_pos;
68
+ void lineSegment(int aidx, int bidx) {
69
+ vec2 a = gl_in[aidx].gl_Position.xy;
70
+ vec2 b = gl_in[bidx].gl_Position.xy;
71
+ vec3 aCol = vg_pos[aidx];
72
+ vec3 bCol = vg_pos[bidx];
73
+
74
+ vec2 dir = normalize((b - a) * u_resolution);
75
+ vec2 offset = vec2(-dir.y, dir.x) * u_dilation / u_resolution;
76
+
77
+ gl_Position = vec4(a + offset, 0.0, 1.0);
78
+ vf_pos = aCol;
79
+ EmitVertex();
80
+ gl_Position = vec4(a - offset, 0.0, 1.0);
81
+ vf_pos = aCol;
82
+ EmitVertex();
83
+ gl_Position = vec4(b + offset, 0.0, 1.0);
84
+ vf_pos = bCol;
85
+ EmitVertex();
86
+ gl_Position = vec4(b - offset, 0.0, 1.0);
87
+ vf_pos = bCol;
88
+ EmitVertex();
89
+ }
90
+ void main() {
91
+ lineSegment(0, 1);
92
+ lineSegment(1, 2);
93
+ lineSegment(2, 0);
94
+ EndPrimitive();
95
+ }
96
+ """,
97
+ fragment_shader="""
98
+ #version 330
99
+ in vec3 vf_pos;
100
+ out vec4 o_col;
101
+ void main() {
102
+ o_col = vec4(vf_pos, 1.0);
103
+ }
104
+ """,
105
+ )
106
+ uvs = atlas_uvs.flatten().astype("f4")
107
+ pos = mesh.vertices[atlas_vmapping].flatten().astype("f4")
108
+ indices = atlas_indices.flatten().astype("i4")
109
+ vbo_uvs = ctx.buffer(uvs)
110
+ vbo_pos = ctx.buffer(pos)
111
+ ibo = ctx.buffer(indices)
112
+ vao_content = [
113
+ vbo_uvs.bind("in_uv", layout="2f"),
114
+ vbo_pos.bind("in_pos", layout="3f"),
115
+ ]
116
+ basic_vao = ctx.vertex_array(basic_prog, vao_content, ibo)
117
+ gs_vao = ctx.vertex_array(gs_prog, vao_content, ibo)
118
+ fbo = ctx.framebuffer(
119
+ color_attachments=[
120
+ ctx.texture((texture_resolution, texture_resolution), 4, dtype="f4")
121
+ ]
122
+ )
123
+ fbo.use()
124
+ fbo.clear(0.0, 0.0, 0.0, 0.0)
125
+ gs_prog["u_resolution"].value = texture_resolution
126
+ gs_prog["u_dilation"].value = texture_padding
127
+ gs_vao.render()
128
+ basic_vao.render()
129
+
130
+ fbo_bytes = fbo.color_attachments[0].read()
131
+ fbo_np = np.frombuffer(fbo_bytes, dtype="f4").reshape(
132
+ texture_resolution, texture_resolution, 4
133
+ )
134
+ return fbo_np
135
+
136
+
137
+ def positions_to_colors(model, scene_code, positions_texture, texture_resolution):
138
+ positions = torch.tensor(positions_texture.reshape(-1, 4)[:, :-1])
139
+ with torch.no_grad():
140
+ queried_grid = model.renderer.query_triplane(
141
+ model.decoder,
142
+ positions,
143
+ scene_code,
144
+ )
145
+ rgb_f = queried_grid["color"].numpy().reshape(-1, 3)
146
+ rgba_f = np.insert(rgb_f, 3, positions_texture.reshape(-1, 4)[:, -1], axis=1)
147
+ rgba_f[rgba_f[:, -1] == 0.0] = [0, 0, 0, 0]
148
+ return rgba_f.reshape(texture_resolution, texture_resolution, 4)
149
+
150
+
151
+ def bake_texture(mesh, model, scene_code, texture_resolution):
152
+ texture_padding = round(max(2, texture_resolution / 256))
153
+ atlas = make_atlas(mesh, texture_resolution, texture_padding)
154
+ positions_texture = rasterize_position_atlas(
155
+ mesh,
156
+ atlas["vmapping"],
157
+ atlas["indices"],
158
+ atlas["uvs"],
159
+ texture_resolution,
160
+ texture_padding,
161
+ )
162
+ colors_texture = positions_to_colors(
163
+ model, scene_code, positions_texture, texture_resolution
164
+ )
165
+ return {
166
+ "vmapping": atlas["vmapping"],
167
+ "indices": atlas["indices"],
168
+ "uvs": atlas["uvs"],
169
+ "colors": colors_texture,
170
+ }
tsr/models/isosurface.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchmcubes import marching_cubes
7
+
8
+
9
+ class IsosurfaceHelper(nn.Module):
10
+ points_range: Tuple[float, float] = (0, 1)
11
+
12
+ @property
13
+ def grid_vertices(self) -> torch.FloatTensor:
14
+ raise NotImplementedError
15
+
16
+
17
+ class MarchingCubeHelper(IsosurfaceHelper):
18
+ def __init__(self, resolution: int) -> None:
19
+ super().__init__()
20
+ self.resolution = resolution
21
+ self.mc_func: Callable = marching_cubes
22
+ self._grid_vertices: Optional[torch.FloatTensor] = None
23
+
24
+ @property
25
+ def grid_vertices(self) -> torch.FloatTensor:
26
+ if self._grid_vertices is None:
27
+ # keep the vertices on CPU so that we can support very large resolution
28
+ x, y, z = (
29
+ torch.linspace(*self.points_range, self.resolution),
30
+ torch.linspace(*self.points_range, self.resolution),
31
+ torch.linspace(*self.points_range, self.resolution),
32
+ )
33
+ x, y, z = torch.meshgrid(x, y, z, indexing="ij")
34
+ verts = torch.cat(
35
+ [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
36
+ ).reshape(-1, 3)
37
+ self._grid_vertices = verts
38
+ return self._grid_vertices
39
+
40
+ def forward(
41
+ self,
42
+ level: torch.FloatTensor,
43
+ ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
44
+ level = -level.view(self.resolution, self.resolution, self.resolution)
45
+ try:
46
+ v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
47
+ except AttributeError:
48
+ print("torchmcubes was not compiled with CUDA support, use CPU version instead.")
49
+ v_pos, t_pos_idx = self.mc_func(level.detach().cpu(), 0.0)
50
+ v_pos = v_pos[..., [2, 1, 0]]
51
+ v_pos = v_pos / (self.resolution - 1.0)
52
+ return v_pos.to(level.device), t_pos_idx.to(level.device)
tsr/models/nerf_renderer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, reduce
7
+
8
+ from ..utils import (
9
+ BaseModule,
10
+ chunk_batch,
11
+ get_activation,
12
+ rays_intersect_bbox,
13
+ scale_tensor,
14
+ )
15
+
16
+
17
+ class TriplaneNeRFRenderer(BaseModule):
18
+ @dataclass
19
+ class Config(BaseModule.Config):
20
+ radius: float
21
+
22
+ feature_reduction: str = "concat"
23
+ density_activation: str = "trunc_exp"
24
+ density_bias: float = -1.0
25
+ color_activation: str = "sigmoid"
26
+ num_samples_per_ray: int = 128
27
+ randomized: bool = False
28
+
29
+ cfg: Config
30
+
31
+ def configure(self) -> None:
32
+ assert self.cfg.feature_reduction in ["concat", "mean"]
33
+ self.chunk_size = 0
34
+
35
+ def set_chunk_size(self, chunk_size: int):
36
+ assert (
37
+ chunk_size >= 0
38
+ ), "chunk_size must be a non-negative integer (0 for no chunking)."
39
+ self.chunk_size = chunk_size
40
+
41
+ def query_triplane(
42
+ self,
43
+ decoder: torch.nn.Module,
44
+ positions: torch.Tensor,
45
+ triplane: torch.Tensor,
46
+ ) -> Dict[str, torch.Tensor]:
47
+ input_shape = positions.shape[:-1]
48
+ positions = positions.view(-1, 3)
49
+
50
+ # positions in (-radius, radius)
51
+ # normalized to (-1, 1) for grid sample
52
+ positions = scale_tensor(
53
+ positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
54
+ )
55
+
56
+ def _query_chunk(x):
57
+ indices2D: torch.Tensor = torch.stack(
58
+ (x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
59
+ dim=-3,
60
+ )
61
+ out: torch.Tensor = F.grid_sample(
62
+ rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
63
+ rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
64
+ align_corners=False,
65
+ mode="bilinear",
66
+ )
67
+ if self.cfg.feature_reduction == "concat":
68
+ out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
69
+ elif self.cfg.feature_reduction == "mean":
70
+ out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
71
+ else:
72
+ raise NotImplementedError
73
+
74
+ net_out: Dict[str, torch.Tensor] = decoder(out)
75
+ return net_out
76
+
77
+ if self.chunk_size > 0:
78
+ net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
79
+ else:
80
+ net_out = _query_chunk(positions)
81
+
82
+ net_out["density_act"] = get_activation(self.cfg.density_activation)(
83
+ net_out["density"] + self.cfg.density_bias
84
+ )
85
+ net_out["color"] = get_activation(self.cfg.color_activation)(
86
+ net_out["features"]
87
+ )
88
+
89
+ net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}
90
+
91
+ return net_out
92
+
93
+ def _forward(
94
+ self,
95
+ decoder: torch.nn.Module,
96
+ triplane: torch.Tensor,
97
+ rays_o: torch.Tensor,
98
+ rays_d: torch.Tensor,
99
+ **kwargs,
100
+ ):
101
+ rays_shape = rays_o.shape[:-1]
102
+ rays_o = rays_o.view(-1, 3)
103
+ rays_d = rays_d.view(-1, 3)
104
+ n_rays = rays_o.shape[0]
105
+
106
+ t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
107
+ t_near, t_far = t_near[rays_valid], t_far[rays_valid]
108
+
109
+ t_vals = torch.linspace(
110
+ 0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
111
+ )
112
+ t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
113
+ z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None] # (N_rays, N_samples)
114
+
115
+ xyz = (
116
+ rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
117
+ ) # (N_rays, N_sample, 3)
118
+
119
+ mlp_out = self.query_triplane(
120
+ decoder=decoder,
121
+ positions=xyz,
122
+ triplane=triplane,
123
+ )
124
+
125
+ eps = 1e-10
126
+ # deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
127
+ deltas = t_vals[1:] - t_vals[:-1] # (N_rays, N_samples)
128
+ alpha = 1 - torch.exp(
129
+ -deltas * mlp_out["density_act"][..., 0]
130
+ ) # (N_rays, N_samples)
131
+ accum_prod = torch.cat(
132
+ [
133
+ torch.ones_like(alpha[:, :1]),
134
+ torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
135
+ ],
136
+ dim=-1,
137
+ )
138
+ weights = alpha * accum_prod # (N_rays, N_samples)
139
+ comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2) # (N_rays, 3)
140
+ opacity_ = weights.sum(dim=-1) # (N_rays)
141
+
142
+ comp_rgb = torch.zeros(
143
+ n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
144
+ )
145
+ opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
146
+ comp_rgb[rays_valid] = comp_rgb_
147
+ opacity[rays_valid] = opacity_
148
+
149
+ comp_rgb += 1 - opacity[..., None]
150
+ comp_rgb = comp_rgb.view(*rays_shape, 3)
151
+
152
+ return comp_rgb
153
+
154
+ def forward(
155
+ self,
156
+ decoder: torch.nn.Module,
157
+ triplane: torch.Tensor,
158
+ rays_o: torch.Tensor,
159
+ rays_d: torch.Tensor,
160
+ ) -> Dict[str, torch.Tensor]:
161
+ if triplane.ndim == 4:
162
+ comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
163
+ else:
164
+ comp_rgb = torch.stack(
165
+ [
166
+ self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
167
+ for i in range(triplane.shape[0])
168
+ ],
169
+ dim=0,
170
+ )
171
+
172
+ return comp_rgb
173
+
174
+ def train(self, mode=True):
175
+ self.randomized = mode and self.cfg.randomized
176
+ return super().train(mode=mode)
177
+
178
+ def eval(self):
179
+ self.randomized = False
180
+ return super().eval()
tsr/models/network_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from ..utils import BaseModule
9
+
10
+
11
+ class TriplaneUpsampleNetwork(BaseModule):
12
+ @dataclass
13
+ class Config(BaseModule.Config):
14
+ in_channels: int
15
+ out_channels: int
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.upsample = nn.ConvTranspose2d(
21
+ self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
22
+ )
23
+
24
+ def forward(self, triplanes: torch.Tensor) -> torch.Tensor:
25
+ triplanes_up = rearrange(
26
+ self.upsample(
27
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
28
+ ),
29
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
30
+ Np=3,
31
+ )
32
+ return triplanes_up
33
+
34
+
35
+ class NeRFMLP(BaseModule):
36
+ @dataclass
37
+ class Config(BaseModule.Config):
38
+ in_channels: int
39
+ n_neurons: int
40
+ n_hidden_layers: int
41
+ activation: str = "relu"
42
+ bias: bool = True
43
+ weight_init: Optional[str] = "kaiming_uniform"
44
+ bias_init: Optional[str] = None
45
+
46
+ cfg: Config
47
+
48
+ def configure(self) -> None:
49
+ layers = [
50
+ self.make_linear(
51
+ self.cfg.in_channels,
52
+ self.cfg.n_neurons,
53
+ bias=self.cfg.bias,
54
+ weight_init=self.cfg.weight_init,
55
+ bias_init=self.cfg.bias_init,
56
+ ),
57
+ self.make_activation(self.cfg.activation),
58
+ ]
59
+ for i in range(self.cfg.n_hidden_layers - 1):
60
+ layers += [
61
+ self.make_linear(
62
+ self.cfg.n_neurons,
63
+ self.cfg.n_neurons,
64
+ bias=self.cfg.bias,
65
+ weight_init=self.cfg.weight_init,
66
+ bias_init=self.cfg.bias_init,
67
+ ),
68
+ self.make_activation(self.cfg.activation),
69
+ ]
70
+ layers += [
71
+ self.make_linear(
72
+ self.cfg.n_neurons,
73
+ 4, # density 1 + features 3
74
+ bias=self.cfg.bias,
75
+ weight_init=self.cfg.weight_init,
76
+ bias_init=self.cfg.bias_init,
77
+ )
78
+ ]
79
+ self.layers = nn.Sequential(*layers)
80
+
81
+ def make_linear(
82
+ self,
83
+ dim_in,
84
+ dim_out,
85
+ bias=True,
86
+ weight_init=None,
87
+ bias_init=None,
88
+ ):
89
+ layer = nn.Linear(dim_in, dim_out, bias=bias)
90
+
91
+ if weight_init is None:
92
+ pass
93
+ elif weight_init == "kaiming_uniform":
94
+ torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ if bias:
99
+ if bias_init is None:
100
+ pass
101
+ elif bias_init == "zero":
102
+ torch.nn.init.zeros_(layer.bias)
103
+ else:
104
+ raise NotImplementedError
105
+
106
+ return layer
107
+
108
+ def make_activation(self, activation):
109
+ if activation == "relu":
110
+ return nn.ReLU(inplace=True)
111
+ elif activation == "silu":
112
+ return nn.SiLU(inplace=True)
113
+ else:
114
+ raise NotImplementedError
115
+
116
+ def forward(self, x):
117
+ inp_shape = x.shape[:-1]
118
+ x = x.reshape(-1, x.shape[-1])
119
+
120
+ features = self.layers(x)
121
+ features = features.reshape(*inp_shape, -1)
122
+ out = {"density": features[..., 0:1], "features": features[..., 1:4]}
123
+
124
+ return out
tsr/models/tokenizers/image.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from huggingface_hub import hf_hub_download
7
+ from transformers.models.vit.modeling_vit import ViTModel
8
+
9
+ from ...utils import BaseModule
10
+
11
+
12
+ class DINOSingleImageTokenizer(BaseModule):
13
+ @dataclass
14
+ class Config(BaseModule.Config):
15
+ pretrained_model_name_or_path: str = "facebook/dino-vitb16"
16
+ enable_gradient_checkpointing: bool = False
17
+
18
+ cfg: Config
19
+
20
+ def configure(self) -> None:
21
+ self.model: ViTModel = ViTModel(
22
+ ViTModel.config_class.from_pretrained(
23
+ hf_hub_download(
24
+ repo_id=self.cfg.pretrained_model_name_or_path,
25
+ filename="config.json",
26
+ )
27
+ )
28
+ )
29
+
30
+ if self.cfg.enable_gradient_checkpointing:
31
+ self.model.encoder.gradient_checkpointing = True
32
+
33
+ self.register_buffer(
34
+ "image_mean",
35
+ torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
36
+ persistent=False,
37
+ )
38
+ self.register_buffer(
39
+ "image_std",
40
+ torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
41
+ persistent=False,
42
+ )
43
+
44
+ def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
45
+ packed = False
46
+ if images.ndim == 4:
47
+ packed = True
48
+ images = images.unsqueeze(1)
49
+
50
+ batch_size, n_input_views = images.shape[:2]
51
+ images = (images - self.image_mean) / self.image_std
52
+ out = self.model(
53
+ rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
54
+ )
55
+ local_features, global_features = out.last_hidden_state, out.pooler_output
56
+ local_features = local_features.permute(0, 2, 1)
57
+ local_features = rearrange(
58
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
59
+ )
60
+ if packed:
61
+ local_features = local_features.squeeze(1)
62
+
63
+ return local_features
64
+
65
+ def detokenize(self, *args, **kwargs):
66
+ raise NotImplementedError
tsr/models/tokenizers/triplane.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, repeat
7
+
8
+ from ...utils import BaseModule
9
+
10
+
11
+ class Triplane1DTokenizer(BaseModule):
12
+ @dataclass
13
+ class Config(BaseModule.Config):
14
+ plane_size: int
15
+ num_channels: int
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.embeddings = nn.Parameter(
21
+ torch.randn(
22
+ (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
23
+ dtype=torch.float32,
24
+ )
25
+ * 1
26
+ / math.sqrt(self.cfg.num_channels)
27
+ )
28
+
29
+ def forward(self, batch_size: int) -> torch.Tensor:
30
+ return rearrange(
31
+ repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
32
+ "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
33
+ )
34
+
35
+ def detokenize(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ batch_size, Ct, Nt = tokens.shape
37
+ assert Nt == self.cfg.plane_size**2 * 3
38
+ assert Ct == self.cfg.num_channels
39
+ return rearrange(
40
+ tokens,
41
+ "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
42
+ Np=3,
43
+ Hp=self.cfg.plane_size,
44
+ Wp=self.cfg.plane_size,
45
+ )
tsr/models/transformer/attention.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ from typing import Optional
40
+
41
+ import torch
42
+ import torch.nn.functional as F
43
+ from torch import nn
44
+
45
+
46
+ class Attention(nn.Module):
47
+ r"""
48
+ A cross attention layer.
49
+
50
+ Parameters:
51
+ query_dim (`int`):
52
+ The number of channels in the query.
53
+ cross_attention_dim (`int`, *optional*):
54
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
55
+ heads (`int`, *optional*, defaults to 8):
56
+ The number of heads to use for multi-head attention.
57
+ dim_head (`int`, *optional*, defaults to 64):
58
+ The number of channels in each head.
59
+ dropout (`float`, *optional*, defaults to 0.0):
60
+ The dropout probability to use.
61
+ bias (`bool`, *optional*, defaults to False):
62
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
63
+ upcast_attention (`bool`, *optional*, defaults to False):
64
+ Set to `True` to upcast the attention computation to `float32`.
65
+ upcast_softmax (`bool`, *optional*, defaults to False):
66
+ Set to `True` to upcast the softmax computation to `float32`.
67
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
68
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
69
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
70
+ The number of groups to use for the group norm in the cross attention.
71
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
72
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
73
+ norm_num_groups (`int`, *optional*, defaults to `None`):
74
+ The number of groups to use for the group norm in the attention.
75
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
76
+ The number of channels to use for the spatial normalization.
77
+ out_bias (`bool`, *optional*, defaults to `True`):
78
+ Set to `True` to use a bias in the output linear layer.
79
+ scale_qk (`bool`, *optional*, defaults to `True`):
80
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
81
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
82
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
83
+ `added_kv_proj_dim` is not `None`.
84
+ eps (`float`, *optional*, defaults to 1e-5):
85
+ An additional value added to the denominator in group normalization that is used for numerical stability.
86
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
87
+ A factor to rescale the output by dividing it with this value.
88
+ residual_connection (`bool`, *optional*, defaults to `False`):
89
+ Set to `True` to add the residual connection to the output.
90
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
91
+ Set to `True` if the attention block is loaded from a deprecated state dict.
92
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
93
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
94
+ `AttnProcessor` otherwise.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ query_dim: int,
100
+ cross_attention_dim: Optional[int] = None,
101
+ heads: int = 8,
102
+ dim_head: int = 64,
103
+ dropout: float = 0.0,
104
+ bias: bool = False,
105
+ upcast_attention: bool = False,
106
+ upcast_softmax: bool = False,
107
+ cross_attention_norm: Optional[str] = None,
108
+ cross_attention_norm_num_groups: int = 32,
109
+ added_kv_proj_dim: Optional[int] = None,
110
+ norm_num_groups: Optional[int] = None,
111
+ out_bias: bool = True,
112
+ scale_qk: bool = True,
113
+ only_cross_attention: bool = False,
114
+ eps: float = 1e-5,
115
+ rescale_output_factor: float = 1.0,
116
+ residual_connection: bool = False,
117
+ _from_deprecated_attn_block: bool = False,
118
+ processor: Optional["AttnProcessor"] = None,
119
+ out_dim: int = None,
120
+ ):
121
+ super().__init__()
122
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
123
+ self.query_dim = query_dim
124
+ self.cross_attention_dim = (
125
+ cross_attention_dim if cross_attention_dim is not None else query_dim
126
+ )
127
+ self.upcast_attention = upcast_attention
128
+ self.upcast_softmax = upcast_softmax
129
+ self.rescale_output_factor = rescale_output_factor
130
+ self.residual_connection = residual_connection
131
+ self.dropout = dropout
132
+ self.fused_projections = False
133
+ self.out_dim = out_dim if out_dim is not None else query_dim
134
+
135
+ # we make use of this private variable to know whether this class is loaded
136
+ # with an deprecated state dict so that we can convert it on the fly
137
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
138
+
139
+ self.scale_qk = scale_qk
140
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
141
+
142
+ self.heads = out_dim // dim_head if out_dim is not None else heads
143
+ # for slice_size > 0 the attention score computation
144
+ # is split across the batch axis to save memory
145
+ # You can set slice_size with `set_attention_slice`
146
+ self.sliceable_head_dim = heads
147
+
148
+ self.added_kv_proj_dim = added_kv_proj_dim
149
+ self.only_cross_attention = only_cross_attention
150
+
151
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
152
+ raise ValueError(
153
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
154
+ )
155
+
156
+ if norm_num_groups is not None:
157
+ self.group_norm = nn.GroupNorm(
158
+ num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
159
+ )
160
+ else:
161
+ self.group_norm = None
162
+
163
+ self.spatial_norm = None
164
+
165
+ if cross_attention_norm is None:
166
+ self.norm_cross = None
167
+ elif cross_attention_norm == "layer_norm":
168
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
169
+ elif cross_attention_norm == "group_norm":
170
+ if self.added_kv_proj_dim is not None:
171
+ # The given `encoder_hidden_states` are initially of shape
172
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
173
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
174
+ # before the projection, so we need to use `added_kv_proj_dim` as
175
+ # the number of channels for the group norm.
176
+ norm_cross_num_channels = added_kv_proj_dim
177
+ else:
178
+ norm_cross_num_channels = self.cross_attention_dim
179
+
180
+ self.norm_cross = nn.GroupNorm(
181
+ num_channels=norm_cross_num_channels,
182
+ num_groups=cross_attention_norm_num_groups,
183
+ eps=1e-5,
184
+ affine=True,
185
+ )
186
+ else:
187
+ raise ValueError(
188
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
189
+ )
190
+
191
+ linear_cls = nn.Linear
192
+
193
+ self.linear_cls = linear_cls
194
+ self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
195
+
196
+ if not self.only_cross_attention:
197
+ # only relevant for the `AddedKVProcessor` classes
198
+ self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
199
+ self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
200
+ else:
201
+ self.to_k = None
202
+ self.to_v = None
203
+
204
+ if self.added_kv_proj_dim is not None:
205
+ self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
206
+ self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
207
+
208
+ self.to_out = nn.ModuleList([])
209
+ self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
210
+ self.to_out.append(nn.Dropout(dropout))
211
+
212
+ # set attention processor
213
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
214
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
215
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
216
+ if processor is None:
217
+ processor = (
218
+ AttnProcessor2_0()
219
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
220
+ else AttnProcessor()
221
+ )
222
+ self.set_processor(processor)
223
+
224
+ def set_processor(self, processor: "AttnProcessor") -> None:
225
+ self.processor = processor
226
+
227
+ def forward(
228
+ self,
229
+ hidden_states: torch.FloatTensor,
230
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
231
+ attention_mask: Optional[torch.FloatTensor] = None,
232
+ **cross_attention_kwargs,
233
+ ) -> torch.Tensor:
234
+ r"""
235
+ The forward method of the `Attention` class.
236
+
237
+ Args:
238
+ hidden_states (`torch.Tensor`):
239
+ The hidden states of the query.
240
+ encoder_hidden_states (`torch.Tensor`, *optional*):
241
+ The hidden states of the encoder.
242
+ attention_mask (`torch.Tensor`, *optional*):
243
+ The attention mask to use. If `None`, no mask is applied.
244
+ **cross_attention_kwargs:
245
+ Additional keyword arguments to pass along to the cross attention.
246
+
247
+ Returns:
248
+ `torch.Tensor`: The output of the attention layer.
249
+ """
250
+ # The `Attention` class can call different attention processors / attention functions
251
+ # here we simply pass along all tensors to the selected processor class
252
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
253
+ return self.processor(
254
+ self,
255
+ hidden_states,
256
+ encoder_hidden_states=encoder_hidden_states,
257
+ attention_mask=attention_mask,
258
+ **cross_attention_kwargs,
259
+ )
260
+
261
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
262
+ r"""
263
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
264
+ is the number of heads initialized while constructing the `Attention` class.
265
+
266
+ Args:
267
+ tensor (`torch.Tensor`): The tensor to reshape.
268
+
269
+ Returns:
270
+ `torch.Tensor`: The reshaped tensor.
271
+ """
272
+ head_size = self.heads
273
+ batch_size, seq_len, dim = tensor.shape
274
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
275
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
276
+ batch_size // head_size, seq_len, dim * head_size
277
+ )
278
+ return tensor
279
+
280
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
281
+ r"""
282
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
283
+ the number of heads initialized while constructing the `Attention` class.
284
+
285
+ Args:
286
+ tensor (`torch.Tensor`): The tensor to reshape.
287
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
288
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
289
+
290
+ Returns:
291
+ `torch.Tensor`: The reshaped tensor.
292
+ """
293
+ head_size = self.heads
294
+ batch_size, seq_len, dim = tensor.shape
295
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
296
+ tensor = tensor.permute(0, 2, 1, 3)
297
+
298
+ if out_dim == 3:
299
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
300
+
301
+ return tensor
302
+
303
+ def get_attention_scores(
304
+ self,
305
+ query: torch.Tensor,
306
+ key: torch.Tensor,
307
+ attention_mask: torch.Tensor = None,
308
+ ) -> torch.Tensor:
309
+ r"""
310
+ Compute the attention scores.
311
+
312
+ Args:
313
+ query (`torch.Tensor`): The query tensor.
314
+ key (`torch.Tensor`): The key tensor.
315
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
316
+
317
+ Returns:
318
+ `torch.Tensor`: The attention probabilities/scores.
319
+ """
320
+ dtype = query.dtype
321
+ if self.upcast_attention:
322
+ query = query.float()
323
+ key = key.float()
324
+
325
+ if attention_mask is None:
326
+ baddbmm_input = torch.empty(
327
+ query.shape[0],
328
+ query.shape[1],
329
+ key.shape[1],
330
+ dtype=query.dtype,
331
+ device=query.device,
332
+ )
333
+ beta = 0
334
+ else:
335
+ baddbmm_input = attention_mask
336
+ beta = 1
337
+
338
+ attention_scores = torch.baddbmm(
339
+ baddbmm_input,
340
+ query,
341
+ key.transpose(-1, -2),
342
+ beta=beta,
343
+ alpha=self.scale,
344
+ )
345
+ del baddbmm_input
346
+
347
+ if self.upcast_softmax:
348
+ attention_scores = attention_scores.float()
349
+
350
+ attention_probs = attention_scores.softmax(dim=-1)
351
+ del attention_scores
352
+
353
+ attention_probs = attention_probs.to(dtype)
354
+
355
+ return attention_probs
356
+
357
+ def prepare_attention_mask(
358
+ self,
359
+ attention_mask: torch.Tensor,
360
+ target_length: int,
361
+ batch_size: int,
362
+ out_dim: int = 3,
363
+ ) -> torch.Tensor:
364
+ r"""
365
+ Prepare the attention mask for the attention computation.
366
+
367
+ Args:
368
+ attention_mask (`torch.Tensor`):
369
+ The attention mask to prepare.
370
+ target_length (`int`):
371
+ The target length of the attention mask. This is the length of the attention mask after padding.
372
+ batch_size (`int`):
373
+ The batch size, which is used to repeat the attention mask.
374
+ out_dim (`int`, *optional*, defaults to `3`):
375
+ The output dimension of the attention mask. Can be either `3` or `4`.
376
+
377
+ Returns:
378
+ `torch.Tensor`: The prepared attention mask.
379
+ """
380
+ head_size = self.heads
381
+ if attention_mask is None:
382
+ return attention_mask
383
+
384
+ current_length: int = attention_mask.shape[-1]
385
+ if current_length != target_length:
386
+ if attention_mask.device.type == "mps":
387
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
388
+ # Instead, we can manually construct the padding tensor.
389
+ padding_shape = (
390
+ attention_mask.shape[0],
391
+ attention_mask.shape[1],
392
+ target_length,
393
+ )
394
+ padding = torch.zeros(
395
+ padding_shape,
396
+ dtype=attention_mask.dtype,
397
+ device=attention_mask.device,
398
+ )
399
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
400
+ else:
401
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
402
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
403
+ # remaining_length: int = target_length - current_length
404
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
405
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
406
+
407
+ if out_dim == 3:
408
+ if attention_mask.shape[0] < batch_size * head_size:
409
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
410
+ elif out_dim == 4:
411
+ attention_mask = attention_mask.unsqueeze(1)
412
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
413
+
414
+ return attention_mask
415
+
416
+ def norm_encoder_hidden_states(
417
+ self, encoder_hidden_states: torch.Tensor
418
+ ) -> torch.Tensor:
419
+ r"""
420
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
421
+ `Attention` class.
422
+
423
+ Args:
424
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
425
+
426
+ Returns:
427
+ `torch.Tensor`: The normalized encoder hidden states.
428
+ """
429
+ assert (
430
+ self.norm_cross is not None
431
+ ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
432
+
433
+ if isinstance(self.norm_cross, nn.LayerNorm):
434
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
435
+ elif isinstance(self.norm_cross, nn.GroupNorm):
436
+ # Group norm norms along the channels dimension and expects
437
+ # input to be in the shape of (N, C, *). In this case, we want
438
+ # to norm along the hidden dimension, so we need to move
439
+ # (batch_size, sequence_length, hidden_size) ->
440
+ # (batch_size, hidden_size, sequence_length)
441
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
442
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
443
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
444
+ else:
445
+ assert False
446
+
447
+ return encoder_hidden_states
448
+
449
+ @torch.no_grad()
450
+ def fuse_projections(self, fuse=True):
451
+ is_cross_attention = self.cross_attention_dim != self.query_dim
452
+ device = self.to_q.weight.data.device
453
+ dtype = self.to_q.weight.data.dtype
454
+
455
+ if not is_cross_attention:
456
+ # fetch weight matrices.
457
+ concatenated_weights = torch.cat(
458
+ [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
459
+ )
460
+ in_features = concatenated_weights.shape[1]
461
+ out_features = concatenated_weights.shape[0]
462
+
463
+ # create a new single projection layer and copy over the weights.
464
+ self.to_qkv = self.linear_cls(
465
+ in_features, out_features, bias=False, device=device, dtype=dtype
466
+ )
467
+ self.to_qkv.weight.copy_(concatenated_weights)
468
+
469
+ else:
470
+ concatenated_weights = torch.cat(
471
+ [self.to_k.weight.data, self.to_v.weight.data]
472
+ )
473
+ in_features = concatenated_weights.shape[1]
474
+ out_features = concatenated_weights.shape[0]
475
+
476
+ self.to_kv = self.linear_cls(
477
+ in_features, out_features, bias=False, device=device, dtype=dtype
478
+ )
479
+ self.to_kv.weight.copy_(concatenated_weights)
480
+
481
+ self.fused_projections = fuse
482
+
483
+
484
+ class AttnProcessor:
485
+ r"""
486
+ Default processor for performing attention-related computations.
487
+ """
488
+
489
+ def __call__(
490
+ self,
491
+ attn: Attention,
492
+ hidden_states: torch.FloatTensor,
493
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
494
+ attention_mask: Optional[torch.FloatTensor] = None,
495
+ ) -> torch.Tensor:
496
+ residual = hidden_states
497
+
498
+ input_ndim = hidden_states.ndim
499
+
500
+ if input_ndim == 4:
501
+ batch_size, channel, height, width = hidden_states.shape
502
+ hidden_states = hidden_states.view(
503
+ batch_size, channel, height * width
504
+ ).transpose(1, 2)
505
+
506
+ batch_size, sequence_length, _ = (
507
+ hidden_states.shape
508
+ if encoder_hidden_states is None
509
+ else encoder_hidden_states.shape
510
+ )
511
+ attention_mask = attn.prepare_attention_mask(
512
+ attention_mask, sequence_length, batch_size
513
+ )
514
+
515
+ if attn.group_norm is not None:
516
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
517
+ 1, 2
518
+ )
519
+
520
+ query = attn.to_q(hidden_states)
521
+
522
+ if encoder_hidden_states is None:
523
+ encoder_hidden_states = hidden_states
524
+ elif attn.norm_cross:
525
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
526
+ encoder_hidden_states
527
+ )
528
+
529
+ key = attn.to_k(encoder_hidden_states)
530
+ value = attn.to_v(encoder_hidden_states)
531
+
532
+ query = attn.head_to_batch_dim(query)
533
+ key = attn.head_to_batch_dim(key)
534
+ value = attn.head_to_batch_dim(value)
535
+
536
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
537
+ hidden_states = torch.bmm(attention_probs, value)
538
+ hidden_states = attn.batch_to_head_dim(hidden_states)
539
+
540
+ # linear proj
541
+ hidden_states = attn.to_out[0](hidden_states)
542
+ # dropout
543
+ hidden_states = attn.to_out[1](hidden_states)
544
+
545
+ if input_ndim == 4:
546
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
547
+ batch_size, channel, height, width
548
+ )
549
+
550
+ if attn.residual_connection:
551
+ hidden_states = hidden_states + residual
552
+
553
+ hidden_states = hidden_states / attn.rescale_output_factor
554
+
555
+ return hidden_states
556
+
557
+
558
+ class AttnProcessor2_0:
559
+ r"""
560
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
561
+ """
562
+
563
+ def __init__(self):
564
+ if not hasattr(F, "scaled_dot_product_attention"):
565
+ raise ImportError(
566
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
567
+ )
568
+
569
+ def __call__(
570
+ self,
571
+ attn: Attention,
572
+ hidden_states: torch.FloatTensor,
573
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
574
+ attention_mask: Optional[torch.FloatTensor] = None,
575
+ ) -> torch.FloatTensor:
576
+ residual = hidden_states
577
+
578
+ input_ndim = hidden_states.ndim
579
+
580
+ if input_ndim == 4:
581
+ batch_size, channel, height, width = hidden_states.shape
582
+ hidden_states = hidden_states.view(
583
+ batch_size, channel, height * width
584
+ ).transpose(1, 2)
585
+
586
+ batch_size, sequence_length, _ = (
587
+ hidden_states.shape
588
+ if encoder_hidden_states is None
589
+ else encoder_hidden_states.shape
590
+ )
591
+
592
+ if attention_mask is not None:
593
+ attention_mask = attn.prepare_attention_mask(
594
+ attention_mask, sequence_length, batch_size
595
+ )
596
+ # scaled_dot_product_attention expects attention_mask shape to be
597
+ # (batch, heads, source_length, target_length)
598
+ attention_mask = attention_mask.view(
599
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
600
+ )
601
+
602
+ if attn.group_norm is not None:
603
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
604
+ 1, 2
605
+ )
606
+
607
+ query = attn.to_q(hidden_states)
608
+
609
+ if encoder_hidden_states is None:
610
+ encoder_hidden_states = hidden_states
611
+ elif attn.norm_cross:
612
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
613
+ encoder_hidden_states
614
+ )
615
+
616
+ key = attn.to_k(encoder_hidden_states)
617
+ value = attn.to_v(encoder_hidden_states)
618
+
619
+ inner_dim = key.shape[-1]
620
+ head_dim = inner_dim // attn.heads
621
+
622
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
623
+
624
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
625
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
626
+
627
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
628
+ # TODO: add support for attn.scale when we move to Torch 2.1
629
+ hidden_states = F.scaled_dot_product_attention(
630
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
631
+ )
632
+
633
+ hidden_states = hidden_states.transpose(1, 2).reshape(
634
+ batch_size, -1, attn.heads * head_dim
635
+ )
636
+ hidden_states = hidden_states.to(query.dtype)
637
+
638
+ # linear proj
639
+ hidden_states = attn.to_out[0](hidden_states)
640
+ # dropout
641
+ hidden_states = attn.to_out[1](hidden_states)
642
+
643
+ if input_ndim == 4:
644
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
645
+ batch_size, channel, height, width
646
+ )
647
+
648
+ if attn.residual_connection:
649
+ hidden_states = hidden_states + residual
650
+
651
+ hidden_states = hidden_states / attn.rescale_output_factor
652
+
653
+ return hidden_states
tsr/models/transformer/basic_transformer_block.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ from typing import Optional
40
+
41
+ import torch
42
+ import torch.nn.functional as F
43
+ from torch import nn
44
+
45
+ from .attention import Attention
46
+
47
+
48
+ class BasicTransformerBlock(nn.Module):
49
+ r"""
50
+ A basic Transformer block.
51
+
52
+ Parameters:
53
+ dim (`int`): The number of channels in the input and output.
54
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
55
+ attention_head_dim (`int`): The number of channels in each head.
56
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
57
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
58
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
59
+ attention_bias (:
60
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
61
+ only_cross_attention (`bool`, *optional*):
62
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
63
+ double_self_attention (`bool`, *optional*):
64
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
65
+ upcast_attention (`bool`, *optional*):
66
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
67
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
68
+ Whether to use learnable elementwise affine parameters for normalization.
69
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
70
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
71
+ final_dropout (`bool` *optional*, defaults to False):
72
+ Whether to apply a final dropout after the last feed-forward layer.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ dim: int,
78
+ num_attention_heads: int,
79
+ attention_head_dim: int,
80
+ dropout=0.0,
81
+ cross_attention_dim: Optional[int] = None,
82
+ activation_fn: str = "geglu",
83
+ attention_bias: bool = False,
84
+ only_cross_attention: bool = False,
85
+ double_self_attention: bool = False,
86
+ upcast_attention: bool = False,
87
+ norm_elementwise_affine: bool = True,
88
+ norm_type: str = "layer_norm",
89
+ final_dropout: bool = False,
90
+ ):
91
+ super().__init__()
92
+ self.only_cross_attention = only_cross_attention
93
+
94
+ assert norm_type == "layer_norm"
95
+
96
+ # Define 3 blocks. Each block has its own normalization layer.
97
+ # 1. Self-Attn
98
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
99
+ self.attn1 = Attention(
100
+ query_dim=dim,
101
+ heads=num_attention_heads,
102
+ dim_head=attention_head_dim,
103
+ dropout=dropout,
104
+ bias=attention_bias,
105
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
106
+ upcast_attention=upcast_attention,
107
+ )
108
+
109
+ # 2. Cross-Attn
110
+ if cross_attention_dim is not None or double_self_attention:
111
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
112
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
113
+ # the second cross attention block.
114
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
115
+
116
+ self.attn2 = Attention(
117
+ query_dim=dim,
118
+ cross_attention_dim=(
119
+ cross_attention_dim if not double_self_attention else None
120
+ ),
121
+ heads=num_attention_heads,
122
+ dim_head=attention_head_dim,
123
+ dropout=dropout,
124
+ bias=attention_bias,
125
+ upcast_attention=upcast_attention,
126
+ ) # is self-attn if encoder_hidden_states is none
127
+ else:
128
+ self.norm2 = None
129
+ self.attn2 = None
130
+
131
+ # 3. Feed-forward
132
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
133
+ self.ff = FeedForward(
134
+ dim,
135
+ dropout=dropout,
136
+ activation_fn=activation_fn,
137
+ final_dropout=final_dropout,
138
+ )
139
+
140
+ # let chunk size default to None
141
+ self._chunk_size = None
142
+ self._chunk_dim = 0
143
+
144
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
145
+ # Sets chunk feed-forward
146
+ self._chunk_size = chunk_size
147
+ self._chunk_dim = dim
148
+
149
+ def forward(
150
+ self,
151
+ hidden_states: torch.FloatTensor,
152
+ attention_mask: Optional[torch.FloatTensor] = None,
153
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
154
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
155
+ ) -> torch.FloatTensor:
156
+ # Notice that normalization is always applied before the real computation in the following blocks.
157
+ # 0. Self-Attention
158
+ norm_hidden_states = self.norm1(hidden_states)
159
+
160
+ attn_output = self.attn1(
161
+ norm_hidden_states,
162
+ encoder_hidden_states=(
163
+ encoder_hidden_states if self.only_cross_attention else None
164
+ ),
165
+ attention_mask=attention_mask,
166
+ )
167
+
168
+ hidden_states = attn_output + hidden_states
169
+
170
+ # 3. Cross-Attention
171
+ if self.attn2 is not None:
172
+ norm_hidden_states = self.norm2(hidden_states)
173
+
174
+ attn_output = self.attn2(
175
+ norm_hidden_states,
176
+ encoder_hidden_states=encoder_hidden_states,
177
+ attention_mask=encoder_attention_mask,
178
+ )
179
+ hidden_states = attn_output + hidden_states
180
+
181
+ # 4. Feed-forward
182
+ norm_hidden_states = self.norm3(hidden_states)
183
+
184
+ if self._chunk_size is not None:
185
+ # "feed_forward_chunk_size" can be used to save memory
186
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
187
+ raise ValueError(
188
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
189
+ )
190
+
191
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
192
+ ff_output = torch.cat(
193
+ [
194
+ self.ff(hid_slice)
195
+ for hid_slice in norm_hidden_states.chunk(
196
+ num_chunks, dim=self._chunk_dim
197
+ )
198
+ ],
199
+ dim=self._chunk_dim,
200
+ )
201
+ else:
202
+ ff_output = self.ff(norm_hidden_states)
203
+
204
+ hidden_states = ff_output + hidden_states
205
+
206
+ return hidden_states
207
+
208
+
209
+ class FeedForward(nn.Module):
210
+ r"""
211
+ A feed-forward layer.
212
+
213
+ Parameters:
214
+ dim (`int`): The number of channels in the input.
215
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
216
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
217
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
218
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
219
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ dim: int,
225
+ dim_out: Optional[int] = None,
226
+ mult: int = 4,
227
+ dropout: float = 0.0,
228
+ activation_fn: str = "geglu",
229
+ final_dropout: bool = False,
230
+ ):
231
+ super().__init__()
232
+ inner_dim = int(dim * mult)
233
+ dim_out = dim_out if dim_out is not None else dim
234
+ linear_cls = nn.Linear
235
+
236
+ if activation_fn == "gelu":
237
+ act_fn = GELU(dim, inner_dim)
238
+ if activation_fn == "gelu-approximate":
239
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
240
+ elif activation_fn == "geglu":
241
+ act_fn = GEGLU(dim, inner_dim)
242
+ elif activation_fn == "geglu-approximate":
243
+ act_fn = ApproximateGELU(dim, inner_dim)
244
+
245
+ self.net = nn.ModuleList([])
246
+ # project in
247
+ self.net.append(act_fn)
248
+ # project dropout
249
+ self.net.append(nn.Dropout(dropout))
250
+ # project out
251
+ self.net.append(linear_cls(inner_dim, dim_out))
252
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
253
+ if final_dropout:
254
+ self.net.append(nn.Dropout(dropout))
255
+
256
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
257
+ for module in self.net:
258
+ hidden_states = module(hidden_states)
259
+ return hidden_states
260
+
261
+
262
+ class GELU(nn.Module):
263
+ r"""
264
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
265
+
266
+ Parameters:
267
+ dim_in (`int`): The number of channels in the input.
268
+ dim_out (`int`): The number of channels in the output.
269
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
270
+ """
271
+
272
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
273
+ super().__init__()
274
+ self.proj = nn.Linear(dim_in, dim_out)
275
+ self.approximate = approximate
276
+
277
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
278
+ if gate.device.type != "mps":
279
+ return F.gelu(gate, approximate=self.approximate)
280
+ # mps: gelu is not implemented for float16
281
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
282
+ dtype=gate.dtype
283
+ )
284
+
285
+ def forward(self, hidden_states):
286
+ hidden_states = self.proj(hidden_states)
287
+ hidden_states = self.gelu(hidden_states)
288
+ return hidden_states
289
+
290
+
291
+ class GEGLU(nn.Module):
292
+ r"""
293
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
294
+
295
+ Parameters:
296
+ dim_in (`int`): The number of channels in the input.
297
+ dim_out (`int`): The number of channels in the output.
298
+ """
299
+
300
+ def __init__(self, dim_in: int, dim_out: int):
301
+ super().__init__()
302
+ linear_cls = nn.Linear
303
+
304
+ self.proj = linear_cls(dim_in, dim_out * 2)
305
+
306
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
307
+ if gate.device.type != "mps":
308
+ return F.gelu(gate)
309
+ # mps: gelu is not implemented for float16
310
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
311
+
312
+ def forward(self, hidden_states, scale: float = 1.0):
313
+ args = ()
314
+ hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
315
+ return hidden_states * self.gelu(gate)
316
+
317
+
318
+ class ApproximateGELU(nn.Module):
319
+ r"""
320
+ The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
321
+ https://arxiv.org/abs/1606.08415.
322
+
323
+ Parameters:
324
+ dim_in (`int`): The number of channels in the input.
325
+ dim_out (`int`): The number of channels in the output.
326
+ """
327
+
328
+ def __init__(self, dim_in: int, dim_out: int):
329
+ super().__init__()
330
+ self.proj = nn.Linear(dim_in, dim_out)
331
+
332
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
+ x = self.proj(x)
334
+ return x * torch.sigmoid(1.702 * x)
tsr/models/transformer/transformer_1d.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ from dataclasses import dataclass
40
+ from typing import Optional
41
+
42
+ import torch
43
+ import torch.nn.functional as F
44
+ from torch import nn
45
+
46
+ from ...utils import BaseModule
47
+ from .basic_transformer_block import BasicTransformerBlock
48
+
49
+
50
+ class Transformer1D(BaseModule):
51
+ @dataclass
52
+ class Config(BaseModule.Config):
53
+ num_attention_heads: int = 16
54
+ attention_head_dim: int = 88
55
+ in_channels: Optional[int] = None
56
+ out_channels: Optional[int] = None
57
+ num_layers: int = 1
58
+ dropout: float = 0.0
59
+ norm_num_groups: int = 32
60
+ cross_attention_dim: Optional[int] = None
61
+ attention_bias: bool = False
62
+ activation_fn: str = "geglu"
63
+ only_cross_attention: bool = False
64
+ double_self_attention: bool = False
65
+ upcast_attention: bool = False
66
+ norm_type: str = "layer_norm"
67
+ norm_elementwise_affine: bool = True
68
+ gradient_checkpointing: bool = False
69
+
70
+ cfg: Config
71
+
72
+ def configure(self) -> None:
73
+ self.num_attention_heads = self.cfg.num_attention_heads
74
+ self.attention_head_dim = self.cfg.attention_head_dim
75
+ inner_dim = self.num_attention_heads * self.attention_head_dim
76
+
77
+ linear_cls = nn.Linear
78
+
79
+ # 2. Define input layers
80
+ self.in_channels = self.cfg.in_channels
81
+
82
+ self.norm = torch.nn.GroupNorm(
83
+ num_groups=self.cfg.norm_num_groups,
84
+ num_channels=self.cfg.in_channels,
85
+ eps=1e-6,
86
+ affine=True,
87
+ )
88
+ self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
89
+
90
+ # 3. Define transformers blocks
91
+ self.transformer_blocks = nn.ModuleList(
92
+ [
93
+ BasicTransformerBlock(
94
+ inner_dim,
95
+ self.num_attention_heads,
96
+ self.attention_head_dim,
97
+ dropout=self.cfg.dropout,
98
+ cross_attention_dim=self.cfg.cross_attention_dim,
99
+ activation_fn=self.cfg.activation_fn,
100
+ attention_bias=self.cfg.attention_bias,
101
+ only_cross_attention=self.cfg.only_cross_attention,
102
+ double_self_attention=self.cfg.double_self_attention,
103
+ upcast_attention=self.cfg.upcast_attention,
104
+ norm_type=self.cfg.norm_type,
105
+ norm_elementwise_affine=self.cfg.norm_elementwise_affine,
106
+ )
107
+ for d in range(self.cfg.num_layers)
108
+ ]
109
+ )
110
+
111
+ # 4. Define output layers
112
+ self.out_channels = (
113
+ self.cfg.in_channels
114
+ if self.cfg.out_channels is None
115
+ else self.cfg.out_channels
116
+ )
117
+
118
+ self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
119
+
120
+ self.gradient_checkpointing = self.cfg.gradient_checkpointing
121
+
122
+ def forward(
123
+ self,
124
+ hidden_states: torch.Tensor,
125
+ encoder_hidden_states: Optional[torch.Tensor] = None,
126
+ attention_mask: Optional[torch.Tensor] = None,
127
+ encoder_attention_mask: Optional[torch.Tensor] = None,
128
+ ):
129
+ """
130
+ The [`Transformer1DModel`] forward method.
131
+
132
+ Args:
133
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
134
+ Input `hidden_states`.
135
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
136
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
137
+ self-attention.
138
+ attention_mask ( `torch.Tensor`, *optional*):
139
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
140
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
141
+ negative values to the attention scores corresponding to "discard" tokens.
142
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
143
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
144
+
145
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
146
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
147
+
148
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
149
+ above. This bias will be added to the cross-attention scores.
150
+
151
+ Returns:
152
+ torch.FloatTensor
153
+ """
154
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
155
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
156
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
157
+ # expects mask of shape:
158
+ # [batch, key_tokens]
159
+ # adds singleton query_tokens dimension:
160
+ # [batch, 1, key_tokens]
161
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
162
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
163
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
164
+ if attention_mask is not None and attention_mask.ndim == 2:
165
+ # assume that mask is expressed as:
166
+ # (1 = keep, 0 = discard)
167
+ # convert mask into a bias that can be added to attention scores:
168
+ # (keep = +0, discard = -10000.0)
169
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
170
+ attention_mask = attention_mask.unsqueeze(1)
171
+
172
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
173
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
174
+ encoder_attention_mask = (
175
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
176
+ ) * -10000.0
177
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
178
+
179
+ # 1. Input
180
+ batch, _, seq_len = hidden_states.shape
181
+ residual = hidden_states
182
+
183
+ hidden_states = self.norm(hidden_states)
184
+ inner_dim = hidden_states.shape[1]
185
+ hidden_states = hidden_states.permute(0, 2, 1).reshape(
186
+ batch, seq_len, inner_dim
187
+ )
188
+ hidden_states = self.proj_in(hidden_states)
189
+
190
+ # 2. Blocks
191
+ for block in self.transformer_blocks:
192
+ if self.training and self.gradient_checkpointing:
193
+ hidden_states = torch.utils.checkpoint.checkpoint(
194
+ block,
195
+ hidden_states,
196
+ attention_mask,
197
+ encoder_hidden_states,
198
+ encoder_attention_mask,
199
+ use_reentrant=False,
200
+ )
201
+ else:
202
+ hidden_states = block(
203
+ hidden_states,
204
+ attention_mask=attention_mask,
205
+ encoder_hidden_states=encoder_hidden_states,
206
+ encoder_attention_mask=encoder_attention_mask,
207
+ )
208
+
209
+ # 3. Output
210
+ hidden_states = self.proj_out(hidden_states)
211
+ hidden_states = (
212
+ hidden_states.reshape(batch, seq_len, inner_dim)
213
+ .permute(0, 2, 1)
214
+ .contiguous()
215
+ )
216
+
217
+ output = hidden_states + residual
218
+
219
+ return output
tsr/system.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from dataclasses import dataclass, field
4
+ from typing import List, Union
5
+
6
+ import numpy as np
7
+ import PIL.Image
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import trimesh
11
+ from einops import rearrange
12
+ from huggingface_hub import hf_hub_download
13
+ from omegaconf import OmegaConf
14
+ from PIL import Image
15
+
16
+ from .models.isosurface import MarchingCubeHelper
17
+ from .utils import (
18
+ BaseModule,
19
+ ImagePreprocessor,
20
+ find_class,
21
+ get_spherical_cameras,
22
+ scale_tensor,
23
+ )
24
+
25
+
26
+ class TSR(BaseModule):
27
+ @dataclass
28
+ class Config(BaseModule.Config):
29
+ cond_image_size: int
30
+
31
+ image_tokenizer_cls: str
32
+ image_tokenizer: dict
33
+
34
+ tokenizer_cls: str
35
+ tokenizer: dict
36
+
37
+ backbone_cls: str
38
+ backbone: dict
39
+
40
+ post_processor_cls: str
41
+ post_processor: dict
42
+
43
+ decoder_cls: str
44
+ decoder: dict
45
+
46
+ renderer_cls: str
47
+ renderer: dict
48
+
49
+ cfg: Config
50
+
51
+ @classmethod
52
+ def from_pretrained(
53
+ cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
54
+ ):
55
+ if os.path.isdir(pretrained_model_name_or_path):
56
+ config_path = os.path.join(pretrained_model_name_or_path, config_name)
57
+ weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
58
+ else:
59
+ config_path = hf_hub_download(
60
+ repo_id=pretrained_model_name_or_path, filename=config_name
61
+ )
62
+ weight_path = hf_hub_download(
63
+ repo_id=pretrained_model_name_or_path, filename=weight_name
64
+ )
65
+
66
+ cfg = OmegaConf.load(config_path)
67
+ OmegaConf.resolve(cfg)
68
+ model = cls(cfg)
69
+ ckpt = torch.load(weight_path, map_location="cpu")
70
+ model.load_state_dict(ckpt)
71
+ return model
72
+
73
+ def configure(self):
74
+ self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
75
+ self.cfg.image_tokenizer
76
+ )
77
+ self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
78
+ self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
79
+ self.post_processor = find_class(self.cfg.post_processor_cls)(
80
+ self.cfg.post_processor
81
+ )
82
+ self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
83
+ self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer)
84
+ self.image_processor = ImagePreprocessor()
85
+ self.isosurface_helper = None
86
+
87
+ def forward(
88
+ self,
89
+ image: Union[
90
+ PIL.Image.Image,
91
+ np.ndarray,
92
+ torch.FloatTensor,
93
+ List[PIL.Image.Image],
94
+ List[np.ndarray],
95
+ List[torch.FloatTensor],
96
+ ],
97
+ device: str,
98
+ ) -> torch.FloatTensor:
99
+ rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to(
100
+ device
101
+ )
102
+ batch_size = rgb_cond.shape[0]
103
+
104
+ input_image_tokens: torch.Tensor = self.image_tokenizer(
105
+ rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1),
106
+ )
107
+
108
+ input_image_tokens = rearrange(
109
+ input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1
110
+ )
111
+
112
+ tokens: torch.Tensor = self.tokenizer(batch_size)
113
+
114
+ tokens = self.backbone(
115
+ tokens,
116
+ encoder_hidden_states=input_image_tokens,
117
+ )
118
+
119
+ scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
120
+ return scene_codes
121
+
122
+ def render(
123
+ self,
124
+ scene_codes,
125
+ n_views: int,
126
+ elevation_deg: float = 0.0,
127
+ camera_distance: float = 1.9,
128
+ fovy_deg: float = 40.0,
129
+ height: int = 256,
130
+ width: int = 256,
131
+ return_type: str = "pil",
132
+ ):
133
+ rays_o, rays_d = get_spherical_cameras(
134
+ n_views, elevation_deg, camera_distance, fovy_deg, height, width
135
+ )
136
+ rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device)
137
+
138
+ def process_output(image: torch.FloatTensor):
139
+ if return_type == "pt":
140
+ return image
141
+ elif return_type == "np":
142
+ return image.detach().cpu().numpy()
143
+ elif return_type == "pil":
144
+ return Image.fromarray(
145
+ (image.detach().cpu().numpy() * 255.0).astype(np.uint8)
146
+ )
147
+ else:
148
+ raise NotImplementedError
149
+
150
+ images = []
151
+ for scene_code in scene_codes:
152
+ images_ = []
153
+ for i in range(n_views):
154
+ with torch.no_grad():
155
+ image = self.renderer(
156
+ self.decoder, scene_code, rays_o[i], rays_d[i]
157
+ )
158
+ images_.append(process_output(image))
159
+ images.append(images_)
160
+
161
+ return images
162
+
163
+ def set_marching_cubes_resolution(self, resolution: int):
164
+ if (
165
+ self.isosurface_helper is not None
166
+ and self.isosurface_helper.resolution == resolution
167
+ ):
168
+ return
169
+ self.isosurface_helper = MarchingCubeHelper(resolution)
170
+
171
+ def extract_mesh(self, scene_codes, has_vertex_color, resolution: int = 256, threshold: float = 25.0):
172
+ self.set_marching_cubes_resolution(resolution)
173
+ meshes = []
174
+ for scene_code in scene_codes:
175
+ with torch.no_grad():
176
+ density = self.renderer.query_triplane(
177
+ self.decoder,
178
+ scale_tensor(
179
+ self.isosurface_helper.grid_vertices.to(scene_codes.device),
180
+ self.isosurface_helper.points_range,
181
+ (-self.renderer.cfg.radius, self.renderer.cfg.radius),
182
+ ),
183
+ scene_code,
184
+ )["density_act"]
185
+ v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
186
+ v_pos = scale_tensor(
187
+ v_pos,
188
+ self.isosurface_helper.points_range,
189
+ (-self.renderer.cfg.radius, self.renderer.cfg.radius),
190
+ )
191
+ color = None
192
+ if has_vertex_color:
193
+ with torch.no_grad():
194
+ color = self.renderer.query_triplane(
195
+ self.decoder,
196
+ v_pos,
197
+ scene_code,
198
+ )["color"]
199
+ mesh = trimesh.Trimesh(
200
+ vertices=v_pos.cpu().numpy(),
201
+ faces=t_pos_idx.cpu().numpy(),
202
+ vertex_colors=color.cpu().numpy() if has_vertex_color else None,
203
+ )
204
+ meshes.append(mesh)
205
+ return meshes
tsr/utils.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import math
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ import imageio
8
+ import numpy as np
9
+ import PIL.Image
10
+ # import rembg
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import trimesh
15
+ from omegaconf import DictConfig, OmegaConf
16
+ from PIL import Image
17
+
18
+
19
+ def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
20
+ scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
21
+ return scfg
22
+
23
+
24
+ def find_class(cls_string):
25
+ module_string = ".".join(cls_string.split(".")[:-1])
26
+ cls_name = cls_string.split(".")[-1]
27
+ module = importlib.import_module(module_string, package=None)
28
+ cls = getattr(module, cls_name)
29
+ return cls
30
+
31
+
32
+ def get_intrinsic_from_fov(fov, H, W, bs=-1):
33
+ focal_length = 0.5 * H / np.tan(0.5 * fov)
34
+ intrinsic = np.identity(3, dtype=np.float32)
35
+ intrinsic[0, 0] = focal_length
36
+ intrinsic[1, 1] = focal_length
37
+ intrinsic[0, 2] = W / 2.0
38
+ intrinsic[1, 2] = H / 2.0
39
+
40
+ if bs > 0:
41
+ intrinsic = intrinsic[None].repeat(bs, axis=0)
42
+
43
+ return torch.from_numpy(intrinsic)
44
+
45
+
46
+ class BaseModule(nn.Module):
47
+ @dataclass
48
+ class Config:
49
+ pass
50
+
51
+ cfg: Config # add this to every subclass of BaseModule to enable static type checking
52
+
53
+ def __init__(
54
+ self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
55
+ ) -> None:
56
+ super().__init__()
57
+ self.cfg = parse_structured(self.Config, cfg)
58
+ self.configure(*args, **kwargs)
59
+
60
+ def configure(self, *args, **kwargs) -> None:
61
+ raise NotImplementedError
62
+
63
+
64
+ class ImagePreprocessor:
65
+ def convert_and_resize(
66
+ self,
67
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
68
+ size: int,
69
+ ):
70
+ if isinstance(image, PIL.Image.Image):
71
+ image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
72
+ elif isinstance(image, np.ndarray):
73
+ if image.dtype == np.uint8:
74
+ image = torch.from_numpy(image.astype(np.float32) / 255.0)
75
+ else:
76
+ image = torch.from_numpy(image)
77
+ elif isinstance(image, torch.Tensor):
78
+ pass
79
+
80
+ batched = image.ndim == 4
81
+
82
+ if not batched:
83
+ image = image[None, ...]
84
+ image = F.interpolate(
85
+ image.permute(0, 3, 1, 2),
86
+ (size, size),
87
+ mode="bilinear",
88
+ align_corners=False,
89
+ antialias=True,
90
+ ).permute(0, 2, 3, 1)
91
+ if not batched:
92
+ image = image[0]
93
+ return image
94
+
95
+ def __call__(
96
+ self,
97
+ image: Union[
98
+ PIL.Image.Image,
99
+ np.ndarray,
100
+ torch.FloatTensor,
101
+ List[PIL.Image.Image],
102
+ List[np.ndarray],
103
+ List[torch.FloatTensor],
104
+ ],
105
+ size: int,
106
+ ) -> Any:
107
+ if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
108
+ image = self.convert_and_resize(image, size)
109
+ else:
110
+ if not isinstance(image, list):
111
+ image = [image]
112
+ image = [self.convert_and_resize(im, size) for im in image]
113
+ image = torch.stack(image, dim=0)
114
+ return image
115
+
116
+
117
+ def rays_intersect_bbox(
118
+ rays_o: torch.Tensor,
119
+ rays_d: torch.Tensor,
120
+ radius: float,
121
+ near: float = 0.0,
122
+ valid_thresh: float = 0.01,
123
+ ):
124
+ input_shape = rays_o.shape[:-1]
125
+ rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3)
126
+ rays_d_valid = torch.where(
127
+ rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d
128
+ )
129
+ if type(radius) in [int, float]:
130
+ radius = torch.FloatTensor(
131
+ [[-radius, radius], [-radius, radius], [-radius, radius]]
132
+ ).to(rays_o.device)
133
+ radius = (
134
+ 1.0 - 1.0e-3
135
+ ) * radius # tighten the radius to make sure the intersection point lies in the bounding box
136
+ interx0 = (radius[..., 1] - rays_o) / rays_d_valid
137
+ interx1 = (radius[..., 0] - rays_o) / rays_d_valid
138
+ t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near)
139
+ t_far = torch.maximum(interx0, interx1).amin(dim=-1)
140
+
141
+ # check wheter a ray intersects the bbox or not
142
+ rays_valid = t_far - t_near > valid_thresh
143
+
144
+ t_near[torch.where(~rays_valid)] = 0.0
145
+ t_far[torch.where(~rays_valid)] = 0.0
146
+
147
+ t_near = t_near.view(*input_shape, 1)
148
+ t_far = t_far.view(*input_shape, 1)
149
+ rays_valid = rays_valid.view(*input_shape)
150
+
151
+ return t_near, t_far, rays_valid
152
+
153
+
154
+ def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any:
155
+ if chunk_size <= 0:
156
+ return func(*args, **kwargs)
157
+ B = None
158
+ for arg in list(args) + list(kwargs.values()):
159
+ if isinstance(arg, torch.Tensor):
160
+ B = arg.shape[0]
161
+ break
162
+ assert (
163
+ B is not None
164
+ ), "No tensor found in args or kwargs, cannot determine batch size."
165
+ out = defaultdict(list)
166
+ out_type = None
167
+ # max(1, B) to support B == 0
168
+ for i in range(0, max(1, B), chunk_size):
169
+ out_chunk = func(
170
+ *[
171
+ arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
172
+ for arg in args
173
+ ],
174
+ **{
175
+ k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
176
+ for k, arg in kwargs.items()
177
+ },
178
+ )
179
+ if out_chunk is None:
180
+ continue
181
+ out_type = type(out_chunk)
182
+ if isinstance(out_chunk, torch.Tensor):
183
+ out_chunk = {0: out_chunk}
184
+ elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
185
+ chunk_length = len(out_chunk)
186
+ out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
187
+ elif isinstance(out_chunk, dict):
188
+ pass
189
+ else:
190
+ print(
191
+ f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
192
+ )
193
+ exit(1)
194
+ for k, v in out_chunk.items():
195
+ v = v if torch.is_grad_enabled() else v.detach()
196
+ out[k].append(v)
197
+
198
+ if out_type is None:
199
+ return None
200
+
201
+ out_merged: Dict[Any, Optional[torch.Tensor]] = {}
202
+ for k, v in out.items():
203
+ if all([vv is None for vv in v]):
204
+ # allow None in return value
205
+ out_merged[k] = None
206
+ elif all([isinstance(vv, torch.Tensor) for vv in v]):
207
+ out_merged[k] = torch.cat(v, dim=0)
208
+ else:
209
+ raise TypeError(
210
+ f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
211
+ )
212
+
213
+ if out_type is torch.Tensor:
214
+ return out_merged[0]
215
+ elif out_type in [tuple, list]:
216
+ return out_type([out_merged[i] for i in range(chunk_length)])
217
+ elif out_type is dict:
218
+ return out_merged
219
+
220
+
221
+ ValidScale = Union[Tuple[float, float], torch.FloatTensor]
222
+
223
+
224
+ def scale_tensor(dat: torch.FloatTensor, inp_scale: ValidScale, tgt_scale: ValidScale):
225
+ if inp_scale is None:
226
+ inp_scale = (0, 1)
227
+ if tgt_scale is None:
228
+ tgt_scale = (0, 1)
229
+ if isinstance(tgt_scale, torch.FloatTensor):
230
+ assert dat.shape[-1] == tgt_scale.shape[-1]
231
+ dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
232
+ dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
233
+ return dat
234
+
235
+
236
+ def get_activation(name) -> Callable:
237
+ if name is None:
238
+ return lambda x: x
239
+ name = name.lower()
240
+ if name == "none":
241
+ return lambda x: x
242
+ elif name == "exp":
243
+ return lambda x: torch.exp(x)
244
+ elif name == "sigmoid":
245
+ return lambda x: torch.sigmoid(x)
246
+ elif name == "tanh":
247
+ return lambda x: torch.tanh(x)
248
+ elif name == "softplus":
249
+ return lambda x: F.softplus(x)
250
+ else:
251
+ try:
252
+ return getattr(F, name)
253
+ except AttributeError:
254
+ raise ValueError(f"Unknown activation function: {name}")
255
+
256
+
257
+ def get_ray_directions(
258
+ H: int,
259
+ W: int,
260
+ focal: Union[float, Tuple[float, float]],
261
+ principal: Optional[Tuple[float, float]] = None,
262
+ use_pixel_centers: bool = True,
263
+ normalize: bool = True,
264
+ ) -> torch.FloatTensor:
265
+ """
266
+ Get ray directions for all pixels in camera coordinate.
267
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
268
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
269
+
270
+ Inputs:
271
+ H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
272
+ Outputs:
273
+ directions: (H, W, 3), the direction of the rays in camera coordinate
274
+ """
275
+ pixel_center = 0.5 if use_pixel_centers else 0
276
+
277
+ if isinstance(focal, float):
278
+ fx, fy = focal, focal
279
+ cx, cy = W / 2, H / 2
280
+ else:
281
+ fx, fy = focal
282
+ assert principal is not None
283
+ cx, cy = principal
284
+
285
+ i, j = torch.meshgrid(
286
+ torch.arange(W, dtype=torch.float32) + pixel_center,
287
+ torch.arange(H, dtype=torch.float32) + pixel_center,
288
+ indexing="xy",
289
+ )
290
+
291
+ directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1)
292
+
293
+ if normalize:
294
+ directions = F.normalize(directions, dim=-1)
295
+
296
+ return directions
297
+
298
+
299
+ def get_rays(
300
+ directions,
301
+ c2w,
302
+ keepdim=False,
303
+ normalize=False,
304
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
305
+ # Rotate ray directions from camera coordinate to the world coordinate
306
+ assert directions.shape[-1] == 3
307
+
308
+ if directions.ndim == 2: # (N_rays, 3)
309
+ if c2w.ndim == 2: # (4, 4)
310
+ c2w = c2w[None, :, :]
311
+ assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4)
312
+ rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3)
313
+ rays_o = c2w[:, :3, 3].expand(rays_d.shape)
314
+ elif directions.ndim == 3: # (H, W, 3)
315
+ assert c2w.ndim in [2, 3]
316
+ if c2w.ndim == 2: # (4, 4)
317
+ rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
318
+ -1
319
+ ) # (H, W, 3)
320
+ rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
321
+ elif c2w.ndim == 3: # (B, 4, 4)
322
+ rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
323
+ -1
324
+ ) # (B, H, W, 3)
325
+ rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
326
+ elif directions.ndim == 4: # (B, H, W, 3)
327
+ assert c2w.ndim == 3 # (B, 4, 4)
328
+ rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
329
+ -1
330
+ ) # (B, H, W, 3)
331
+ rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
332
+
333
+ if normalize:
334
+ rays_d = F.normalize(rays_d, dim=-1)
335
+ if not keepdim:
336
+ rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
337
+
338
+ return rays_o, rays_d
339
+
340
+
341
+ def get_spherical_cameras(
342
+ n_views: int,
343
+ elevation_deg: float,
344
+ camera_distance: float,
345
+ fovy_deg: float,
346
+ height: int,
347
+ width: int,
348
+ ):
349
+ azimuth_deg = torch.linspace(0, 360.0, n_views + 1)[:n_views]
350
+ elevation_deg = torch.full_like(azimuth_deg, elevation_deg)
351
+ camera_distances = torch.full_like(elevation_deg, camera_distance)
352
+
353
+ elevation = elevation_deg * math.pi / 180
354
+ azimuth = azimuth_deg * math.pi / 180
355
+
356
+ # convert spherical coordinates to cartesian coordinates
357
+ # right hand coordinate system, x back, y right, z up
358
+ # elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
359
+ camera_positions = torch.stack(
360
+ [
361
+ camera_distances * torch.cos(elevation) * torch.cos(azimuth),
362
+ camera_distances * torch.cos(elevation) * torch.sin(azimuth),
363
+ camera_distances * torch.sin(elevation),
364
+ ],
365
+ dim=-1,
366
+ )
367
+
368
+ # default scene center at origin
369
+ center = torch.zeros_like(camera_positions)
370
+ # default camera up direction as +z
371
+ up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
372
+
373
+ fovy = torch.full_like(elevation_deg, fovy_deg) * math.pi / 180
374
+
375
+ lookat = F.normalize(center - camera_positions, dim=-1)
376
+ right = F.normalize(torch.cross(lookat, up), dim=-1)
377
+ up = F.normalize(torch.cross(right, lookat), dim=-1)
378
+ c2w3x4 = torch.cat(
379
+ [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
380
+ dim=-1,
381
+ )
382
+ c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
383
+ c2w[:, 3, 3] = 1.0
384
+
385
+ # get directions by dividing directions_unit_focal by focal length
386
+ focal_length = 0.5 * height / torch.tan(0.5 * fovy)
387
+ directions_unit_focal = get_ray_directions(
388
+ H=height,
389
+ W=width,
390
+ focal=1.0,
391
+ )
392
+ directions = directions_unit_focal[None, :, :, :].repeat(n_views, 1, 1, 1)
393
+ directions[:, :, :, :2] = (
394
+ directions[:, :, :, :2] / focal_length[:, None, None, None]
395
+ )
396
+ # must use normalize=True to normalize directions here
397
+ rays_o, rays_d = get_rays(directions, c2w, keepdim=True, normalize=True)
398
+
399
+ return rays_o, rays_d
400
+
401
+
402
+ def remove_background(
403
+ image: PIL.Image.Image,
404
+ rembg_session: Any = None,
405
+ force: bool = False,
406
+ **rembg_kwargs,
407
+ ) -> PIL.Image.Image:
408
+ do_remove = True
409
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
410
+ do_remove = False
411
+ do_remove = do_remove or force
412
+ if do_remove:
413
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
414
+ return image
415
+
416
+
417
+ def resize_foreground(
418
+ image: PIL.Image.Image,
419
+ ratio: float,
420
+ ) -> PIL.Image.Image:
421
+ image = np.array(image)
422
+ assert image.shape[-1] == 4
423
+ alpha = np.where(image[..., 3] > 0)
424
+ y1, y2, x1, x2 = (
425
+ alpha[0].min(),
426
+ alpha[0].max(),
427
+ alpha[1].min(),
428
+ alpha[1].max(),
429
+ )
430
+ # crop the foreground
431
+ fg = image[y1:y2, x1:x2]
432
+ # pad to square
433
+ size = max(fg.shape[0], fg.shape[1])
434
+ ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
435
+ ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
436
+ new_image = np.pad(
437
+ fg,
438
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
439
+ mode="constant",
440
+ constant_values=((0, 0), (0, 0), (0, 0)),
441
+ )
442
+
443
+ # compute padding according to the ratio
444
+ new_size = int(new_image.shape[0] / ratio)
445
+ # pad to size, double side
446
+ ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
447
+ ph1, pw1 = new_size - size - ph0, new_size - size - pw0
448
+ new_image = np.pad(
449
+ new_image,
450
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
451
+ mode="constant",
452
+ constant_values=((0, 0), (0, 0), (0, 0)),
453
+ )
454
+ new_image = PIL.Image.fromarray(new_image)
455
+ return new_image
456
+
457
+
458
+ def save_video(
459
+ frames: List[PIL.Image.Image],
460
+ output_path: str,
461
+ fps: int = 30,
462
+ ):
463
+ # use imageio to save video
464
+ frames = [np.array(frame) for frame in frames]
465
+ writer = imageio.get_writer(output_path, fps=fps)
466
+ for frame in frames:
467
+ writer.append_data(frame)
468
+ writer.close()
469
+
470
+
471
+ def to_gradio_3d_orientation(mesh):
472
+ mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0]))
473
+ mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi/2, [0, 1, 0]))
474
+ return mesh
utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+
5
+ import numpy as np
6
+ from PIL import Image, ImageOps
7
+ import numpy as np
8
+ import torch
9
+ import xatlas
10
+ from PIL import Image
11
+
12
+ from tsr.system import TSR
13
+ from tsr.utils import save_video
14
+ from tsr.bake_texture import bake_texture
15
+
16
+
17
+ class Timer:
18
+ def __init__(self):
19
+ self.items = {}
20
+ self.time_scale = 1000.0 # ms
21
+ self.time_unit = "ms"
22
+
23
+ def start(self, name: str) -> None:
24
+ if torch.cuda.is_available():
25
+ torch.cuda.synchronize()
26
+ self.items[name] = time.time()
27
+ logging.info(f"{name} ...")
28
+
29
+ def end(self, name: str) -> float:
30
+ if name not in self.items:
31
+ return
32
+ if torch.cuda.is_available():
33
+ torch.cuda.synchronize()
34
+ start_time = self.items.pop(name)
35
+ delta = time.time() - start_time
36
+ t = delta * self.time_scale
37
+ logging.info(f"{name} finished in {t:.2f}{self.time_unit}.")
38
+
39
+
40
+ def initialize_model(pretrained_model_name_or_path="stabilityai/TripoSR",
41
+ chunk_size=8192,
42
+ device="cuda:0" if torch.cuda.is_available() else "cpu"):
43
+ timer.start("Initializing model")
44
+ model = TSR.from_pretrained(
45
+ pretrained_model_name_or_path,
46
+ config_name="config.yaml",
47
+ weight_name="model.ckpt",
48
+ )
49
+ model.renderer.set_chunk_size(chunk_size)
50
+ model.to(device)
51
+ timer.end("Initializing model")
52
+ return model
53
+
54
+
55
+ def remove_background(image_path, output_path, background_value=127, new_size=(425, 425)):
56
+ # Open the image
57
+ image = Image.open(image_path).convert("RGBA")
58
+
59
+ # Split the image into its respective channels
60
+ r, g, b, alpha = image.split()
61
+
62
+ # Convert the alpha channel to binary mask where transparency is 0 and opaque is 255
63
+ alpha = ImageOps.invert(alpha)
64
+
65
+ # Replace the transparent areas with the specified background value
66
+ background = Image.new("L", image.size, color=background_value)
67
+ image_rgb = Image.composite(background, r, alpha), Image.composite(background, g, alpha), Image.composite(background, b, alpha)
68
+
69
+ # Merge the channels back into an image
70
+ image = Image.merge("RGB", image_rgb)
71
+
72
+ # Resize the image to the desired size
73
+ image = image.resize(new_size, Image.LANCZOS)
74
+
75
+ # Save the processed image
76
+ # image.save(output_path)
77
+
78
+ return image
79
+
80
+
81
+ def process_image(image_path, output_dir, no_remove_bg, foreground_ratio):
82
+ timer.start("Processing image")
83
+
84
+ if no_remove_bg:
85
+ rembg_session = None
86
+ image = np.array(Image.open(image_path).convert("RGB"))
87
+ else:
88
+ image = remove_background(image_path ,output_dir)
89
+
90
+ # Save the processed image
91
+ os.makedirs(output_dir, exist_ok=True)
92
+ image.save(os.path.join(output_dir, "processed_input.png"))
93
+
94
+ timer.end("Processing image")
95
+ return image
96
+
97
+
98
+ def run_model(model, image, output_dir, device, render, mc_resolution, model_save_format, bake_texture_flag, texture_resolution):
99
+ logging.info("Running model...")
100
+
101
+ timer.start("Running model")
102
+ with torch.no_grad():
103
+ scene_codes = model([image], device=device)
104
+ timer.end("Running model")
105
+
106
+ out_video_path = None
107
+ if render:
108
+ timer.start("Rendering")
109
+ render_images = model.render(scene_codes, n_views=30, return_type="pil")
110
+ for ri, render_image in enumerate(render_images[0]):
111
+ render_image.save(os.path.join(output_dir, f"render_{ri:03d}.png"))
112
+ out_video_path = os.path.join(output_dir, "render.mp4")
113
+ save_video(
114
+ render_images[0], out_video_path, fps=30
115
+ )
116
+ timer.end("Rendering")
117
+
118
+ timer.start("Extracting mesh")
119
+ meshes = model.extract_mesh(scene_codes, not bake_texture_flag, resolution=mc_resolution)
120
+ timer.end("Extracting mesh")
121
+
122
+ out_mesh_path = os.path.join(output_dir, f"mesh.{model_save_format}")
123
+ if bake_texture_flag:
124
+ out_texture_path = os.path.join(output_dir, "texture.png")
125
+
126
+ timer.start("Baking texture")
127
+ bake_output = bake_texture(meshes[0], model, scene_codes[0], texture_resolution)
128
+ timer.end("Baking texture")
129
+
130
+ timer.start("Exporting mesh and texture")
131
+ xatlas.export(out_mesh_path, meshes[0].vertices[bake_output["vmapping"]], bake_output["indices"], bake_output["uvs"], meshes[0].vertex_normals[bake_output["vmapping"]])
132
+ Image.fromarray((bake_output["colors"] * 255.0).astype(np.uint8)).transpose(Image.FLIP_TOP_BOTTOM).save(out_texture_path)
133
+ timer.end("Exporting mesh and texture")
134
+ else:
135
+ timer.start("Exporting mesh")
136
+ meshes[0].export(out_mesh_path)
137
+ timer.end("Exporting mesh")
138
+
139
+ return out_mesh_path ,out_video_path
140
+
141
+
142
+ logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)
143
+ timer = Timer()